vhost_task.c 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Copyright (C) 2021 Oracle Corporation
  4. */
  5. #include <linux/slab.h>
  6. #include <linux/completion.h>
  7. #include <linux/sched/task.h>
  8. #include <linux/sched/vhost_task.h>
  9. #include <linux/sched/signal.h>
  10. enum vhost_task_flags {
  11. VHOST_TASK_FLAGS_STOP,
  12. VHOST_TASK_FLAGS_KILLED,
  13. };
  14. struct vhost_task {
  15. bool (*fn)(void *data);
  16. void (*handle_sigkill)(void *data);
  17. void *data;
  18. struct completion exited;
  19. unsigned long flags;
  20. struct task_struct *task;
  21. /* serialize SIGKILL and vhost_task_stop calls */
  22. struct mutex exit_mutex;
  23. };
  24. static int vhost_task_fn(void *data)
  25. {
  26. struct vhost_task *vtsk = data;
  27. for (;;) {
  28. bool did_work;
  29. if (signal_pending(current)) {
  30. struct ksignal ksig;
  31. if (get_signal(&ksig))
  32. break;
  33. }
  34. /* mb paired w/ vhost_task_stop */
  35. set_current_state(TASK_INTERRUPTIBLE);
  36. if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
  37. __set_current_state(TASK_RUNNING);
  38. break;
  39. }
  40. did_work = vtsk->fn(vtsk->data);
  41. if (!did_work)
  42. schedule();
  43. }
  44. mutex_lock(&vtsk->exit_mutex);
  45. /*
  46. * If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL.
  47. * When the vhost layer has called vhost_task_stop it's already stopped
  48. * new work and flushed.
  49. */
  50. if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
  51. set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags);
  52. vtsk->handle_sigkill(vtsk->data);
  53. }
  54. mutex_unlock(&vtsk->exit_mutex);
  55. complete(&vtsk->exited);
  56. do_exit(0);
  57. }
  58. /**
  59. * vhost_task_wake - wakeup the vhost_task
  60. * @vtsk: vhost_task to wake
  61. *
  62. * wake up the vhost_task worker thread
  63. */
  64. void vhost_task_wake(struct vhost_task *vtsk)
  65. {
  66. wake_up_process(vtsk->task);
  67. }
  68. EXPORT_SYMBOL_GPL(vhost_task_wake);
  69. /**
  70. * vhost_task_stop - stop a vhost_task
  71. * @vtsk: vhost_task to stop
  72. *
  73. * vhost_task_fn ensures the worker thread exits after
  74. * VHOST_TASK_FLAGS_STOP becomes true.
  75. */
  76. void vhost_task_stop(struct vhost_task *vtsk)
  77. {
  78. mutex_lock(&vtsk->exit_mutex);
  79. if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) {
  80. set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
  81. vhost_task_wake(vtsk);
  82. }
  83. mutex_unlock(&vtsk->exit_mutex);
  84. /*
  85. * Make sure vhost_task_fn is no longer accessing the vhost_task before
  86. * freeing it below.
  87. */
  88. wait_for_completion(&vtsk->exited);
  89. put_task_struct(vtsk->task);
  90. kfree(vtsk);
  91. }
  92. EXPORT_SYMBOL_GPL(vhost_task_stop);
  93. /**
  94. * vhost_task_create - create a copy of a task to be used by the kernel
  95. * @fn: vhost worker function
  96. * @handle_sigkill: vhost function to handle when we are killed
  97. * @arg: data to be passed to fn and handled_kill
  98. * @name: the thread's name
  99. *
  100. * This returns a specialized task for use by the vhost layer or ERR_PTR() on
  101. * failure. The returned task is inactive, and the caller must fire it up
  102. * through vhost_task_start().
  103. */
  104. struct vhost_task *vhost_task_create(bool (*fn)(void *),
  105. void (*handle_sigkill)(void *), void *arg,
  106. const char *name)
  107. {
  108. struct kernel_clone_args args = {
  109. .flags = CLONE_FS | CLONE_UNTRACED | CLONE_VM |
  110. CLONE_THREAD | CLONE_SIGHAND,
  111. .exit_signal = 0,
  112. .fn = vhost_task_fn,
  113. .name = name,
  114. .user_worker = 1,
  115. .no_files = 1,
  116. };
  117. struct vhost_task *vtsk;
  118. struct task_struct *tsk;
  119. vtsk = kzalloc(sizeof(*vtsk), GFP_KERNEL);
  120. if (!vtsk)
  121. return ERR_PTR(-ENOMEM);
  122. init_completion(&vtsk->exited);
  123. mutex_init(&vtsk->exit_mutex);
  124. vtsk->data = arg;
  125. vtsk->fn = fn;
  126. vtsk->handle_sigkill = handle_sigkill;
  127. args.fn_arg = vtsk;
  128. tsk = copy_process(NULL, 0, NUMA_NO_NODE, &args);
  129. if (IS_ERR(tsk)) {
  130. kfree(vtsk);
  131. return ERR_PTR(PTR_ERR(tsk));
  132. }
  133. vtsk->task = get_task_struct(tsk);
  134. return vtsk;
  135. }
  136. EXPORT_SYMBOL_GPL(vhost_task_create);
  137. /**
  138. * vhost_task_start - start a vhost_task created with vhost_task_create
  139. * @vtsk: vhost_task to wake up
  140. */
  141. void vhost_task_start(struct vhost_task *vtsk)
  142. {
  143. wake_up_new_task(vtsk->task);
  144. }
  145. EXPORT_SYMBOL_GPL(vhost_task_start);