alert.c 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Handle the TLS Alert protocol
  4. *
  5. * Author: Chuck Lever <chuck.lever@oracle.com>
  6. *
  7. * Copyright (c) 2023, Oracle and/or its affiliates.
  8. */
  9. #include <linux/types.h>
  10. #include <linux/socket.h>
  11. #include <linux/kernel.h>
  12. #include <linux/module.h>
  13. #include <linux/skbuff.h>
  14. #include <linux/inet.h>
  15. #include <net/sock.h>
  16. #include <net/handshake.h>
  17. #include <net/tls.h>
  18. #include <net/tls_prot.h>
  19. #include "handshake.h"
  20. #include <trace/events/handshake.h>
  21. /**
  22. * tls_alert_send - send a TLS Alert on a kTLS socket
  23. * @sock: open kTLS socket to send on
  24. * @level: TLS Alert level
  25. * @description: TLS Alert description
  26. *
  27. * Returns zero on success or a negative errno.
  28. */
  29. int tls_alert_send(struct socket *sock, u8 level, u8 description)
  30. {
  31. u8 record_type = TLS_RECORD_TYPE_ALERT;
  32. u8 buf[CMSG_SPACE(sizeof(record_type))];
  33. struct msghdr msg = { 0 };
  34. struct cmsghdr *cmsg;
  35. struct kvec iov;
  36. u8 alert[2];
  37. int ret;
  38. trace_tls_alert_send(sock->sk, level, description);
  39. alert[0] = level;
  40. alert[1] = description;
  41. iov.iov_base = alert;
  42. iov.iov_len = sizeof(alert);
  43. memset(buf, 0, sizeof(buf));
  44. msg.msg_control = buf;
  45. msg.msg_controllen = sizeof(buf);
  46. msg.msg_flags = MSG_DONTWAIT;
  47. cmsg = CMSG_FIRSTHDR(&msg);
  48. cmsg->cmsg_level = SOL_TLS;
  49. cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
  50. cmsg->cmsg_len = CMSG_LEN(sizeof(record_type));
  51. memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type));
  52. iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len);
  53. ret = sock_sendmsg(sock, &msg);
  54. return ret < 0 ? ret : 0;
  55. }
  56. /**
  57. * tls_get_record_type - Look for TLS RECORD_TYPE information
  58. * @sk: socket (for IP address information)
  59. * @cmsg: incoming message to be parsed
  60. *
  61. * Returns zero or a TLS_RECORD_TYPE value.
  62. */
  63. u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg)
  64. {
  65. u8 record_type;
  66. if (cmsg->cmsg_level != SOL_TLS)
  67. return 0;
  68. if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE)
  69. return 0;
  70. record_type = *((u8 *)CMSG_DATA(cmsg));
  71. trace_tls_contenttype(sk, record_type);
  72. return record_type;
  73. }
  74. EXPORT_SYMBOL(tls_get_record_type);
  75. /**
  76. * tls_alert_recv - Parse TLS Alert messages
  77. * @sk: socket (for IP address information)
  78. * @msg: incoming message to be parsed
  79. * @level: OUT - TLS AlertLevel value
  80. * @description: OUT - TLS AlertDescription value
  81. *
  82. */
  83. void tls_alert_recv(const struct sock *sk, const struct msghdr *msg,
  84. u8 *level, u8 *description)
  85. {
  86. const struct kvec *iov;
  87. u8 *data;
  88. iov = msg->msg_iter.kvec;
  89. data = iov->iov_base;
  90. *level = data[0];
  91. *description = data[1];
  92. trace_tls_alert_recv(sk, *level, *description);
  93. }
  94. EXPORT_SYMBOL(tls_alert_recv);