You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

572 lines
20 KiB

  1. #
  2. # This file is part of usb_protocol.
  3. #
  4. """ Convenience emitters for simple, standard descriptors. """
  5. import unittest
  6. from contextlib import contextmanager
  7. from .. import emitter_for_format
  8. from ..descriptor import ComplexDescriptorEmitter
  9. from ...types import LanguageIDs
  10. from ...types.descriptors.standard import *
  11. # Create our basic emitters...
  12. DeviceDescriptorEmitter = emitter_for_format(DeviceDescriptor)
  13. StringDescriptorEmitter = emitter_for_format(StringDescriptor)
  14. StringLanguageDescriptorEmitter = emitter_for_format(StringLanguageDescriptor)
  15. DeviceQualifierDescriptor = emitter_for_format(DeviceQualifierDescriptor)
  16. # ... our basic superspeed emitters ...
  17. USB2ExtensionDescriptorEmitter = emitter_for_format(USB2ExtensionDescriptor)
  18. SuperSpeedUSBDeviceCapabilityDescriptorEmitter = emitter_for_format(SuperSpeedUSBDeviceCapabilityDescriptor)
  19. SuperSpeedEndpointCompanionDescriptorEmitter = emitter_for_format(SuperSpeedEndpointCompanionDescriptor)
  20. # ... convenience functions ...
  21. def get_string_descriptor(string):
  22. """ Generates a string descriptor for the relevant string. """
  23. emitter = StringDescriptorEmitter()
  24. emitter.bString = string
  25. return emitter.emit()
  26. # ... and complex emitters.
  27. class EndpointDescriptorEmitter(ComplexDescriptorEmitter):
  28. """ Emitter that creates an InterfaceDescriptor. """
  29. DESCRIPTOR_FORMAT = EndpointDescriptor
  30. @contextmanager
  31. def SuperSpeedCompanion(self):
  32. """ Context manager that allows addition of a SuperSpeed Companion to this endpoint descriptor.
  33. It can be used with a `with` statement; and yields an SuperSpeedEndpointCompanionDescriptorEmitter
  34. that can be populated:
  35. with endpoint.SuperSpeedEndpointCompanion() as d:
  36. d.bMaxBurst = 1
  37. This adds the relevant descriptor, automatically.
  38. """
  39. descriptor = SuperSpeedEndpointCompanionDescriptorEmitter()
  40. yield descriptor
  41. self.add_subordinate_descriptor(descriptor)
  42. class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
  43. """ Emitter that creates an InterfaceDescriptor. """
  44. DESCRIPTOR_FORMAT = InterfaceDescriptor
  45. @contextmanager
  46. def EndpointDescriptor(self, *, add_default_superspeed=False):
  47. """ Context manager that allows addition of a subordinate endpoint descriptor.
  48. It can be used with a `with` statement; and yields an EndpointDesriptorEmitter
  49. that can be populated:
  50. with interface.EndpointDescriptor() as d:
  51. d.bEndpointAddress = 0x01
  52. d.bmAttributes = 0x80
  53. d.wMaxPacketSize = 64
  54. d.bInterval = 0
  55. This adds the relevant descriptor, automatically.
  56. """
  57. descriptor = EndpointDescriptorEmitter()
  58. yield descriptor
  59. # If we're adding a default SuperSpeed extension, do so.
  60. if add_default_superspeed:
  61. with descriptor.SuperSpeedCompanion():
  62. pass
  63. self.add_subordinate_descriptor(descriptor)
  64. def _pre_emit(self):
  65. # Count our endpoints, and update our internal count.
  66. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
  67. # Ensure that our interface string is an index, if we can.
  68. if self._collection and hasattr(self, 'iInterface'):
  69. self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)
  70. class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
  71. """ Emitter that creates a configuration descriptor. """
  72. DESCRIPTOR_FORMAT = ConfigurationDescriptor
  73. @contextmanager
  74. def InterfaceDescriptor(self):
  75. """ Context manager that allows addition of a subordinate interface descriptor.
  76. It can be used with a `with` statement; and yields an InterfaceDescriptorEmitter
  77. that can be populated:
  78. with interface.InterfaceDescriptor() as d:
  79. d.bInterfaceNumber = 0x01
  80. [snip]
  81. This adds the relevant descriptor, automatically. Note that populating derived
  82. fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
  83. """
  84. descriptor = InterfaceDescriptorEmitter(collection=self._collection)
  85. yield descriptor
  86. self.add_subordinate_descriptor(descriptor)
  87. def _pre_emit(self):
  88. # Count our interfaces. Alternate settings of the same interface do not count multiple times.
  89. self.bNumInterfaces = len(set([subordinate[2] for subordinate in self._subordinates if (subordinate[1] == StandardDescriptorNumbers.INTERFACE)]))
  90. # Figure out our total length.
  91. subordinate_length = sum(len(sub) for sub in self._subordinates)
  92. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  93. # Ensure that our configuration string is an index, if we can.
  94. if self._collection and hasattr(self, 'iConfiguration'):
  95. self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)
  96. class DeviceDescriptorCollection:
  97. """ Object that builds a full collection of descriptors related to a given USB device. """
  98. # Most systems seem happiest with en_US (ugh), so default to that.
  99. DEFAULT_SUPPORTED_LANGUAGES = [LanguageIDs.ENGLISH_US]
  100. def __init__(self, automatic_language_descriptor=True):
  101. """
  102. Parameters:
  103. automatic_language_descriptor -- If set or not provided, a language descriptor will automatically
  104. be added if none exists.
  105. """
  106. self._automatic_language_descriptor = automatic_language_descriptor
  107. # Create our internal descriptor tracker.
  108. # Keys are a tuple of (type, index).
  109. self._descriptors = {}
  110. # Track string descriptors as they're created.
  111. self._next_string_index = 1
  112. self._index_for_string = {}
  113. def ensure_string_field_is_index(self, field_value):
  114. """ Processes the given field value; if it's not an string index, converts it to one.
  115. Non-index-fields are converted to indices using `get_index_for_string`, which automatically
  116. adds the relevant fields to our string descriptor collection.
  117. """
  118. if isinstance(field_value, str):
  119. return self.get_index_for_string(field_value)
  120. else:
  121. return field_value
  122. def get_index_for_string(self, string):
  123. """ Returns an string descriptor index for the given string.
  124. If a string descriptor already exists for the given string, its index is
  125. returned. Otherwise, a string descriptor is created.
  126. """
  127. # If we already have a descriptor for this string, return it.
  128. if string in self._index_for_string:
  129. return self._index_for_string[string]
  130. # Otherwise, create one:
  131. # Allocate an index...
  132. index = self._next_string_index
  133. self._index_for_string[string] = index
  134. self._next_string_index += 1
  135. # ... store our string descriptor with it ...
  136. identifier = StandardDescriptorNumbers.STRING, index
  137. self._descriptors[identifier] = get_string_descriptor(string)
  138. # ... and return our index.
  139. return index
  140. def add_descriptor(self, descriptor, index=0, descriptor_type=None):
  141. """ Adds a descriptor to our collection.
  142. Parameters:
  143. descriptor -- The descriptor to be added.
  144. index -- The index of the relevant descriptor. Defaults to 0.
  145. descriptor_type -- The type of the descriptor to be added. If `None`,
  146. this is automatically derived from the descriptor contents.
  147. """
  148. # If this is an emitter rather than a descriptor itself, convert it.
  149. if hasattr(descriptor, 'emit'):
  150. descriptor = descriptor.emit()
  151. # Figure out the identifier (type + index) for this descriptor...
  152. if (descriptor_type is None):
  153. descriptor_type = descriptor[1]
  154. identifier = descriptor_type, index
  155. # ... and store it.
  156. self._descriptors[identifier] = descriptor
  157. def add_language_descriptor(self, supported_languages=None):
  158. """ Adds a language descriptor to the list of device descriptors.
  159. Parameters:
  160. supported_languages -- A list of languages supported by the device.
  161. """
  162. if supported_languages is None:
  163. supported_languages = self.DEFAULT_SUPPORTED_LANGUAGES
  164. descriptor = StringLanguageDescriptorEmitter()
  165. descriptor.wLANGID = supported_languages
  166. self.add_descriptor(descriptor)
  167. @contextmanager
  168. def DeviceDescriptor(self):
  169. """ Context manager that allows addition of a device descriptor.
  170. It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
  171. that can be populated:
  172. with collection.DeviceDescriptor() as d:
  173. d.idVendor = 0xabcd
  174. d.idProduct = 0x1234
  175. [snip]
  176. This adds the relevant descriptor, automatically.
  177. """
  178. descriptor = DeviceDescriptorEmitter()
  179. yield descriptor
  180. # If we have any string fields, ensure that they're indices before continuing.
  181. for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
  182. if hasattr(descriptor, field):
  183. value = getattr(descriptor, field)
  184. index = self.ensure_string_field_is_index(value)
  185. setattr(descriptor, field, index)
  186. self.add_descriptor(descriptor)
  187. @contextmanager
  188. def ConfigurationDescriptor(self):
  189. """ Context manager that allows addition of a configuration descriptor.
  190. It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
  191. that can be populated:
  192. with collection.ConfigurationDescriptor() as d:
  193. d.bConfigurationValue = 1
  194. [snip]
  195. This adds the relevant descriptor, automatically. Note that populating derived
  196. fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
  197. """
  198. descriptor = ConfigurationDescriptorEmitter(collection=self)
  199. yield descriptor
  200. self.add_descriptor(descriptor)
  201. def _ensure_has_language_descriptor(self):
  202. """ ensures that we have a language descriptor; adding one if necessary."""
  203. # if we're not automatically adding a language descriptor, we shouldn't do anything,
  204. # and we'll just ignore this.
  205. if not self._automatic_language_descriptor:
  206. return
  207. # if we don't have a language descriptor, add our default one.
  208. if (StandardDescriptorNumbers.STRING, 0) not in self._descriptors:
  209. self.add_language_descriptor()
  210. def get_descriptor_bytes(self, type_number: int, index: int = 0):
  211. """ Returns the raw, binary descriptor for a given descriptor type/index.
  212. Parmeters:
  213. type_number -- The descriptor type number.
  214. index -- The index of the relevant descriptor, if relevant.
  215. """
  216. # If this is a request for a language descriptor, return one.
  217. if (type_number, index) == (StandardDescriptorNumbers.STRING, 0):
  218. self._ensure_has_language_descriptor()
  219. return self._descriptors[(type_number, index)]
  220. def __iter__(self):
  221. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  222. self._ensure_has_language_descriptor()
  223. return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())
  224. class BinaryObjectStoreDescriptorEmitter(ComplexDescriptorEmitter):
  225. """ Emitter that creates a BinaryObjectStore descriptor. """
  226. DESCRIPTOR_FORMAT = BinaryObjectStoreDescriptor
  227. @contextmanager
  228. def USB2Extension(self):
  229. """ Context manager that allows addition of a USB 2.0 Extension to this Binary Object Store.
  230. It can be used with a `with` statement; and yields an USB2ExtensionDescriptorEmitter
  231. that can be populated:
  232. with bos.USB2Extension() as e:
  233. e.bmAttributes = 1
  234. This adds the relevant descriptor, automatically.
  235. """
  236. descriptor = USB2ExtensionDescriptorEmitter()
  237. yield descriptor
  238. self.add_subordinate_descriptor(descriptor)
  239. @contextmanager
  240. def SuperSpeedUSBDeviceCapability(self):
  241. """ Context manager that allows addition of a SS Device Capability to this Binary Object Store.
  242. It can be used with a `with` statement; and yields an SuperSpeedUSBDeviceCapabilityDescriptorEmitter
  243. that can be populated:
  244. with bos.SuperSpeedUSBDeviceCapability() as e:
  245. e.wSpeedSupported = 0b1110
  246. e.bFunctionalitySupport = 1
  247. This adds the relevant descriptor, automatically.
  248. """
  249. descriptor = SuperSpeedUSBDeviceCapabilityDescriptorEmitter()
  250. yield descriptor
  251. self.add_subordinate_descriptor(descriptor)
  252. def _pre_emit(self):
  253. # Figure out the total length of our descriptor, including subordinates.
  254. subordinate_length = sum(len(sub) for sub in self._subordinates)
  255. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  256. # Count our subordinate descriptors, and update our internal count.
  257. self.bNumDeviceCaps = len(self._subordinates)
  258. class SuperSpeedDeviceDescriptorCollection(DeviceDescriptorCollection):
  259. """ Object that builds a full collection of descriptors related to a given USB3 device. """
  260. def __init__(self, automatic_descriptors=True):
  261. """
  262. Parameters:
  263. automatic_descriptors -- If set or not provided, certian required descriptors will be
  264. be added if none exists.
  265. """
  266. self._automatic_descriptors = automatic_descriptors
  267. super().__init__(automatic_language_descriptor=automatic_descriptors)
  268. @contextmanager
  269. def BOSDescriptor(self):
  270. """ Context manager that allows addition of a Binary Object Store descriptor.
  271. It can be used with a `with` statement; and yields an BinaryObjectStoreDescriptorEmitter
  272. that can be populated:
  273. with collection.BOSDescriptor() as d:
  274. [snip]
  275. This adds the relevant descriptor, automatically. Note that populating derived
  276. fields such as bNumDeviceCaps aren't necessary; they'll be populated automatically.
  277. """
  278. descriptor = BinaryObjectStoreDescriptorEmitter()
  279. yield descriptor
  280. self.add_descriptor(descriptor)
  281. def add_default_bos_descriptor(self):
  282. """ Adds a default, empty BOS descriptor. """
  283. # Create an empty BOS descriptor...
  284. descriptor = BinaryObjectStoreDescriptorEmitter()
  285. # ... populate our default required descriptors...
  286. descriptor.add_subordinate_descriptor(USB2ExtensionDescriptorEmitter())
  287. descriptor.add_subordinate_descriptor(SuperSpeedUSBDeviceCapabilityDescriptorEmitter())
  288. # ... and add it to our overall BOS descriptor.
  289. self.add_descriptor(descriptor)
  290. def _ensure_has_bos_descriptor(self):
  291. """ Ensures that we have a BOS descriptor; adding one if necessary."""
  292. # If we're not automatically adding a language descriptor, we shouldn't do anything,
  293. # and we'll just ignore this.
  294. if not self._automatic_descriptors:
  295. return
  296. # If we don't have a language descriptor, add our default one.
  297. if (StandardDescriptorNumbers.BOS, 0) not in self._descriptors:
  298. self.add_default_bos_descriptor()
  299. def __iter__(self):
  300. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  301. self._ensure_has_bos_descriptor()
  302. return super().__iter__()
  303. class EmitterTests(unittest.TestCase):
  304. def test_string_emitter(self):
  305. emitter = StringDescriptorEmitter()
  306. emitter.bString = "Hello"
  307. self.assertEqual(emitter.emit(), b"\x0C\x03H\0e\0l\0l\0o\0")
  308. def test_string_emitter_function(self):
  309. self.assertEqual(get_string_descriptor("Hello"), b"\x0C\x03H\0e\0l\0l\0o\0")
  310. def test_configuration_emitter(self):
  311. descriptor = bytes([
  312. # config descriptor
  313. 12, # length
  314. 2, # type
  315. 25, 00, # total length
  316. 1, # num interfaces
  317. 1, # configuration number
  318. 0, # config string
  319. 0x80, # attributes
  320. 250, # max power
  321. # interface descriptor
  322. 9, # length
  323. 4, # type
  324. 0, # number
  325. 0, # alternate
  326. 1, # num endpoints
  327. 0xff, # class
  328. 0xff, # subclass
  329. 0xff, # protocol
  330. 0, # string
  331. # endpoint descriptor
  332. 7, # length
  333. 5, # type
  334. 0x01, # address
  335. 2, # attributes
  336. 64, 0, # max packet size
  337. 255, # interval
  338. ])
  339. # Create a trivial configuration descriptor...
  340. emitter = ConfigurationDescriptorEmitter()
  341. with emitter.InterfaceDescriptor() as interface:
  342. interface.bInterfaceNumber = 0
  343. with interface.EndpointDescriptor() as endpoint:
  344. endpoint.bEndpointAddress = 1
  345. # ... and validate that it maches our reference descriptor.
  346. binary = emitter.emit()
  347. self.assertEqual(len(binary), len(descriptor))
  348. def test_descriptor_collection(self):
  349. collection = DeviceDescriptorCollection()
  350. with collection.DeviceDescriptor() as d:
  351. d.idVendor = 0xdead
  352. d.idProduct = 0xbeef
  353. d.bNumConfigurations = 1
  354. d.iManufacturer = "Test Company"
  355. d.iProduct = "Test Product"
  356. with collection.ConfigurationDescriptor() as c:
  357. c.bConfigurationValue = 1
  358. with c.InterfaceDescriptor() as i:
  359. i.bInterfaceNumber = 1
  360. with i.EndpointDescriptor() as e:
  361. e.bEndpointAddress = 0x81
  362. with i.EndpointDescriptor() as e:
  363. e.bEndpointAddress = 0x01
  364. results = list(collection)
  365. # We should wind up with four descriptor entries, as our endpoint/interface descriptors are
  366. # included in our configuration descriptor.
  367. self.assertEqual(len(results), 5)
  368. # Supported languages string.
  369. self.assertIn((3, 0, b'\x04\x03\x09\x04'), results)
  370. # Manufacturer / product string.
  371. self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
  372. self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)
  373. # Device descriptor.
  374. self.assertIn((1, 0, b'\x12\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)
  375. # Configuration descriptor, with subordinates.
  376. self.assertIn((2, 0, b'\t\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)
  377. def test_empty_descriptor_collection(self):
  378. collection = DeviceDescriptorCollection(automatic_language_descriptor=False)
  379. results = list(collection)
  380. self.assertEqual(len(results), 0)
  381. def test_automatic_language_descriptor(self):
  382. collection = DeviceDescriptorCollection(automatic_language_descriptor=True)
  383. results = list(collection)
  384. self.assertEqual(len(results), 1)
  385. if __name__ == "__main__":
  386. unittest.main()