p9_fwd.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #!/usr/bin/env python3
  2. # SPDX-License-Identifier: GPL-2.0
  3. import argparse
  4. import errno
  5. import logging
  6. import socket
  7. import struct
  8. import time
  9. import usb.core
  10. import usb.util
  11. def path_from_usb_dev(dev):
  12. """Takes a pyUSB device as argument and returns a string.
  13. The string is a Path representation of the position of the USB device on the USB bus tree.
  14. This path is used to find a USB device on the bus or all devices connected to a HUB.
  15. The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
  16. if dev.port_numbers:
  17. dev_path = ".".join(str(i) for i in dev.port_numbers)
  18. return f"{dev.bus}-{dev_path}"
  19. return ""
  20. HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128
  21. class Forwarder:
  22. @staticmethod
  23. def _log_hexdump(data):
  24. if not logging.root.isEnabledFor(logging.TRACE):
  25. return
  26. L = 16
  27. for c in range(0, len(data), L):
  28. chars = data[c : c + L]
  29. dump = " ".join(f"{x:02x}" for x in chars)
  30. printable = "".join(HEXDUMP_FILTER[x] for x in chars)
  31. line = f"{c:08x} {dump:{L*3}s} |{printable:{L}s}|"
  32. logging.root.log(logging.TRACE, "%s", line)
  33. def __init__(self, server, vid, pid, path):
  34. self.stats = {
  35. "c2s packets": 0,
  36. "c2s bytes": 0,
  37. "s2c packets": 0,
  38. "s2c bytes": 0,
  39. }
  40. self.stats_logged = time.monotonic()
  41. def find_filter(dev):
  42. dev_path = path_from_usb_dev(dev)
  43. if path is not None:
  44. return dev_path == path
  45. return True
  46. dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter)
  47. if dev is None:
  48. raise ValueError("Device not found")
  49. logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
  50. # dev.set_configuration() is not necessary since g_multi has only one
  51. usb9pfs = None
  52. # g_multi adds 9pfs as last interface
  53. cfg = dev.get_active_configuration()
  54. for intf in cfg:
  55. # we have to detach the usb-storage driver from multi gadget since
  56. # stall option could be set, which will lead to spontaneous port
  57. # resets and our transfers will run dead
  58. if intf.bInterfaceClass == 0x08:
  59. if dev.is_kernel_driver_active(intf.bInterfaceNumber):
  60. dev.detach_kernel_driver(intf.bInterfaceNumber)
  61. if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09:
  62. usb9pfs = intf
  63. if usb9pfs is None:
  64. raise ValueError("Interface not found")
  65. logging.info(f"claiming interface:\n{usb9pfs}")
  66. usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber)
  67. ep_out = usb.util.find_descriptor(
  68. usb9pfs,
  69. custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT,
  70. )
  71. assert ep_out is not None
  72. ep_in = usb.util.find_descriptor(
  73. usb9pfs,
  74. custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN,
  75. )
  76. assert ep_in is not None
  77. logging.info("interface claimed")
  78. self.ep_out = ep_out
  79. self.ep_in = ep_in
  80. self.dev = dev
  81. # create and connect socket
  82. self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  83. self.s.connect(server)
  84. logging.info("connected to server")
  85. def c2s(self):
  86. """forward a request from the USB client to the TCP server"""
  87. data = None
  88. while data is None:
  89. try:
  90. logging.log(logging.TRACE, "c2s: reading")
  91. data = self.ep_in.read(self.ep_in.wMaxPacketSize)
  92. except usb.core.USBTimeoutError:
  93. logging.log(logging.TRACE, "c2s: reading timed out")
  94. continue
  95. except usb.core.USBError as e:
  96. if e.errno == errno.EIO:
  97. logging.debug("c2s: reading failed with %s, retrying", repr(e))
  98. time.sleep(0.5)
  99. continue
  100. logging.error("c2s: reading failed with %s, aborting", repr(e))
  101. raise
  102. size = struct.unpack("<I", data[:4])[0]
  103. while len(data) < size:
  104. data += self.ep_in.read(size - len(data))
  105. logging.log(logging.TRACE, "c2s: writing")
  106. self._log_hexdump(data)
  107. self.s.send(data)
  108. logging.debug("c2s: forwarded %i bytes", size)
  109. self.stats["c2s packets"] += 1
  110. self.stats["c2s bytes"] += size
  111. def s2c(self):
  112. """forward a response from the TCP server to the USB client"""
  113. logging.log(logging.TRACE, "s2c: reading")
  114. data = self.s.recv(4)
  115. size = struct.unpack("<I", data[:4])[0]
  116. while len(data) < size:
  117. data += self.s.recv(size - len(data))
  118. logging.log(logging.TRACE, "s2c: writing")
  119. self._log_hexdump(data)
  120. while data:
  121. written = self.ep_out.write(data)
  122. assert written > 0
  123. data = data[written:]
  124. if size % self.ep_out.wMaxPacketSize == 0:
  125. logging.log(logging.TRACE, "sending zero length packet")
  126. self.ep_out.write(b"")
  127. logging.debug("s2c: forwarded %i bytes", size)
  128. self.stats["s2c packets"] += 1
  129. self.stats["s2c bytes"] += size
  130. def log_stats(self):
  131. logging.info("statistics:")
  132. for k, v in self.stats.items():
  133. logging.info(f" {k+':':14s} {v}")
  134. def log_stats_interval(self, interval=5):
  135. if (time.monotonic() - self.stats_logged) < interval:
  136. return
  137. self.log_stats()
  138. self.stats_logged = time.monotonic()
  139. def try_get_usb_str(dev, name):
  140. try:
  141. with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f:
  142. return f.read().strip()
  143. except FileNotFoundError:
  144. return None
  145. def list_usb(args):
  146. vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
  147. print("Bus | Addr | Manufacturer | Product | ID | Path")
  148. print("--- | ---- | ---------------- | ---------------- | --------- | ----")
  149. for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid):
  150. path = path_from_usb_dev(dev) or ""
  151. manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown"
  152. product = try_get_usb_str(dev, "product") or "unknown"
  153. print(
  154. f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
  155. )
  156. def connect(args):
  157. vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
  158. f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path)
  159. try:
  160. while True:
  161. f.c2s()
  162. f.s2c()
  163. f.log_stats_interval()
  164. finally:
  165. f.log_stats()
  166. def main():
  167. parser = argparse.ArgumentParser(
  168. description="Forward 9PFS requests from USB to TCP",
  169. )
  170. parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device")
  171. parser.add_argument("--path", type=str, required=False, help="path of target device")
  172. parser.add_argument("-v", "--verbose", action="count", default=0)
  173. subparsers = parser.add_subparsers()
  174. subparsers.required = True
  175. subparsers.dest = "command"
  176. parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets")
  177. parser_list.set_defaults(func=list_usb)
  178. parser_connect = subparsers.add_parser(
  179. "connect", help="Forward messages between the usb9pfs gadget and the 9p server"
  180. )
  181. parser_connect.set_defaults(func=connect)
  182. connect_group = parser_connect.add_argument_group()
  183. connect_group.required = True
  184. parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname")
  185. parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port")
  186. args = parser.parse_args()
  187. logging.TRACE = logging.DEBUG - 5
  188. logging.addLevelName(logging.TRACE, "TRACE")
  189. if args.verbose >= 2:
  190. level = logging.TRACE
  191. elif args.verbose:
  192. level = logging.DEBUG
  193. else:
  194. level = logging.INFO
  195. logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s")
  196. args.func(args)
  197. if __name__ == "__main__":
  198. main()