You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

509 lines
17 KiB

  1. #
  2. # This file is part of usb_protocol.
  3. #
  4. """ Convenience emitters for simple, standard descriptors. """
  5. import unittest
  6. from contextlib import contextmanager
  7. from .. import emitter_for_format
  8. from ..descriptor import ComplexDescriptorEmitter
  9. from ...types import LanguageIDs
  10. from ...types.descriptors.standard import \
  11. DeviceDescriptor, StringDescriptor, EndpointDescriptor, DeviceQualifierDescriptor, \
  12. ConfigurationDescriptor, InterfaceDescriptor, StandardDescriptorNumbers, StringLanguageDescriptor, \
  13. BinaryObjectStoreDescriptor, SuperSpeedEndpointCompanionDescriptor
  14. # Create our basic emitters...
  15. DeviceDescriptorEmitter = emitter_for_format(DeviceDescriptor)
  16. StringDescriptorEmitter = emitter_for_format(StringDescriptor)
  17. StringLanguageDescriptorEmitter = emitter_for_format(StringLanguageDescriptor)
  18. DeviceQualifierDescriptor = emitter_for_format(DeviceQualifierDescriptor)
  19. # ... our basic superspeed emitters ...
  20. BinaryObjectStoreDescriptorEmitter = emitter_for_format(BinaryObjectStoreDescriptor)
  21. SuperSpeedEndpointCompanionDescriptorEmitter = emitter_for_format(SuperSpeedEndpointCompanionDescriptor)
  22. # ... convenience functions ...
  23. def get_string_descriptor(string):
  24. """ Generates a string descriptor for the relevant string. """
  25. emitter = StringDescriptorEmitter()
  26. emitter.bString = string
  27. return emitter.emit()
  28. # ... and complex emitters.
  29. class EndpointDescriptorEmitter(ComplexDescriptorEmitter):
  30. """ Emitter that creates an InterfaceDescriptor. """
  31. DESCRIPTOR_FORMAT = EndpointDescriptor
  32. @contextmanager
  33. def SuperSpeedCompanion(self):
  34. """ Context manager that allows addition of a SuperSpeed Companion to this endpoint descriptor.
  35. It can be used with a `with` statement; and yields an SuperSpeedEndpointCompanionDescriptorEmitter
  36. that can be populated:
  37. with endpoint.SuperSpeedEndpointCompanion() as d:
  38. d.bMaxBurst = 1
  39. This adds the relevant descriptor, automatically.
  40. """
  41. descriptor = SuperSpeedEndpointCompanionDescriptorEmitter()
  42. yield descriptor
  43. self.add_subordinate_descriptor(descriptor)
  44. class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
  45. """ Emitter that creates an InterfaceDescriptor. """
  46. DESCRIPTOR_FORMAT = InterfaceDescriptor
  47. @contextmanager
  48. def EndpointDescriptor(self, *, add_default_superspeed=False):
  49. """ Context manager that allows addition of a subordinate endpoint descriptor.
  50. It can be used with a `with` statement; and yields an EndpointDesriptorEmitter
  51. that can be populated:
  52. with interface.EndpointDescriptor() as d:
  53. d.bEndpointAddress = 0x01
  54. d.bmAttributes = 0x80
  55. d.wMaxPacketSize = 64
  56. d.bInterval = 0
  57. This adds the relevant descriptor, automatically.
  58. """
  59. descriptor = EndpointDescriptorEmitter()
  60. yield descriptor
  61. # If we're adding a default SuperSpeed extension, do so.
  62. if add_default_superspeed:
  63. with descriptor.SuperSpeedCompanion():
  64. pass
  65. self.add_subordinate_descriptor(descriptor)
  66. def _pre_emit(self):
  67. # Count our endpoints, and update our internal count.
  68. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
  69. # Ensure that our interface string is an index, if we can.
  70. if self._collection and hasattr(self, 'iInterface'):
  71. self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)
  72. class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
  73. """ Emitter that creates a configuration descriptor. """
  74. DESCRIPTOR_FORMAT = ConfigurationDescriptor
  75. @contextmanager
  76. def InterfaceDescriptor(self):
  77. """ Context manager that allows addition of a subordinate interface descriptor.
  78. It can be used with a `with` statement; and yields an InterfaceDescriptorEmitter
  79. that can be populated:
  80. with interface.InterfaceDescriptor() as d:
  81. d.bInterfaceNumber = 0x01
  82. [snip]
  83. This adds the relevant descriptor, automatically. Note that populating derived
  84. fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
  85. """
  86. descriptor = InterfaceDescriptorEmitter(collection=self._collection)
  87. yield descriptor
  88. self.add_subordinate_descriptor(descriptor)
  89. def _pre_emit(self):
  90. # Count our interfaces.
  91. self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE]
  92. # Figure out our total length.
  93. subordinate_length = sum(len(sub) for sub in self._subordinates)
  94. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  95. # Ensure that our configuration string is an index, if we can.
  96. if self._collection and hasattr(self, 'iConfiguration'):
  97. self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)
  98. class DeviceDescriptorCollection:
  99. """ Object that builds a full collection of descriptors related to a given USB device. """
  100. # Most systems seem happiest with en_US (ugh), so default to that.
  101. DEFAULT_SUPPORTED_LANGUAGES = [LanguageIDs.ENGLISH_US]
  102. def __init__(self, automatic_language_descriptor=True):
  103. """
  104. Parameters:
  105. automatic_language_descriptor -- If set or not provided, a language descriptor will automatically
  106. be added if none exists.
  107. """
  108. self._automatic_language_descriptor = automatic_language_descriptor
  109. # Create our internal descriptor tracker.
  110. # Keys are a tuple of (type, index).
  111. self._descriptors = {}
  112. # Track string descriptors as they're created.
  113. self._next_string_index = 1
  114. self._index_for_string = {}
  115. def ensure_string_field_is_index(self, field_value):
  116. """ Processes the given field value; if it's not an string index, converts it to one.
  117. Non-index-fields are converted to indices using `get_index_for_string`, which automatically
  118. adds the relevant fields to our string descriptor collection.
  119. """
  120. if isinstance(field_value, str):
  121. return self.get_index_for_string(field_value)
  122. else:
  123. return field_value
  124. def get_index_for_string(self, string):
  125. """ Returns an string descriptor index for the given string.
  126. If a string descriptor already exists for the given string, its index is
  127. returned. Otherwise, a string descriptor is created.
  128. """
  129. # If we already have a descriptor for this string, return it.
  130. if string in self._index_for_string:
  131. return self._index_for_string[string]
  132. # Otherwise, create one:
  133. # Allocate an index...
  134. index = self._next_string_index
  135. self._index_for_string[string] = index
  136. self._next_string_index += 1
  137. # ... store our string descriptor with it ...
  138. identifier = StandardDescriptorNumbers.STRING, index
  139. self._descriptors[identifier] = get_string_descriptor(string)
  140. # ... and return our index.
  141. return index
  142. def add_descriptor(self, descriptor, index=0):
  143. """ Adds a descriptor to our collection.
  144. Parameters:
  145. descriptor -- The descriptor to be added.
  146. index -- The index of the relevant descriptor. Defaults to 0.
  147. """
  148. # If this is an emitter rather than a descriptor itself, convert it.
  149. if hasattr(descriptor, 'emit'):
  150. descriptor = descriptor.emit()
  151. # Figure out the identifier (type + index) for this descriptor...
  152. descriptor_type = descriptor[1]
  153. identifier = descriptor_type, index
  154. # ... and store it.
  155. self._descriptors[identifier] = descriptor
  156. def add_language_descriptor(self, supported_languages=None):
  157. """ Adds a language descriptor to the list of device descriptors.
  158. Parameters:
  159. supported_languages -- A list of languages supported by the device.
  160. """
  161. if supported_languages is None:
  162. supported_languages = self.DEFAULT_SUPPORTED_LANGUAGES
  163. descriptor = StringLanguageDescriptorEmitter()
  164. descriptor.wLANGID = supported_languages
  165. self.add_descriptor(descriptor)
  166. @contextmanager
  167. def DeviceDescriptor(self):
  168. """ Context manager that allows addition of a device descriptor.
  169. It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
  170. that can be populated:
  171. with collection.DeviceDescriptor() as d:
  172. d.idVendor = 0xabcd
  173. d.idProduct = 0x1234
  174. [snip]
  175. This adds the relevant descriptor, automatically.
  176. """
  177. descriptor = DeviceDescriptorEmitter()
  178. yield descriptor
  179. # If we have any string fields, ensure that they're indices before continuing.
  180. for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
  181. if hasattr(descriptor, field):
  182. value = getattr(descriptor, field)
  183. index = self.ensure_string_field_is_index(value)
  184. setattr(descriptor, field, index)
  185. self.add_descriptor(descriptor)
  186. @contextmanager
  187. def ConfigurationDescriptor(self):
  188. """ Context manager that allows addition of a configuration descriptor.
  189. It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
  190. that can be populated:
  191. with collection.ConfigurationDescriptor() as d:
  192. d.bConfigurationValue = 1
  193. [snip]
  194. This adds the relevant descriptor, automatically. Note that populating derived
  195. fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
  196. """
  197. descriptor = ConfigurationDescriptorEmitter()
  198. yield descriptor
  199. self.add_descriptor(descriptor)
  200. def _ensure_has_language_descriptor(self):
  201. """ ensures that we have a language descriptor; adding one if necessary."""
  202. # if we're not automatically adding a language descriptor, we shouldn't do anything,
  203. # and we'll just ignore this.
  204. if not self._automatic_language_descriptor:
  205. return
  206. # if we don't have a language descriptor, add our default one.
  207. if (StandardDescriptorNumbers.STRING, 0) not in self._descriptors:
  208. self.add_language_descriptor()
  209. def get_descriptor_bytes(self, type_number: int, index: int = 0):
  210. """ Returns the raw, binary descriptor for a given descriptor type/index.
  211. Parmeters:
  212. type_number -- The descriptor type number.
  213. index -- The index of the relevant descriptor, if relevant.
  214. """
  215. # If this is a request for a language descriptor, return one.
  216. if (type_number, index) == (StandardDescriptorNumbers.STRING, 0):
  217. self._ensure_has_language_descriptor()
  218. return self._descriptors[(type_number, index)]
  219. def __iter__(self):
  220. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  221. self._ensure_has_language_descriptor()
  222. return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())
  223. class SuperSpeedDeviceDescriptorCollection(DeviceDescriptorCollection):
  224. """ Object that builds a full collection of descriptors related to a given USB3 device. """
  225. def __init__(self, automatic_descriptors=True):
  226. """
  227. Parameters:
  228. automatic_descriptors -- If set or not provided, certian required descriptors will be
  229. be added if none exists.
  230. """
  231. self._automatic_descriptors = automatic_descriptors
  232. super().__init__(automatic_language_descriptor=automatic_descriptors)
  233. @contextmanager
  234. def BOSDescriptor(self):
  235. """ Context manager that allows addition of a Binary Object Store descriptor.
  236. It can be used with a `with` statement; and yields an BinaryObjectStoreDescriptorEmitter
  237. that can be populated:
  238. with collection.BOSDescriptor() as d:
  239. [snip]
  240. This adds the relevant descriptor, automatically. Note that populating derived
  241. fields such as bNumDeviceCaps aren't necessary; they'll be populated automatically.
  242. """
  243. descriptor = BinaryObjectStoreDescriptorEmitter()
  244. yield descriptor
  245. self.add_descriptor(descriptor)
  246. def add_default_bos_descriptor(self):
  247. """ Adds a default, empty BOS descriptor. """
  248. descriptor = BinaryObjectStoreDescriptorEmitter()
  249. self.add_descriptor(descriptor)
  250. def _ensure_has_bos_descriptor(self):
  251. """ Ensures that we have a BOS descriptor; adding one if necessary."""
  252. # If we're not automatically adding a language descriptor, we shouldn't do anything,
  253. # and we'll just ignore this.
  254. if not self._automatic_descriptors:
  255. return
  256. # If we don't have a language descriptor, add our default one.
  257. if (StandardDescriptorNumbers.BOS, 0) not in self._descriptors:
  258. self.add_default_bos_descriptor()
  259. def __iter__(self):
  260. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  261. self._ensure_has_bos_descriptor()
  262. return super().__iter__()
  263. class EmitterTests(unittest.TestCase):
  264. def test_string_emitter(self):
  265. emitter = StringDescriptorEmitter()
  266. emitter.bString = "Hello"
  267. self.assertEqual(emitter.emit(), b"\x0C\x03H\0e\0l\0l\0o\0")
  268. def test_string_emitter_function(self):
  269. self.assertEqual(get_string_descriptor("Hello"), b"\x0C\x03H\0e\0l\0l\0o\0")
  270. def test_configuration_emitter(self):
  271. descriptor = bytes([
  272. # config descriptor
  273. 12, # length
  274. 2, # type
  275. 25, 00, # total length
  276. 1, # num interfaces
  277. 1, # configuration number
  278. 0, # config string
  279. 0x80, # attributes
  280. 250, # max power
  281. # interface descriptor
  282. 9, # length
  283. 4, # type
  284. 0, # number
  285. 0, # alternate
  286. 1, # num endpoints
  287. 0xff, # class
  288. 0xff, # subclass
  289. 0xff, # protocol
  290. 0, # string
  291. # endpoint descriptor
  292. 7, # length
  293. 5, # type
  294. 0x01, # address
  295. 2, # attributes
  296. 64, 0, # max packet size
  297. 255, # interval
  298. ])
  299. # Create a trivial configuration descriptor...
  300. emitter = ConfigurationDescriptorEmitter()
  301. with emitter.InterfaceDescriptor() as interface:
  302. interface.bInterfaceNumber = 0
  303. with interface.EndpointDescriptor() as endpoint:
  304. endpoint.bEndpointAddress = 1
  305. # ... and validate that it maches our reference descriptor.
  306. binary = emitter.emit()
  307. self.assertEqual(len(binary), len(descriptor))
  308. def test_descriptor_collection(self):
  309. collection = DeviceDescriptorCollection()
  310. with collection.DeviceDescriptor() as d:
  311. d.idVendor = 0xdead
  312. d.idProduct = 0xbeef
  313. d.bNumConfigurations = 1
  314. d.iManufacturer = "Test Company"
  315. d.iProduct = "Test Product"
  316. with collection.ConfigurationDescriptor() as c:
  317. c.bConfigurationValue = 1
  318. with c.InterfaceDescriptor() as i:
  319. i.bInterfaceNumber = 1
  320. with i.EndpointDescriptor() as e:
  321. e.bEndpointAddress = 0x81
  322. with i.EndpointDescriptor() as e:
  323. e.bEndpointAddress = 0x01
  324. results = list(collection)
  325. # We should wind up with four descriptor entries, as our endpoint/interface descriptors are
  326. # included in our configuration descriptor.
  327. self.assertEqual(len(results), 5)
  328. # Supported languages string.
  329. self.assertIn((3, 0, b'\x04\x03\x09\x04'), results)
  330. # Manufacturer / product string.
  331. self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
  332. self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)
  333. # Device descriptor.
  334. self.assertIn((1, 0, b'\x12\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)
  335. # Configuration descriptor, with subordinates.
  336. self.assertIn((2, 0, b'\t\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)
  337. def test_empty_descriptor_collection(self):
  338. collection = DeviceDescriptorCollection(automatic_language_descriptor=False)
  339. results = list(collection)
  340. self.assertEqual(len(results), 0)
  341. def test_automatic_language_descriptor(self):
  342. collection = DeviceDescriptorCollection(automatic_language_descriptor=True)
  343. results = list(collection)
  344. self.assertEqual(len(results), 1)
  345. if __name__ == "__main__":
  346. unittest.main()