netlink.c 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Generic netlink handshake service
  4. *
  5. * Author: Chuck Lever <chuck.lever@oracle.com>
  6. *
  7. * Copyright (c) 2023, Oracle and/or its affiliates.
  8. */
  9. #include <linux/types.h>
  10. #include <linux/socket.h>
  11. #include <linux/kernel.h>
  12. #include <linux/module.h>
  13. #include <linux/skbuff.h>
  14. #include <linux/mm.h>
  15. #include <net/sock.h>
  16. #include <net/genetlink.h>
  17. #include <net/netns/generic.h>
  18. #include <kunit/visibility.h>
  19. #include <uapi/linux/handshake.h>
  20. #include "handshake.h"
  21. #include "genl.h"
  22. #include <trace/events/handshake.h>
  23. /**
  24. * handshake_genl_notify - Notify handlers that a request is waiting
  25. * @net: target network namespace
  26. * @proto: handshake protocol
  27. * @flags: memory allocation control flags
  28. *
  29. * Returns zero on success or a negative errno if notification failed.
  30. */
  31. int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
  32. gfp_t flags)
  33. {
  34. struct sk_buff *msg;
  35. void *hdr;
  36. /* Disable notifications during unit testing */
  37. if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
  38. return 0;
  39. if (!genl_has_listeners(&handshake_nl_family, net,
  40. proto->hp_handler_class))
  41. return -ESRCH;
  42. msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, flags);
  43. if (!msg)
  44. return -ENOMEM;
  45. hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
  46. HANDSHAKE_CMD_READY);
  47. if (!hdr)
  48. goto out_free;
  49. if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
  50. proto->hp_handler_class) < 0) {
  51. genlmsg_cancel(msg, hdr);
  52. goto out_free;
  53. }
  54. genlmsg_end(msg, hdr);
  55. return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
  56. 0, proto->hp_handler_class, flags);
  57. out_free:
  58. nlmsg_free(msg);
  59. return -EMSGSIZE;
  60. }
  61. /**
  62. * handshake_genl_put - Create a generic netlink message header
  63. * @msg: buffer in which to create the header
  64. * @info: generic netlink message context
  65. *
  66. * Returns a ready-to-use header, or NULL.
  67. */
  68. struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
  69. struct genl_info *info)
  70. {
  71. return genlmsg_put(msg, info->snd_portid, info->snd_seq,
  72. &handshake_nl_family, 0, info->genlhdr->cmd);
  73. }
  74. EXPORT_SYMBOL(handshake_genl_put);
  75. int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
  76. {
  77. struct net *net = sock_net(skb->sk);
  78. struct handshake_net *hn = handshake_pernet(net);
  79. struct handshake_req *req = NULL;
  80. struct socket *sock;
  81. int class, fd, err;
  82. err = -EOPNOTSUPP;
  83. if (!hn)
  84. goto out_status;
  85. err = -EINVAL;
  86. if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
  87. goto out_status;
  88. class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
  89. err = -EAGAIN;
  90. req = handshake_req_next(hn, class);
  91. if (!req)
  92. goto out_status;
  93. sock = req->hr_sk->sk_socket;
  94. fd = get_unused_fd_flags(O_CLOEXEC);
  95. if (fd < 0) {
  96. err = fd;
  97. goto out_complete;
  98. }
  99. err = req->hr_proto->hp_accept(req, info, fd);
  100. if (err) {
  101. put_unused_fd(fd);
  102. goto out_complete;
  103. }
  104. fd_install(fd, get_file(sock->file));
  105. trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
  106. return 0;
  107. out_complete:
  108. handshake_complete(req, -EIO, NULL);
  109. out_status:
  110. trace_handshake_cmd_accept_err(net, req, NULL, err);
  111. return err;
  112. }
  113. int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
  114. {
  115. struct net *net = sock_net(skb->sk);
  116. struct handshake_req *req;
  117. struct socket *sock;
  118. int fd, status, err;
  119. if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
  120. return -EINVAL;
  121. fd = nla_get_s32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
  122. sock = sockfd_lookup(fd, &err);
  123. if (!sock)
  124. return err;
  125. req = handshake_req_hash_lookup(sock->sk);
  126. if (!req) {
  127. err = -EBUSY;
  128. trace_handshake_cmd_done_err(net, req, sock->sk, err);
  129. sockfd_put(sock);
  130. return err;
  131. }
  132. trace_handshake_cmd_done(net, req, sock->sk, fd);
  133. status = -EIO;
  134. if (info->attrs[HANDSHAKE_A_DONE_STATUS])
  135. status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
  136. handshake_complete(req, status, info);
  137. sockfd_put(sock);
  138. return 0;
  139. }
  140. static unsigned int handshake_net_id;
  141. static int __net_init handshake_net_init(struct net *net)
  142. {
  143. struct handshake_net *hn = net_generic(net, handshake_net_id);
  144. unsigned long tmp;
  145. struct sysinfo si;
  146. /*
  147. * Arbitrary limit to prevent handshakes that do not make
  148. * progress from clogging up the system. The cap scales up
  149. * with the amount of physical memory on the system.
  150. */
  151. si_meminfo(&si);
  152. tmp = si.totalram / (25 * si.mem_unit);
  153. hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
  154. spin_lock_init(&hn->hn_lock);
  155. hn->hn_pending = 0;
  156. hn->hn_flags = 0;
  157. INIT_LIST_HEAD(&hn->hn_requests);
  158. return 0;
  159. }
  160. static void __net_exit handshake_net_exit(struct net *net)
  161. {
  162. struct handshake_net *hn = net_generic(net, handshake_net_id);
  163. struct handshake_req *req;
  164. LIST_HEAD(requests);
  165. /*
  166. * Drain the net's pending list. Requests that have been
  167. * accepted and are in progress will be destroyed when
  168. * the socket is closed.
  169. */
  170. spin_lock(&hn->hn_lock);
  171. set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
  172. list_splice_init(&requests, &hn->hn_requests);
  173. spin_unlock(&hn->hn_lock);
  174. while (!list_empty(&requests)) {
  175. req = list_first_entry(&requests, struct handshake_req, hr_list);
  176. list_del(&req->hr_list);
  177. /*
  178. * Requests on this list have not yet been
  179. * accepted, so they do not have an fd to put.
  180. */
  181. handshake_complete(req, -ETIMEDOUT, NULL);
  182. }
  183. }
  184. static struct pernet_operations handshake_genl_net_ops = {
  185. .init = handshake_net_init,
  186. .exit = handshake_net_exit,
  187. .id = &handshake_net_id,
  188. .size = sizeof(struct handshake_net),
  189. };
  190. /**
  191. * handshake_pernet - Get the handshake private per-net structure
  192. * @net: network namespace
  193. *
  194. * Returns a pointer to the net's private per-net structure for the
  195. * handshake module, or NULL if handshake_init() failed.
  196. */
  197. struct handshake_net *handshake_pernet(struct net *net)
  198. {
  199. return handshake_net_id ?
  200. net_generic(net, handshake_net_id) : NULL;
  201. }
  202. EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
  203. static int __init handshake_init(void)
  204. {
  205. int ret;
  206. ret = handshake_req_hash_init();
  207. if (ret) {
  208. pr_warn("handshake: hash initialization failed (%d)\n", ret);
  209. return ret;
  210. }
  211. ret = genl_register_family(&handshake_nl_family);
  212. if (ret) {
  213. pr_warn("handshake: netlink registration failed (%d)\n", ret);
  214. handshake_req_hash_destroy();
  215. return ret;
  216. }
  217. /*
  218. * ORDER: register_pernet_subsys must be done last.
  219. *
  220. * If initialization does not make it past pernet_subsys
  221. * registration, then handshake_net_id will remain 0. That
  222. * shunts the handshake consumer API to return ENOTSUPP
  223. * to prevent it from dereferencing something that hasn't
  224. * been allocated.
  225. */
  226. ret = register_pernet_subsys(&handshake_genl_net_ops);
  227. if (ret) {
  228. pr_warn("handshake: pernet registration failed (%d)\n", ret);
  229. genl_unregister_family(&handshake_nl_family);
  230. handshake_req_hash_destroy();
  231. }
  232. return ret;
  233. }
  234. static void __exit handshake_exit(void)
  235. {
  236. unregister_pernet_subsys(&handshake_genl_net_ops);
  237. handshake_net_id = 0;
  238. handshake_req_hash_destroy();
  239. genl_unregister_family(&handshake_nl_family);
  240. }
  241. module_init(handshake_init);
  242. module_exit(handshake_exit);