vhost_task.c 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Copyright (C) 2021 Oracle Corporation
  4. */
  5. #include <linux/slab.h>
  6. #include <linux/completion.h>
  7. #include <linux/sched/task.h>
  8. #include <linux/sched/vhost_task.h>
  9. #include <linux/sched/signal.h>
  10. enum vhost_task_flags {
  11. VHOST_TASK_FLAGS_STOP,
  12. VHOST_TASK_FLAGS_KILLED,
  13. };
  14. struct vhost_task {
  15. bool (*fn)(void *data);
  16. void (*handle_sigkill)(void *data);
  17. void *data;
  18. struct completion exited;
  19. unsigned long flags;
  20. struct task_struct *task;
  21. /* serialize SIGKILL and vhost_task_stop calls */
  22. struct mutex exit_mutex;
  23. };
  24. static int vhost_task_fn(void *data)
  25. {
  26. struct vhost_task *vtsk = data;
  27. for (;;) {
  28. bool did_work;
  29. if (signal_pending(current)) {
  30. struct ksignal ksig;
  31. if (get_signal(&ksig))
  32. break;
  33. }
  34. /* mb paired w/ vhost_task_stop */
  35. set_current_state(TASK_INTERRUPTIBLE);
  36. if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
  37. __set_current_state(TASK_RUNNING);
  38. break;
  39. }
  40. did_work = vtsk->fn(vtsk->data);
  41. if (!did_work)
  42. schedule();
  43. }
  44. mutex_lock(&vtsk->exit_mutex);
  45. /*
  46. * If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL.
  47. * When the vhost layer has called vhost_task_stop it's already stopped
  48. * new work and flushed.
  49. */
  50. if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
  51. set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags);
  52. vtsk->handle_sigkill(vtsk->data);
  53. }
  54. mutex_unlock(&vtsk->exit_mutex);
  55. complete(&vtsk->exited);
  56. do_exit(0);
  57. }
  58. /**
  59. * vhost_task_wake - wakeup the vhost_task
  60. * @vtsk: vhost_task to wake
  61. *
  62. * wake up the vhost_task worker thread
  63. */
  64. void vhost_task_wake(struct vhost_task *vtsk)
  65. {
  66. wake_up_process(vtsk->task);
  67. }
  68. EXPORT_SYMBOL_GPL(vhost_task_wake);
  69. /**
  70. * vhost_task_stop - stop a vhost_task
  71. * @vtsk: vhost_task to stop
  72. *
  73. * vhost_task_fn ensures the worker thread exits after
  74. * VHOST_TASK_FLAGS_STOP becomes true.
  75. */
  76. void vhost_task_stop(struct vhost_task *vtsk)
  77. {
  78. mutex_lock(&vtsk->exit_mutex);
  79. if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) {
  80. set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
  81. vhost_task_wake(vtsk);
  82. }
  83. mutex_unlock(&vtsk->exit_mutex);
  84. /*
  85. * Make sure vhost_task_fn is no longer accessing the vhost_task before
  86. * freeing it below.
  87. */
  88. wait_for_completion(&vtsk->exited);
  89. kfree(vtsk);
  90. }
  91. EXPORT_SYMBOL_GPL(vhost_task_stop);
  92. /**
  93. * vhost_task_create - create a copy of a task to be used by the kernel
  94. * @fn: vhost worker function
  95. * @handle_sigkill: vhost function to handle when we are killed
  96. * @arg: data to be passed to fn and handled_kill
  97. * @name: the thread's name
  98. *
  99. * This returns a specialized task for use by the vhost layer or NULL on
  100. * failure. The returned task is inactive, and the caller must fire it up
  101. * through vhost_task_start().
  102. */
  103. struct vhost_task *vhost_task_create(bool (*fn)(void *),
  104. void (*handle_sigkill)(void *), void *arg,
  105. const char *name)
  106. {
  107. struct kernel_clone_args args = {
  108. .flags = CLONE_FS | CLONE_UNTRACED | CLONE_VM |
  109. CLONE_THREAD | CLONE_SIGHAND,
  110. .exit_signal = 0,
  111. .fn = vhost_task_fn,
  112. .name = name,
  113. .user_worker = 1,
  114. .no_files = 1,
  115. };
  116. struct vhost_task *vtsk;
  117. struct task_struct *tsk;
  118. vtsk = kzalloc(sizeof(*vtsk), GFP_KERNEL);
  119. if (!vtsk)
  120. return NULL;
  121. init_completion(&vtsk->exited);
  122. mutex_init(&vtsk->exit_mutex);
  123. vtsk->data = arg;
  124. vtsk->fn = fn;
  125. vtsk->handle_sigkill = handle_sigkill;
  126. args.fn_arg = vtsk;
  127. tsk = copy_process(NULL, 0, NUMA_NO_NODE, &args);
  128. if (IS_ERR(tsk)) {
  129. kfree(vtsk);
  130. return NULL;
  131. }
  132. vtsk->task = tsk;
  133. return vtsk;
  134. }
  135. EXPORT_SYMBOL_GPL(vhost_task_create);
  136. /**
  137. * vhost_task_start - start a vhost_task created with vhost_task_create
  138. * @vtsk: vhost_task to wake up
  139. */
  140. void vhost_task_start(struct vhost_task *vtsk)
  141. {
  142. wake_up_new_task(vtsk->task);
  143. }
  144. EXPORT_SYMBOL_GPL(vhost_task_start);