request.c 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Handshake request lifetime events
  4. *
  5. * Author: Chuck Lever <chuck.lever@oracle.com>
  6. *
  7. * Copyright (c) 2023, Oracle and/or its affiliates.
  8. */
  9. #include <linux/types.h>
  10. #include <linux/socket.h>
  11. #include <linux/kernel.h>
  12. #include <linux/module.h>
  13. #include <linux/skbuff.h>
  14. #include <linux/inet.h>
  15. #include <linux/fdtable.h>
  16. #include <linux/rhashtable.h>
  17. #include <net/sock.h>
  18. #include <net/genetlink.h>
  19. #include <net/netns/generic.h>
  20. #include <kunit/visibility.h>
  21. #include <uapi/linux/handshake.h>
  22. #include "handshake.h"
  23. #include <trace/events/handshake.h>
  24. /*
  25. * We need both a handshake_req -> sock mapping, and a sock ->
  26. * handshake_req mapping. Both are one-to-one.
  27. *
  28. * To avoid adding another pointer field to struct sock, net/handshake
  29. * maintains a hash table, indexed by the memory address of @sock, to
  30. * find the struct handshake_req outstanding for that socket. The
  31. * reverse direction uses a simple pointer field in the handshake_req
  32. * struct.
  33. */
  34. static struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp;
  35. static const struct rhashtable_params handshake_rhash_params = {
  36. .key_len = sizeof_field(struct handshake_req, hr_sk),
  37. .key_offset = offsetof(struct handshake_req, hr_sk),
  38. .head_offset = offsetof(struct handshake_req, hr_rhash),
  39. .automatic_shrinking = true,
  40. };
  41. int handshake_req_hash_init(void)
  42. {
  43. return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params);
  44. }
  45. void handshake_req_hash_destroy(void)
  46. {
  47. rhashtable_destroy(&handshake_rhashtbl);
  48. }
  49. struct handshake_req *handshake_req_hash_lookup(struct sock *sk)
  50. {
  51. return rhashtable_lookup_fast(&handshake_rhashtbl, &sk,
  52. handshake_rhash_params);
  53. }
  54. EXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup);
  55. static bool handshake_req_hash_add(struct handshake_req *req)
  56. {
  57. int ret;
  58. ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl,
  59. &req->hr_rhash,
  60. handshake_rhash_params);
  61. return ret == 0;
  62. }
  63. static void handshake_req_destroy(struct handshake_req *req)
  64. {
  65. if (req->hr_proto->hp_destroy)
  66. req->hr_proto->hp_destroy(req);
  67. rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash,
  68. handshake_rhash_params);
  69. kfree(req);
  70. }
  71. static void handshake_sk_destruct(struct sock *sk)
  72. {
  73. void (*sk_destruct)(struct sock *sk);
  74. struct handshake_req *req;
  75. req = handshake_req_hash_lookup(sk);
  76. if (!req)
  77. return;
  78. trace_handshake_destruct(sock_net(sk), req, sk);
  79. sk_destruct = req->hr_odestruct;
  80. handshake_req_destroy(req);
  81. if (sk_destruct)
  82. sk_destruct(sk);
  83. }
  84. /**
  85. * handshake_req_alloc - Allocate a handshake request
  86. * @proto: security protocol
  87. * @flags: memory allocation flags
  88. *
  89. * Returns an initialized handshake_req or NULL.
  90. */
  91. struct handshake_req *handshake_req_alloc(const struct handshake_proto *proto,
  92. gfp_t flags)
  93. {
  94. struct handshake_req *req;
  95. if (!proto)
  96. return NULL;
  97. if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE)
  98. return NULL;
  99. if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX)
  100. return NULL;
  101. if (!proto->hp_accept || !proto->hp_done)
  102. return NULL;
  103. req = kzalloc(struct_size(req, hr_priv, proto->hp_privsize), flags);
  104. if (!req)
  105. return NULL;
  106. INIT_LIST_HEAD(&req->hr_list);
  107. req->hr_proto = proto;
  108. return req;
  109. }
  110. EXPORT_SYMBOL(handshake_req_alloc);
  111. /**
  112. * handshake_req_private - Get per-handshake private data
  113. * @req: handshake arguments
  114. *
  115. */
  116. void *handshake_req_private(struct handshake_req *req)
  117. {
  118. return (void *)&req->hr_priv;
  119. }
  120. EXPORT_SYMBOL(handshake_req_private);
  121. static bool __add_pending_locked(struct handshake_net *hn,
  122. struct handshake_req *req)
  123. {
  124. if (WARN_ON_ONCE(!list_empty(&req->hr_list)))
  125. return false;
  126. hn->hn_pending++;
  127. list_add_tail(&req->hr_list, &hn->hn_requests);
  128. return true;
  129. }
  130. static void __remove_pending_locked(struct handshake_net *hn,
  131. struct handshake_req *req)
  132. {
  133. hn->hn_pending--;
  134. list_del_init(&req->hr_list);
  135. }
  136. /*
  137. * Returns %true if the request was found on @net's pending list,
  138. * otherwise %false.
  139. *
  140. * If @req was on a pending list, it has not yet been accepted.
  141. */
  142. static bool remove_pending(struct handshake_net *hn, struct handshake_req *req)
  143. {
  144. bool ret = false;
  145. spin_lock(&hn->hn_lock);
  146. if (!list_empty(&req->hr_list)) {
  147. __remove_pending_locked(hn, req);
  148. ret = true;
  149. }
  150. spin_unlock(&hn->hn_lock);
  151. return ret;
  152. }
  153. struct handshake_req *handshake_req_next(struct handshake_net *hn, int class)
  154. {
  155. struct handshake_req *req, *pos;
  156. req = NULL;
  157. spin_lock(&hn->hn_lock);
  158. list_for_each_entry(pos, &hn->hn_requests, hr_list) {
  159. if (pos->hr_proto->hp_handler_class != class)
  160. continue;
  161. __remove_pending_locked(hn, pos);
  162. req = pos;
  163. break;
  164. }
  165. spin_unlock(&hn->hn_lock);
  166. return req;
  167. }
  168. EXPORT_SYMBOL_IF_KUNIT(handshake_req_next);
  169. /**
  170. * handshake_req_submit - Submit a handshake request
  171. * @sock: open socket on which to perform the handshake
  172. * @req: handshake arguments
  173. * @flags: memory allocation flags
  174. *
  175. * Return values:
  176. * %0: Request queued
  177. * %-EINVAL: Invalid argument
  178. * %-EBUSY: A handshake is already under way for this socket
  179. * %-ESRCH: No handshake agent is available
  180. * %-EAGAIN: Too many pending handshake requests
  181. * %-ENOMEM: Failed to allocate memory
  182. * %-EMSGSIZE: Failed to construct notification message
  183. * %-EOPNOTSUPP: Handshake module not initialized
  184. *
  185. * A zero return value from handshake_req_submit() means that
  186. * exactly one subsequent completion callback is guaranteed.
  187. *
  188. * A negative return value from handshake_req_submit() means that
  189. * no completion callback will be done and that @req has been
  190. * destroyed.
  191. */
  192. int handshake_req_submit(struct socket *sock, struct handshake_req *req,
  193. gfp_t flags)
  194. {
  195. struct handshake_net *hn;
  196. struct net *net;
  197. int ret;
  198. if (!sock || !req || !sock->file) {
  199. kfree(req);
  200. return -EINVAL;
  201. }
  202. req->hr_sk = sock->sk;
  203. if (!req->hr_sk) {
  204. kfree(req);
  205. return -EINVAL;
  206. }
  207. req->hr_odestruct = req->hr_sk->sk_destruct;
  208. req->hr_sk->sk_destruct = handshake_sk_destruct;
  209. ret = -EOPNOTSUPP;
  210. net = sock_net(req->hr_sk);
  211. hn = handshake_pernet(net);
  212. if (!hn)
  213. goto out_err;
  214. ret = -EAGAIN;
  215. if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max)
  216. goto out_err;
  217. spin_lock(&hn->hn_lock);
  218. ret = -EOPNOTSUPP;
  219. if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags))
  220. goto out_unlock;
  221. ret = -EBUSY;
  222. if (!handshake_req_hash_add(req))
  223. goto out_unlock;
  224. if (!__add_pending_locked(hn, req))
  225. goto out_unlock;
  226. spin_unlock(&hn->hn_lock);
  227. ret = handshake_genl_notify(net, req->hr_proto, flags);
  228. if (ret) {
  229. trace_handshake_notify_err(net, req, req->hr_sk, ret);
  230. if (remove_pending(hn, req))
  231. goto out_err;
  232. }
  233. /* Prevent socket release while a handshake request is pending */
  234. sock_hold(req->hr_sk);
  235. trace_handshake_submit(net, req, req->hr_sk);
  236. return 0;
  237. out_unlock:
  238. spin_unlock(&hn->hn_lock);
  239. out_err:
  240. trace_handshake_submit_err(net, req, req->hr_sk, ret);
  241. handshake_req_destroy(req);
  242. return ret;
  243. }
  244. EXPORT_SYMBOL(handshake_req_submit);
  245. void handshake_complete(struct handshake_req *req, unsigned int status,
  246. struct genl_info *info)
  247. {
  248. struct sock *sk = req->hr_sk;
  249. struct net *net = sock_net(sk);
  250. if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
  251. trace_handshake_complete(net, req, sk, status);
  252. req->hr_proto->hp_done(req, status, info);
  253. /* Handshake request is no longer pending */
  254. sock_put(sk);
  255. }
  256. }
  257. EXPORT_SYMBOL_IF_KUNIT(handshake_complete);
  258. /**
  259. * handshake_req_cancel - Cancel an in-progress handshake
  260. * @sk: socket on which there is an ongoing handshake
  261. *
  262. * Request cancellation races with request completion. To determine
  263. * who won, callers examine the return value from this function.
  264. *
  265. * Return values:
  266. * %true - Uncompleted handshake request was canceled
  267. * %false - Handshake request already completed or not found
  268. */
  269. bool handshake_req_cancel(struct sock *sk)
  270. {
  271. struct handshake_req *req;
  272. struct handshake_net *hn;
  273. struct net *net;
  274. net = sock_net(sk);
  275. req = handshake_req_hash_lookup(sk);
  276. if (!req) {
  277. trace_handshake_cancel_none(net, req, sk);
  278. return false;
  279. }
  280. hn = handshake_pernet(net);
  281. if (hn && remove_pending(hn, req)) {
  282. /* Request hadn't been accepted */
  283. goto out_true;
  284. }
  285. if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
  286. /* Request already completed */
  287. trace_handshake_cancel_busy(net, req, sk);
  288. return false;
  289. }
  290. out_true:
  291. trace_handshake_cancel(net, req, sk);
  292. /* Handshake request is no longer pending */
  293. sock_put(sk);
  294. return true;
  295. }
  296. EXPORT_SYMBOL(handshake_req_cancel);