espintcp.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. // SPDX-License-Identifier: GPL-2.0
  2. #include <net/tcp.h>
  3. #include <net/strparser.h>
  4. #include <net/xfrm.h>
  5. #include <net/esp.h>
  6. #include <net/espintcp.h>
  7. #include <linux/skmsg.h>
  8. #include <net/inet_common.h>
  9. #include <trace/events/sock.h>
  10. #if IS_ENABLED(CONFIG_IPV6)
  11. #include <net/ipv6_stubs.h>
  12. #endif
  13. #include <net/hotdata.h>
  14. static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
  15. struct sock *sk)
  16. {
  17. if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
  18. !sk_rmem_schedule(sk, skb, skb->truesize)) {
  19. XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
  20. kfree_skb(skb);
  21. return;
  22. }
  23. skb_set_owner_r(skb, sk);
  24. memset(skb->cb, 0, sizeof(skb->cb));
  25. skb_queue_tail(&ctx->ike_queue, skb);
  26. ctx->saved_data_ready(sk);
  27. }
  28. static void handle_esp(struct sk_buff *skb, struct sock *sk)
  29. {
  30. struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb;
  31. skb_reset_transport_header(skb);
  32. /* restore IP CB, we need at least IP6CB->nhoff */
  33. memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header));
  34. rcu_read_lock();
  35. skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
  36. local_bh_disable();
  37. #if IS_ENABLED(CONFIG_IPV6)
  38. if (sk->sk_family == AF_INET6)
  39. ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
  40. else
  41. #endif
  42. xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
  43. local_bh_enable();
  44. rcu_read_unlock();
  45. }
  46. static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
  47. {
  48. struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
  49. strp);
  50. struct strp_msg *rxm = strp_msg(skb);
  51. int len = rxm->full_len - 2;
  52. u32 nonesp_marker;
  53. int err;
  54. /* keepalive packet? */
  55. if (unlikely(len == 1)) {
  56. u8 data;
  57. err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
  58. if (err < 0) {
  59. XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
  60. kfree_skb(skb);
  61. return;
  62. }
  63. if (data == 0xff) {
  64. kfree_skb(skb);
  65. return;
  66. }
  67. }
  68. /* drop other short messages */
  69. if (unlikely(len <= sizeof(nonesp_marker))) {
  70. XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
  71. kfree_skb(skb);
  72. return;
  73. }
  74. err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
  75. sizeof(nonesp_marker));
  76. if (err < 0) {
  77. XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
  78. kfree_skb(skb);
  79. return;
  80. }
  81. /* remove header, leave non-ESP marker/SPI */
  82. if (!pskb_pull(skb, rxm->offset + 2)) {
  83. XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
  84. kfree_skb(skb);
  85. return;
  86. }
  87. if (pskb_trim(skb, rxm->full_len - 2) != 0) {
  88. XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
  89. kfree_skb(skb);
  90. return;
  91. }
  92. if (nonesp_marker == 0)
  93. handle_nonesp(ctx, skb, strp->sk);
  94. else
  95. handle_esp(skb, strp->sk);
  96. }
  97. static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
  98. {
  99. struct strp_msg *rxm = strp_msg(skb);
  100. __be16 blen;
  101. u16 len;
  102. int err;
  103. if (skb->len < rxm->offset + 2)
  104. return 0;
  105. err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
  106. if (err < 0)
  107. return err;
  108. len = be16_to_cpu(blen);
  109. if (len < 2)
  110. return -EINVAL;
  111. return len;
  112. }
  113. static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  114. int flags, int *addr_len)
  115. {
  116. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  117. struct sk_buff *skb;
  118. int err = 0;
  119. int copied;
  120. int off = 0;
  121. skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
  122. if (!skb) {
  123. if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
  124. return 0;
  125. return err;
  126. }
  127. copied = len;
  128. if (copied > skb->len)
  129. copied = skb->len;
  130. else if (copied < skb->len)
  131. msg->msg_flags |= MSG_TRUNC;
  132. err = skb_copy_datagram_msg(skb, 0, msg, copied);
  133. if (unlikely(err)) {
  134. kfree_skb(skb);
  135. return err;
  136. }
  137. if (flags & MSG_TRUNC)
  138. copied = skb->len;
  139. kfree_skb(skb);
  140. return copied;
  141. }
  142. int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
  143. {
  144. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  145. if (skb_queue_len(&ctx->out_queue) >=
  146. READ_ONCE(net_hotdata.max_backlog))
  147. return -ENOBUFS;
  148. __skb_queue_tail(&ctx->out_queue, skb);
  149. return 0;
  150. }
  151. EXPORT_SYMBOL_GPL(espintcp_queue_out);
  152. /* espintcp length field is 2B and length includes the length field's size */
  153. #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
  154. static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
  155. int flags)
  156. {
  157. do {
  158. int ret;
  159. ret = skb_send_sock_locked(sk, emsg->skb,
  160. emsg->offset, emsg->len);
  161. if (ret < 0)
  162. return ret;
  163. emsg->len -= ret;
  164. emsg->offset += ret;
  165. } while (emsg->len > 0);
  166. kfree_skb(emsg->skb);
  167. memset(emsg, 0, sizeof(*emsg));
  168. return 0;
  169. }
  170. static int espintcp_sendskmsg_locked(struct sock *sk,
  171. struct espintcp_msg *emsg, int flags)
  172. {
  173. struct msghdr msghdr = {
  174. .msg_flags = flags | MSG_SPLICE_PAGES | MSG_MORE,
  175. };
  176. struct sk_msg *skmsg = &emsg->skmsg;
  177. bool more = flags & MSG_MORE;
  178. struct scatterlist *sg;
  179. int done = 0;
  180. int ret;
  181. sg = &skmsg->sg.data[skmsg->sg.start];
  182. do {
  183. struct bio_vec bvec;
  184. size_t size = sg->length - emsg->offset;
  185. int offset = sg->offset + emsg->offset;
  186. struct page *p;
  187. emsg->offset = 0;
  188. if (sg_is_last(sg) && !more)
  189. msghdr.msg_flags &= ~MSG_MORE;
  190. p = sg_page(sg);
  191. retry:
  192. bvec_set_page(&bvec, p, size, offset);
  193. iov_iter_bvec(&msghdr.msg_iter, ITER_SOURCE, &bvec, 1, size);
  194. ret = tcp_sendmsg_locked(sk, &msghdr, size);
  195. if (ret < 0) {
  196. emsg->offset = offset - sg->offset;
  197. skmsg->sg.start += done;
  198. return ret;
  199. }
  200. if (ret != size) {
  201. offset += ret;
  202. size -= ret;
  203. goto retry;
  204. }
  205. done++;
  206. put_page(p);
  207. sk_mem_uncharge(sk, sg->length);
  208. sg = sg_next(sg);
  209. } while (sg);
  210. memset(emsg, 0, sizeof(*emsg));
  211. return 0;
  212. }
  213. static int espintcp_push_msgs(struct sock *sk, int flags)
  214. {
  215. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  216. struct espintcp_msg *emsg = &ctx->partial;
  217. int err;
  218. if (!emsg->len)
  219. return 0;
  220. if (ctx->tx_running)
  221. return -EAGAIN;
  222. ctx->tx_running = 1;
  223. if (emsg->skb)
  224. err = espintcp_sendskb_locked(sk, emsg, flags);
  225. else
  226. err = espintcp_sendskmsg_locked(sk, emsg, flags);
  227. if (err == -EAGAIN) {
  228. ctx->tx_running = 0;
  229. return flags & MSG_DONTWAIT ? -EAGAIN : 0;
  230. }
  231. if (!err)
  232. memset(emsg, 0, sizeof(*emsg));
  233. ctx->tx_running = 0;
  234. return err;
  235. }
  236. int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
  237. {
  238. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  239. struct espintcp_msg *emsg = &ctx->partial;
  240. unsigned int len;
  241. int offset;
  242. if (sk->sk_state != TCP_ESTABLISHED) {
  243. kfree_skb(skb);
  244. return -ECONNRESET;
  245. }
  246. offset = skb_transport_offset(skb);
  247. len = skb->len - offset;
  248. espintcp_push_msgs(sk, 0);
  249. if (emsg->len) {
  250. kfree_skb(skb);
  251. return -ENOBUFS;
  252. }
  253. skb_set_owner_w(skb, sk);
  254. emsg->offset = offset;
  255. emsg->len = len;
  256. emsg->skb = skb;
  257. espintcp_push_msgs(sk, 0);
  258. return 0;
  259. }
  260. EXPORT_SYMBOL_GPL(espintcp_push_skb);
  261. static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
  262. {
  263. long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
  264. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  265. struct espintcp_msg *emsg = &ctx->partial;
  266. struct iov_iter pfx_iter;
  267. struct kvec pfx_iov = {};
  268. size_t msglen = size + 2;
  269. char buf[2] = {0};
  270. int err, end;
  271. if (msg->msg_flags & ~MSG_DONTWAIT)
  272. return -EOPNOTSUPP;
  273. if (size > MAX_ESPINTCP_MSG)
  274. return -EMSGSIZE;
  275. if (msg->msg_controllen)
  276. return -EOPNOTSUPP;
  277. lock_sock(sk);
  278. err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
  279. if (err < 0) {
  280. if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
  281. err = -ENOBUFS;
  282. goto unlock;
  283. }
  284. sk_msg_init(&emsg->skmsg);
  285. while (1) {
  286. /* only -ENOMEM is possible since we don't coalesce */
  287. err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
  288. if (!err)
  289. break;
  290. err = sk_stream_wait_memory(sk, &timeo);
  291. if (err)
  292. goto fail;
  293. }
  294. *((__be16 *)buf) = cpu_to_be16(msglen);
  295. pfx_iov.iov_base = buf;
  296. pfx_iov.iov_len = sizeof(buf);
  297. iov_iter_kvec(&pfx_iter, ITER_SOURCE, &pfx_iov, 1, pfx_iov.iov_len);
  298. err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
  299. pfx_iov.iov_len);
  300. if (err < 0)
  301. goto fail;
  302. err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
  303. if (err < 0)
  304. goto fail;
  305. end = emsg->skmsg.sg.end;
  306. emsg->len = size;
  307. sk_msg_iter_var_prev(end);
  308. sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
  309. tcp_rate_check_app_limited(sk);
  310. err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
  311. /* this message could be partially sent, keep it */
  312. release_sock(sk);
  313. return size;
  314. fail:
  315. sk_msg_free(sk, &emsg->skmsg);
  316. memset(emsg, 0, sizeof(*emsg));
  317. unlock:
  318. release_sock(sk);
  319. return err;
  320. }
  321. static struct proto espintcp_prot __ro_after_init;
  322. static struct proto_ops espintcp_ops __ro_after_init;
  323. static struct proto espintcp6_prot;
  324. static struct proto_ops espintcp6_ops;
  325. static DEFINE_MUTEX(tcpv6_prot_mutex);
  326. static void espintcp_data_ready(struct sock *sk)
  327. {
  328. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  329. trace_sk_data_ready(sk);
  330. strp_data_ready(&ctx->strp);
  331. }
  332. static void espintcp_tx_work(struct work_struct *work)
  333. {
  334. struct espintcp_ctx *ctx = container_of(work,
  335. struct espintcp_ctx, work);
  336. struct sock *sk = ctx->strp.sk;
  337. lock_sock(sk);
  338. if (!ctx->tx_running)
  339. espintcp_push_msgs(sk, 0);
  340. release_sock(sk);
  341. }
  342. static void espintcp_write_space(struct sock *sk)
  343. {
  344. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  345. schedule_work(&ctx->work);
  346. ctx->saved_write_space(sk);
  347. }
  348. static void espintcp_destruct(struct sock *sk)
  349. {
  350. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  351. ctx->saved_destruct(sk);
  352. kfree(ctx);
  353. }
  354. bool tcp_is_ulp_esp(struct sock *sk)
  355. {
  356. return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
  357. }
  358. EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
  359. static void build_protos(struct proto *espintcp_prot,
  360. struct proto_ops *espintcp_ops,
  361. const struct proto *orig_prot,
  362. const struct proto_ops *orig_ops);
  363. static int espintcp_init_sk(struct sock *sk)
  364. {
  365. struct inet_connection_sock *icsk = inet_csk(sk);
  366. struct strp_callbacks cb = {
  367. .rcv_msg = espintcp_rcv,
  368. .parse_msg = espintcp_parse,
  369. };
  370. struct espintcp_ctx *ctx;
  371. int err;
  372. /* sockmap is not compatible with espintcp */
  373. if (sk->sk_user_data)
  374. return -EBUSY;
  375. ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
  376. if (!ctx)
  377. return -ENOMEM;
  378. err = strp_init(&ctx->strp, sk, &cb);
  379. if (err)
  380. goto free;
  381. __sk_dst_reset(sk);
  382. strp_check_rcv(&ctx->strp);
  383. skb_queue_head_init(&ctx->ike_queue);
  384. skb_queue_head_init(&ctx->out_queue);
  385. if (sk->sk_family == AF_INET) {
  386. sk->sk_prot = &espintcp_prot;
  387. sk->sk_socket->ops = &espintcp_ops;
  388. } else {
  389. mutex_lock(&tcpv6_prot_mutex);
  390. if (!espintcp6_prot.recvmsg)
  391. build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
  392. mutex_unlock(&tcpv6_prot_mutex);
  393. sk->sk_prot = &espintcp6_prot;
  394. sk->sk_socket->ops = &espintcp6_ops;
  395. }
  396. ctx->saved_data_ready = sk->sk_data_ready;
  397. ctx->saved_write_space = sk->sk_write_space;
  398. ctx->saved_destruct = sk->sk_destruct;
  399. sk->sk_data_ready = espintcp_data_ready;
  400. sk->sk_write_space = espintcp_write_space;
  401. sk->sk_destruct = espintcp_destruct;
  402. rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
  403. INIT_WORK(&ctx->work, espintcp_tx_work);
  404. /* avoid using task_frag */
  405. sk->sk_allocation = GFP_ATOMIC;
  406. sk->sk_use_task_frag = false;
  407. return 0;
  408. free:
  409. kfree(ctx);
  410. return err;
  411. }
  412. static void espintcp_release(struct sock *sk)
  413. {
  414. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  415. struct sk_buff_head queue;
  416. struct sk_buff *skb;
  417. __skb_queue_head_init(&queue);
  418. skb_queue_splice_init(&ctx->out_queue, &queue);
  419. while ((skb = __skb_dequeue(&queue)))
  420. espintcp_push_skb(sk, skb);
  421. tcp_release_cb(sk);
  422. }
  423. static void espintcp_close(struct sock *sk, long timeout)
  424. {
  425. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  426. struct espintcp_msg *emsg = &ctx->partial;
  427. strp_stop(&ctx->strp);
  428. sk->sk_prot = &tcp_prot;
  429. barrier();
  430. cancel_work_sync(&ctx->work);
  431. strp_done(&ctx->strp);
  432. skb_queue_purge(&ctx->out_queue);
  433. skb_queue_purge(&ctx->ike_queue);
  434. if (emsg->len) {
  435. if (emsg->skb)
  436. kfree_skb(emsg->skb);
  437. else
  438. sk_msg_free(sk, &emsg->skmsg);
  439. }
  440. tcp_close(sk, timeout);
  441. }
  442. static __poll_t espintcp_poll(struct file *file, struct socket *sock,
  443. poll_table *wait)
  444. {
  445. __poll_t mask = datagram_poll(file, sock, wait);
  446. struct sock *sk = sock->sk;
  447. struct espintcp_ctx *ctx = espintcp_getctx(sk);
  448. if (!skb_queue_empty(&ctx->ike_queue))
  449. mask |= EPOLLIN | EPOLLRDNORM;
  450. return mask;
  451. }
  452. static void build_protos(struct proto *espintcp_prot,
  453. struct proto_ops *espintcp_ops,
  454. const struct proto *orig_prot,
  455. const struct proto_ops *orig_ops)
  456. {
  457. memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
  458. memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
  459. espintcp_prot->sendmsg = espintcp_sendmsg;
  460. espintcp_prot->recvmsg = espintcp_recvmsg;
  461. espintcp_prot->close = espintcp_close;
  462. espintcp_prot->release_cb = espintcp_release;
  463. espintcp_ops->poll = espintcp_poll;
  464. }
  465. static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
  466. .name = "espintcp",
  467. .owner = THIS_MODULE,
  468. .init = espintcp_init_sk,
  469. };
  470. void __init espintcp_init(void)
  471. {
  472. build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
  473. tcp_register_ulp(&espintcp_ulp);
  474. }