virtio_transport_common.c 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089
  1. /*
  2. * common code for virtio vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <asias@redhat.com>
  6. * Stefan Hajnoczi <stefanha@redhat.com>
  7. *
  8. * This work is licensed under the terms of the GNU GPL, version 2.
  9. */
  10. #include <linux/spinlock.h>
  11. #include <linux/module.h>
  12. #include <linux/sched/signal.h>
  13. #include <linux/ctype.h>
  14. #include <linux/list.h>
  15. #include <linux/virtio.h>
  16. #include <linux/virtio_ids.h>
  17. #include <linux/virtio_config.h>
  18. #include <linux/virtio_vsock.h>
  19. #include <uapi/linux/vsockmon.h>
  20. #include <net/sock.h>
  21. #include <net/af_vsock.h>
  22. #define CREATE_TRACE_POINTS
  23. #include <trace/events/vsock_virtio_transport_common.h>
  24. /* How long to wait for graceful shutdown of a connection */
  25. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  26. static const struct virtio_transport *virtio_transport_get_ops(void)
  27. {
  28. const struct vsock_transport *t = vsock_core_get_transport();
  29. return container_of(t, struct virtio_transport, transport);
  30. }
  31. static struct virtio_vsock_pkt *
  32. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  33. size_t len,
  34. u32 src_cid,
  35. u32 src_port,
  36. u32 dst_cid,
  37. u32 dst_port)
  38. {
  39. struct virtio_vsock_pkt *pkt;
  40. int err;
  41. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  42. if (!pkt)
  43. return NULL;
  44. pkt->hdr.type = cpu_to_le16(info->type);
  45. pkt->hdr.op = cpu_to_le16(info->op);
  46. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  47. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  48. pkt->hdr.src_port = cpu_to_le32(src_port);
  49. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  50. pkt->hdr.flags = cpu_to_le32(info->flags);
  51. pkt->len = len;
  52. pkt->hdr.len = cpu_to_le32(len);
  53. pkt->reply = info->reply;
  54. pkt->vsk = info->vsk;
  55. if (info->msg && len > 0) {
  56. pkt->buf = kmalloc(len, GFP_KERNEL);
  57. if (!pkt->buf)
  58. goto out_pkt;
  59. err = memcpy_from_msg(pkt->buf, info->msg, len);
  60. if (err)
  61. goto out;
  62. }
  63. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  64. dst_cid, dst_port,
  65. len,
  66. info->type,
  67. info->op,
  68. info->flags);
  69. return pkt;
  70. out:
  71. kfree(pkt->buf);
  72. out_pkt:
  73. kfree(pkt);
  74. return NULL;
  75. }
  76. /* Packet capture */
  77. static struct sk_buff *virtio_transport_build_skb(void *opaque)
  78. {
  79. struct virtio_vsock_pkt *pkt = opaque;
  80. struct af_vsockmon_hdr *hdr;
  81. struct sk_buff *skb;
  82. size_t payload_len;
  83. void *payload_buf;
  84. /* A packet could be split to fit the RX buffer, so we can retrieve
  85. * the payload length from the header and the buffer pointer taking
  86. * care of the offset in the original packet.
  87. */
  88. payload_len = le32_to_cpu(pkt->hdr.len);
  89. payload_buf = pkt->buf + pkt->off;
  90. skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
  91. GFP_ATOMIC);
  92. if (!skb)
  93. return NULL;
  94. hdr = skb_put(skb, sizeof(*hdr));
  95. /* pkt->hdr is little-endian so no need to byteswap here */
  96. hdr->src_cid = pkt->hdr.src_cid;
  97. hdr->src_port = pkt->hdr.src_port;
  98. hdr->dst_cid = pkt->hdr.dst_cid;
  99. hdr->dst_port = pkt->hdr.dst_port;
  100. hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
  101. hdr->len = cpu_to_le16(sizeof(pkt->hdr));
  102. memset(hdr->reserved, 0, sizeof(hdr->reserved));
  103. switch (le16_to_cpu(pkt->hdr.op)) {
  104. case VIRTIO_VSOCK_OP_REQUEST:
  105. case VIRTIO_VSOCK_OP_RESPONSE:
  106. hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
  107. break;
  108. case VIRTIO_VSOCK_OP_RST:
  109. case VIRTIO_VSOCK_OP_SHUTDOWN:
  110. hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
  111. break;
  112. case VIRTIO_VSOCK_OP_RW:
  113. hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
  114. break;
  115. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  116. case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
  117. hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
  118. break;
  119. default:
  120. hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
  121. break;
  122. }
  123. skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
  124. if (payload_len) {
  125. skb_put_data(skb, payload_buf, payload_len);
  126. }
  127. return skb;
  128. }
  129. void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
  130. {
  131. vsock_deliver_tap(virtio_transport_build_skb, pkt);
  132. }
  133. EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
  134. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  135. struct virtio_vsock_pkt_info *info)
  136. {
  137. u32 src_cid, src_port, dst_cid, dst_port;
  138. struct virtio_vsock_sock *vvs;
  139. struct virtio_vsock_pkt *pkt;
  140. u32 pkt_len = info->pkt_len;
  141. src_cid = vm_sockets_get_local_cid();
  142. src_port = vsk->local_addr.svm_port;
  143. if (!info->remote_cid) {
  144. dst_cid = vsk->remote_addr.svm_cid;
  145. dst_port = vsk->remote_addr.svm_port;
  146. } else {
  147. dst_cid = info->remote_cid;
  148. dst_port = info->remote_port;
  149. }
  150. vvs = vsk->trans;
  151. /* we can send less than pkt_len bytes */
  152. if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
  153. pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  154. /* virtio_transport_get_credit might return less than pkt_len credit */
  155. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  156. /* Do not send zero length OP_RW pkt */
  157. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  158. return pkt_len;
  159. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  160. src_cid, src_port,
  161. dst_cid, dst_port);
  162. if (!pkt) {
  163. virtio_transport_put_credit(vvs, pkt_len);
  164. return -ENOMEM;
  165. }
  166. virtio_transport_inc_tx_pkt(vvs, pkt);
  167. return virtio_transport_get_ops()->send_pkt(pkt);
  168. }
  169. static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  170. struct virtio_vsock_pkt *pkt)
  171. {
  172. vvs->rx_bytes += pkt->len;
  173. }
  174. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  175. struct virtio_vsock_pkt *pkt)
  176. {
  177. vvs->rx_bytes -= pkt->len;
  178. vvs->fwd_cnt += pkt->len;
  179. }
  180. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  181. {
  182. spin_lock_bh(&vvs->tx_lock);
  183. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  184. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  185. spin_unlock_bh(&vvs->tx_lock);
  186. }
  187. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  188. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  189. {
  190. u32 ret;
  191. spin_lock_bh(&vvs->tx_lock);
  192. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  193. if (ret > credit)
  194. ret = credit;
  195. vvs->tx_cnt += ret;
  196. spin_unlock_bh(&vvs->tx_lock);
  197. return ret;
  198. }
  199. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  200. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  201. {
  202. spin_lock_bh(&vvs->tx_lock);
  203. vvs->tx_cnt -= credit;
  204. spin_unlock_bh(&vvs->tx_lock);
  205. }
  206. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  207. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  208. int type,
  209. struct virtio_vsock_hdr *hdr)
  210. {
  211. struct virtio_vsock_pkt_info info = {
  212. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  213. .type = type,
  214. .vsk = vsk,
  215. };
  216. return virtio_transport_send_pkt_info(vsk, &info);
  217. }
  218. static ssize_t
  219. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  220. struct msghdr *msg,
  221. size_t len)
  222. {
  223. struct virtio_vsock_sock *vvs = vsk->trans;
  224. struct virtio_vsock_pkt *pkt;
  225. size_t bytes, total = 0;
  226. int err = -EFAULT;
  227. spin_lock_bh(&vvs->rx_lock);
  228. while (total < len && !list_empty(&vvs->rx_queue)) {
  229. pkt = list_first_entry(&vvs->rx_queue,
  230. struct virtio_vsock_pkt, list);
  231. bytes = len - total;
  232. if (bytes > pkt->len - pkt->off)
  233. bytes = pkt->len - pkt->off;
  234. /* sk_lock is held by caller so no one else can dequeue.
  235. * Unlock rx_lock since memcpy_to_msg() may sleep.
  236. */
  237. spin_unlock_bh(&vvs->rx_lock);
  238. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  239. if (err)
  240. goto out;
  241. spin_lock_bh(&vvs->rx_lock);
  242. total += bytes;
  243. pkt->off += bytes;
  244. if (pkt->off == pkt->len) {
  245. virtio_transport_dec_rx_pkt(vvs, pkt);
  246. list_del(&pkt->list);
  247. virtio_transport_free_pkt(pkt);
  248. }
  249. }
  250. spin_unlock_bh(&vvs->rx_lock);
  251. /* Send a credit pkt to peer */
  252. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  253. NULL);
  254. return total;
  255. out:
  256. if (total)
  257. err = total;
  258. return err;
  259. }
  260. ssize_t
  261. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  262. struct msghdr *msg,
  263. size_t len, int flags)
  264. {
  265. if (flags & MSG_PEEK)
  266. return -EOPNOTSUPP;
  267. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  268. }
  269. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  270. int
  271. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  272. struct msghdr *msg,
  273. size_t len, int flags)
  274. {
  275. return -EOPNOTSUPP;
  276. }
  277. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  278. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  279. {
  280. struct virtio_vsock_sock *vvs = vsk->trans;
  281. s64 bytes;
  282. spin_lock_bh(&vvs->rx_lock);
  283. bytes = vvs->rx_bytes;
  284. spin_unlock_bh(&vvs->rx_lock);
  285. return bytes;
  286. }
  287. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  288. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  289. {
  290. struct virtio_vsock_sock *vvs = vsk->trans;
  291. s64 bytes;
  292. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  293. if (bytes < 0)
  294. bytes = 0;
  295. return bytes;
  296. }
  297. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  298. {
  299. struct virtio_vsock_sock *vvs = vsk->trans;
  300. s64 bytes;
  301. spin_lock_bh(&vvs->tx_lock);
  302. bytes = virtio_transport_has_space(vsk);
  303. spin_unlock_bh(&vvs->tx_lock);
  304. return bytes;
  305. }
  306. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  307. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  308. struct vsock_sock *psk)
  309. {
  310. struct virtio_vsock_sock *vvs;
  311. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  312. if (!vvs)
  313. return -ENOMEM;
  314. vsk->trans = vvs;
  315. vvs->vsk = vsk;
  316. if (psk) {
  317. struct virtio_vsock_sock *ptrans = psk->trans;
  318. vvs->buf_size = ptrans->buf_size;
  319. vvs->buf_size_min = ptrans->buf_size_min;
  320. vvs->buf_size_max = ptrans->buf_size_max;
  321. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  322. } else {
  323. vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
  324. vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
  325. vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
  326. }
  327. vvs->buf_alloc = vvs->buf_size;
  328. spin_lock_init(&vvs->rx_lock);
  329. spin_lock_init(&vvs->tx_lock);
  330. INIT_LIST_HEAD(&vvs->rx_queue);
  331. return 0;
  332. }
  333. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  334. u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
  335. {
  336. struct virtio_vsock_sock *vvs = vsk->trans;
  337. return vvs->buf_size;
  338. }
  339. EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
  340. u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
  341. {
  342. struct virtio_vsock_sock *vvs = vsk->trans;
  343. return vvs->buf_size_min;
  344. }
  345. EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
  346. u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
  347. {
  348. struct virtio_vsock_sock *vvs = vsk->trans;
  349. return vvs->buf_size_max;
  350. }
  351. EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
  352. void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
  353. {
  354. struct virtio_vsock_sock *vvs = vsk->trans;
  355. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  356. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  357. if (val < vvs->buf_size_min)
  358. vvs->buf_size_min = val;
  359. if (val > vvs->buf_size_max)
  360. vvs->buf_size_max = val;
  361. vvs->buf_size = val;
  362. vvs->buf_alloc = val;
  363. }
  364. EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
  365. void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
  366. {
  367. struct virtio_vsock_sock *vvs = vsk->trans;
  368. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  369. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  370. if (val > vvs->buf_size)
  371. vvs->buf_size = val;
  372. vvs->buf_size_min = val;
  373. }
  374. EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
  375. void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
  376. {
  377. struct virtio_vsock_sock *vvs = vsk->trans;
  378. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  379. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  380. if (val < vvs->buf_size)
  381. vvs->buf_size = val;
  382. vvs->buf_size_max = val;
  383. }
  384. EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
  385. int
  386. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  387. size_t target,
  388. bool *data_ready_now)
  389. {
  390. if (vsock_stream_has_data(vsk))
  391. *data_ready_now = true;
  392. else
  393. *data_ready_now = false;
  394. return 0;
  395. }
  396. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  397. int
  398. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  399. size_t target,
  400. bool *space_avail_now)
  401. {
  402. s64 free_space;
  403. free_space = vsock_stream_has_space(vsk);
  404. if (free_space > 0)
  405. *space_avail_now = true;
  406. else if (free_space == 0)
  407. *space_avail_now = false;
  408. return 0;
  409. }
  410. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  411. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  412. size_t target, struct vsock_transport_recv_notify_data *data)
  413. {
  414. return 0;
  415. }
  416. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  417. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  418. size_t target, struct vsock_transport_recv_notify_data *data)
  419. {
  420. return 0;
  421. }
  422. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  423. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  424. size_t target, struct vsock_transport_recv_notify_data *data)
  425. {
  426. return 0;
  427. }
  428. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  429. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  430. size_t target, ssize_t copied, bool data_read,
  431. struct vsock_transport_recv_notify_data *data)
  432. {
  433. return 0;
  434. }
  435. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  436. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  437. struct vsock_transport_send_notify_data *data)
  438. {
  439. return 0;
  440. }
  441. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  442. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  443. struct vsock_transport_send_notify_data *data)
  444. {
  445. return 0;
  446. }
  447. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  448. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  449. struct vsock_transport_send_notify_data *data)
  450. {
  451. return 0;
  452. }
  453. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  454. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  455. ssize_t written, struct vsock_transport_send_notify_data *data)
  456. {
  457. return 0;
  458. }
  459. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  460. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  461. {
  462. struct virtio_vsock_sock *vvs = vsk->trans;
  463. return vvs->buf_size;
  464. }
  465. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  466. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  467. {
  468. return true;
  469. }
  470. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  471. bool virtio_transport_stream_allow(u32 cid, u32 port)
  472. {
  473. return true;
  474. }
  475. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  476. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  477. struct sockaddr_vm *addr)
  478. {
  479. return -EOPNOTSUPP;
  480. }
  481. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  482. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  483. {
  484. return false;
  485. }
  486. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  487. int virtio_transport_connect(struct vsock_sock *vsk)
  488. {
  489. struct virtio_vsock_pkt_info info = {
  490. .op = VIRTIO_VSOCK_OP_REQUEST,
  491. .type = VIRTIO_VSOCK_TYPE_STREAM,
  492. .vsk = vsk,
  493. };
  494. return virtio_transport_send_pkt_info(vsk, &info);
  495. }
  496. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  497. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  498. {
  499. struct virtio_vsock_pkt_info info = {
  500. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  501. .type = VIRTIO_VSOCK_TYPE_STREAM,
  502. .flags = (mode & RCV_SHUTDOWN ?
  503. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  504. (mode & SEND_SHUTDOWN ?
  505. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  506. .vsk = vsk,
  507. };
  508. return virtio_transport_send_pkt_info(vsk, &info);
  509. }
  510. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  511. int
  512. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  513. struct sockaddr_vm *remote_addr,
  514. struct msghdr *msg,
  515. size_t dgram_len)
  516. {
  517. return -EOPNOTSUPP;
  518. }
  519. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  520. ssize_t
  521. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  522. struct msghdr *msg,
  523. size_t len)
  524. {
  525. struct virtio_vsock_pkt_info info = {
  526. .op = VIRTIO_VSOCK_OP_RW,
  527. .type = VIRTIO_VSOCK_TYPE_STREAM,
  528. .msg = msg,
  529. .pkt_len = len,
  530. .vsk = vsk,
  531. };
  532. return virtio_transport_send_pkt_info(vsk, &info);
  533. }
  534. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  535. void virtio_transport_destruct(struct vsock_sock *vsk)
  536. {
  537. struct virtio_vsock_sock *vvs = vsk->trans;
  538. kfree(vvs);
  539. }
  540. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  541. static int virtio_transport_reset(struct vsock_sock *vsk,
  542. struct virtio_vsock_pkt *pkt)
  543. {
  544. struct virtio_vsock_pkt_info info = {
  545. .op = VIRTIO_VSOCK_OP_RST,
  546. .type = VIRTIO_VSOCK_TYPE_STREAM,
  547. .reply = !!pkt,
  548. .vsk = vsk,
  549. };
  550. /* Send RST only if the original pkt is not a RST pkt */
  551. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  552. return 0;
  553. return virtio_transport_send_pkt_info(vsk, &info);
  554. }
  555. /* Normally packets are associated with a socket. There may be no socket if an
  556. * attempt was made to connect to a socket that does not exist.
  557. */
  558. static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
  559. struct virtio_vsock_pkt *pkt)
  560. {
  561. struct virtio_vsock_pkt *reply;
  562. struct virtio_vsock_pkt_info info = {
  563. .op = VIRTIO_VSOCK_OP_RST,
  564. .type = le16_to_cpu(pkt->hdr.type),
  565. .reply = true,
  566. };
  567. /* Send RST only if the original pkt is not a RST pkt */
  568. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  569. return 0;
  570. reply = virtio_transport_alloc_pkt(&info, 0,
  571. le64_to_cpu(pkt->hdr.dst_cid),
  572. le32_to_cpu(pkt->hdr.dst_port),
  573. le64_to_cpu(pkt->hdr.src_cid),
  574. le32_to_cpu(pkt->hdr.src_port));
  575. if (!reply)
  576. return -ENOMEM;
  577. if (!t) {
  578. virtio_transport_free_pkt(reply);
  579. return -ENOTCONN;
  580. }
  581. return t->send_pkt(reply);
  582. }
  583. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  584. {
  585. if (timeout) {
  586. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  587. add_wait_queue(sk_sleep(sk), &wait);
  588. do {
  589. if (sk_wait_event(sk, &timeout,
  590. sock_flag(sk, SOCK_DONE), &wait))
  591. break;
  592. } while (!signal_pending(current) && timeout);
  593. remove_wait_queue(sk_sleep(sk), &wait);
  594. }
  595. }
  596. static void virtio_transport_do_close(struct vsock_sock *vsk,
  597. bool cancel_timeout)
  598. {
  599. struct sock *sk = sk_vsock(vsk);
  600. sock_set_flag(sk, SOCK_DONE);
  601. vsk->peer_shutdown = SHUTDOWN_MASK;
  602. if (vsock_stream_has_data(vsk) <= 0)
  603. sk->sk_state = TCP_CLOSING;
  604. sk->sk_state_change(sk);
  605. if (vsk->close_work_scheduled &&
  606. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  607. vsk->close_work_scheduled = false;
  608. vsock_remove_sock(vsk);
  609. /* Release refcnt obtained when we scheduled the timeout */
  610. sock_put(sk);
  611. }
  612. }
  613. static void virtio_transport_close_timeout(struct work_struct *work)
  614. {
  615. struct vsock_sock *vsk =
  616. container_of(work, struct vsock_sock, close_work.work);
  617. struct sock *sk = sk_vsock(vsk);
  618. sock_hold(sk);
  619. lock_sock(sk);
  620. if (!sock_flag(sk, SOCK_DONE)) {
  621. (void)virtio_transport_reset(vsk, NULL);
  622. virtio_transport_do_close(vsk, false);
  623. }
  624. vsk->close_work_scheduled = false;
  625. release_sock(sk);
  626. sock_put(sk);
  627. }
  628. /* User context, vsk->sk is locked */
  629. static bool virtio_transport_close(struct vsock_sock *vsk)
  630. {
  631. struct sock *sk = &vsk->sk;
  632. if (!(sk->sk_state == TCP_ESTABLISHED ||
  633. sk->sk_state == TCP_CLOSING))
  634. return true;
  635. /* Already received SHUTDOWN from peer, reply with RST */
  636. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  637. (void)virtio_transport_reset(vsk, NULL);
  638. return true;
  639. }
  640. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  641. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  642. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  643. virtio_transport_wait_close(sk, sk->sk_lingertime);
  644. if (sock_flag(sk, SOCK_DONE)) {
  645. return true;
  646. }
  647. sock_hold(sk);
  648. INIT_DELAYED_WORK(&vsk->close_work,
  649. virtio_transport_close_timeout);
  650. vsk->close_work_scheduled = true;
  651. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  652. return false;
  653. }
  654. void virtio_transport_release(struct vsock_sock *vsk)
  655. {
  656. struct virtio_vsock_sock *vvs = vsk->trans;
  657. struct virtio_vsock_pkt *pkt, *tmp;
  658. struct sock *sk = &vsk->sk;
  659. bool remove_sock = true;
  660. lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
  661. if (sk->sk_type == SOCK_STREAM)
  662. remove_sock = virtio_transport_close(vsk);
  663. list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
  664. list_del(&pkt->list);
  665. virtio_transport_free_pkt(pkt);
  666. }
  667. release_sock(sk);
  668. if (remove_sock)
  669. vsock_remove_sock(vsk);
  670. }
  671. EXPORT_SYMBOL_GPL(virtio_transport_release);
  672. static int
  673. virtio_transport_recv_connecting(struct sock *sk,
  674. struct virtio_vsock_pkt *pkt)
  675. {
  676. struct vsock_sock *vsk = vsock_sk(sk);
  677. int err;
  678. int skerr;
  679. switch (le16_to_cpu(pkt->hdr.op)) {
  680. case VIRTIO_VSOCK_OP_RESPONSE:
  681. sk->sk_state = TCP_ESTABLISHED;
  682. sk->sk_socket->state = SS_CONNECTED;
  683. vsock_insert_connected(vsk);
  684. sk->sk_state_change(sk);
  685. break;
  686. case VIRTIO_VSOCK_OP_INVALID:
  687. break;
  688. case VIRTIO_VSOCK_OP_RST:
  689. skerr = ECONNRESET;
  690. err = 0;
  691. goto destroy;
  692. default:
  693. skerr = EPROTO;
  694. err = -EINVAL;
  695. goto destroy;
  696. }
  697. return 0;
  698. destroy:
  699. virtio_transport_reset(vsk, pkt);
  700. sk->sk_state = TCP_CLOSE;
  701. sk->sk_err = skerr;
  702. sk->sk_error_report(sk);
  703. return err;
  704. }
  705. static int
  706. virtio_transport_recv_connected(struct sock *sk,
  707. struct virtio_vsock_pkt *pkt)
  708. {
  709. struct vsock_sock *vsk = vsock_sk(sk);
  710. struct virtio_vsock_sock *vvs = vsk->trans;
  711. int err = 0;
  712. switch (le16_to_cpu(pkt->hdr.op)) {
  713. case VIRTIO_VSOCK_OP_RW:
  714. pkt->len = le32_to_cpu(pkt->hdr.len);
  715. pkt->off = 0;
  716. spin_lock_bh(&vvs->rx_lock);
  717. virtio_transport_inc_rx_pkt(vvs, pkt);
  718. list_add_tail(&pkt->list, &vvs->rx_queue);
  719. spin_unlock_bh(&vvs->rx_lock);
  720. sk->sk_data_ready(sk);
  721. return err;
  722. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  723. sk->sk_write_space(sk);
  724. break;
  725. case VIRTIO_VSOCK_OP_SHUTDOWN:
  726. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  727. vsk->peer_shutdown |= RCV_SHUTDOWN;
  728. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  729. vsk->peer_shutdown |= SEND_SHUTDOWN;
  730. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  731. vsock_stream_has_data(vsk) <= 0 &&
  732. !sock_flag(sk, SOCK_DONE)) {
  733. (void)virtio_transport_reset(vsk, NULL);
  734. virtio_transport_do_close(vsk, true);
  735. }
  736. if (le32_to_cpu(pkt->hdr.flags))
  737. sk->sk_state_change(sk);
  738. break;
  739. case VIRTIO_VSOCK_OP_RST:
  740. virtio_transport_do_close(vsk, true);
  741. break;
  742. default:
  743. err = -EINVAL;
  744. break;
  745. }
  746. virtio_transport_free_pkt(pkt);
  747. return err;
  748. }
  749. static void
  750. virtio_transport_recv_disconnecting(struct sock *sk,
  751. struct virtio_vsock_pkt *pkt)
  752. {
  753. struct vsock_sock *vsk = vsock_sk(sk);
  754. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  755. virtio_transport_do_close(vsk, true);
  756. }
  757. static int
  758. virtio_transport_send_response(struct vsock_sock *vsk,
  759. struct virtio_vsock_pkt *pkt)
  760. {
  761. struct virtio_vsock_pkt_info info = {
  762. .op = VIRTIO_VSOCK_OP_RESPONSE,
  763. .type = VIRTIO_VSOCK_TYPE_STREAM,
  764. .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
  765. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  766. .reply = true,
  767. .vsk = vsk,
  768. };
  769. return virtio_transport_send_pkt_info(vsk, &info);
  770. }
  771. /* Handle server socket */
  772. static int
  773. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
  774. {
  775. struct vsock_sock *vsk = vsock_sk(sk);
  776. struct vsock_sock *vchild;
  777. struct sock *child;
  778. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  779. virtio_transport_reset(vsk, pkt);
  780. return -EINVAL;
  781. }
  782. if (sk_acceptq_is_full(sk)) {
  783. virtio_transport_reset(vsk, pkt);
  784. return -ENOMEM;
  785. }
  786. child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
  787. sk->sk_type, 0);
  788. if (!child) {
  789. virtio_transport_reset(vsk, pkt);
  790. return -ENOMEM;
  791. }
  792. sk->sk_ack_backlog++;
  793. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  794. child->sk_state = TCP_ESTABLISHED;
  795. vchild = vsock_sk(child);
  796. vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
  797. le32_to_cpu(pkt->hdr.dst_port));
  798. vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
  799. le32_to_cpu(pkt->hdr.src_port));
  800. vsock_insert_connected(vchild);
  801. vsock_enqueue_accept(sk, child);
  802. virtio_transport_send_response(vchild, pkt);
  803. release_sock(child);
  804. sk->sk_data_ready(sk);
  805. return 0;
  806. }
  807. static bool virtio_transport_space_update(struct sock *sk,
  808. struct virtio_vsock_pkt *pkt)
  809. {
  810. struct vsock_sock *vsk = vsock_sk(sk);
  811. struct virtio_vsock_sock *vvs = vsk->trans;
  812. bool space_available;
  813. /* buf_alloc and fwd_cnt is always included in the hdr */
  814. spin_lock_bh(&vvs->tx_lock);
  815. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  816. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  817. space_available = virtio_transport_has_space(vsk);
  818. spin_unlock_bh(&vvs->tx_lock);
  819. return space_available;
  820. }
  821. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  822. * lock.
  823. */
  824. void virtio_transport_recv_pkt(struct virtio_transport *t,
  825. struct virtio_vsock_pkt *pkt)
  826. {
  827. struct sockaddr_vm src, dst;
  828. struct vsock_sock *vsk;
  829. struct sock *sk;
  830. bool space_available;
  831. vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
  832. le32_to_cpu(pkt->hdr.src_port));
  833. vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
  834. le32_to_cpu(pkt->hdr.dst_port));
  835. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  836. dst.svm_cid, dst.svm_port,
  837. le32_to_cpu(pkt->hdr.len),
  838. le16_to_cpu(pkt->hdr.type),
  839. le16_to_cpu(pkt->hdr.op),
  840. le32_to_cpu(pkt->hdr.flags),
  841. le32_to_cpu(pkt->hdr.buf_alloc),
  842. le32_to_cpu(pkt->hdr.fwd_cnt));
  843. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  844. (void)virtio_transport_reset_no_sock(t, pkt);
  845. goto free_pkt;
  846. }
  847. /* The socket must be in connected or bound table
  848. * otherwise send reset back
  849. */
  850. sk = vsock_find_connected_socket(&src, &dst);
  851. if (!sk) {
  852. sk = vsock_find_bound_socket(&dst);
  853. if (!sk) {
  854. (void)virtio_transport_reset_no_sock(t, pkt);
  855. goto free_pkt;
  856. }
  857. }
  858. vsk = vsock_sk(sk);
  859. lock_sock(sk);
  860. space_available = virtio_transport_space_update(sk, pkt);
  861. /* Update CID in case it has changed after a transport reset event */
  862. vsk->local_addr.svm_cid = dst.svm_cid;
  863. if (space_available)
  864. sk->sk_write_space(sk);
  865. switch (sk->sk_state) {
  866. case TCP_LISTEN:
  867. virtio_transport_recv_listen(sk, pkt);
  868. virtio_transport_free_pkt(pkt);
  869. break;
  870. case TCP_SYN_SENT:
  871. virtio_transport_recv_connecting(sk, pkt);
  872. virtio_transport_free_pkt(pkt);
  873. break;
  874. case TCP_ESTABLISHED:
  875. virtio_transport_recv_connected(sk, pkt);
  876. break;
  877. case TCP_CLOSING:
  878. virtio_transport_recv_disconnecting(sk, pkt);
  879. virtio_transport_free_pkt(pkt);
  880. break;
  881. default:
  882. (void)virtio_transport_reset_no_sock(t, pkt);
  883. virtio_transport_free_pkt(pkt);
  884. break;
  885. }
  886. release_sock(sk);
  887. /* Release refcnt obtained when we fetched this socket out of the
  888. * bound or connected list.
  889. */
  890. sock_put(sk);
  891. return;
  892. free_pkt:
  893. virtio_transport_free_pkt(pkt);
  894. }
  895. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  896. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  897. {
  898. kfree(pkt->buf);
  899. kfree(pkt);
  900. }
  901. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  902. MODULE_LICENSE("GPL v2");
  903. MODULE_AUTHOR("Asias He");
  904. MODULE_DESCRIPTION("common code for virtio vsock");