You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

417 lines
14 KiB

  1. #
  2. # This file is part of usb_protocol.
  3. #
  4. """ Convenience emitters for simple, standard descriptors. """
  5. import unittest
  6. from contextlib import contextmanager
  7. from .. import emitter_for_format
  8. from ..descriptor import ComplexDescriptorEmitter
  9. from ...types import LanguageIDs
  10. from ...types.descriptors.standard import \
  11. DeviceDescriptor, StringDescriptor, EndpointDescriptor, DeviceQualifierDescriptor, \
  12. ConfigurationDescriptor, InterfaceDescriptor, StandardDescriptorNumbers, StringLanguageDescriptor
  13. # Create our basic emitters...
  14. DeviceDescriptorEmitter = emitter_for_format(DeviceDescriptor)
  15. StringDescriptorEmitter = emitter_for_format(StringDescriptor)
  16. StringLanguageDescriptorEmitter = emitter_for_format(StringLanguageDescriptor)
  17. EndpointDescriptorEmitter = emitter_for_format(EndpointDescriptor)
  18. DeviceQualifierDescriptor = emitter_for_format(DeviceQualifierDescriptor)
  19. # ... convenience functions ...
  20. def get_string_descriptor(string):
  21. """ Generates a string descriptor for the relevant string. """
  22. emitter = StringDescriptorEmitter()
  23. emitter.bString = string
  24. return emitter.emit()
  25. # ... and complex emitters.
  26. class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
  27. """ Emitter that creates an InterfaceDescriptor. """
  28. DESCRIPTOR_FORMAT = InterfaceDescriptor
  29. @contextmanager
  30. def EndpointDescriptor(self):
  31. """ Context manager that allows addition of a subordinate endpoint descriptor.
  32. It can be used with a `with` statement; and yields an EndpointDesriptorEmitter
  33. that can be populated:
  34. with interface.EndpointDescriptor() as d:
  35. d.bEndpointAddress = 0x01
  36. d.bmAttributes = 0x80
  37. d.wMaxPacketSize = 64
  38. d.bInterval = 0
  39. This adds the relevant descriptor, automatically.
  40. """
  41. descriptor = EndpointDescriptorEmitter()
  42. yield descriptor
  43. self.add_subordinate_descriptor(descriptor)
  44. def _pre_emit(self):
  45. # Count our endpoints, and update our internal count.
  46. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
  47. # Ensure that our interface string is an index, if we can.
  48. if self._collection and hasattr(self, 'iInterface'):
  49. self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)
  50. class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
  51. """ Emitter that creates a configuration descriptor. """
  52. DESCRIPTOR_FORMAT = ConfigurationDescriptor
  53. @contextmanager
  54. def InterfaceDescriptor(self):
  55. """ Context manager that allows addition of a subordinate interface descriptor.
  56. It can be used with a `with` statement; and yields an InterfaceDescriptorEmitter
  57. that can be populated:
  58. with interface.InterfaceDescriptor() as d:
  59. d.bInterfaceNumber = 0x01
  60. [snip]
  61. This adds the relevant descriptor, automatically. Note that populating derived
  62. fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
  63. """
  64. descriptor = InterfaceDescriptorEmitter(collection=self._collection)
  65. yield descriptor
  66. self.add_subordinate_descriptor(descriptor)
  67. def _pre_emit(self):
  68. # Count our interfaces.
  69. self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE]
  70. # Figure out our total length.
  71. subordinate_length = sum(len(sub) for sub in self._subordinates)
  72. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  73. # Ensure that our configuration string is an index, if we can.
  74. if self._collection and hasattr(self, 'iConfiguration'):
  75. self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)
  76. class DeviceDescriptorCollection:
  77. """ Object that builds a full collection of descriptors related to a given USB device. """
  78. # Most systems seem happiest with en_US (ugh), so default to that.
  79. DEFAULT_SUPPORTED_LANGUAGES = [LanguageIDs.ENGLISH_US]
  80. def __init__(self, automatic_language_descriptor=True):
  81. """
  82. Parameters:
  83. automatic_language_descriptor -- If set or not provided, a language descriptor will automatically
  84. be added if none exists.
  85. """
  86. self._automatic_language_descriptor = automatic_language_descriptor
  87. # Create our internal descriptor tracker.
  88. # Keys are a tuple of (type, index).
  89. self._descriptors = {}
  90. # Track string descriptors as they're created.
  91. self._next_string_index = 1
  92. self._index_for_string = {}
  93. def ensure_string_field_is_index(self, field_value):
  94. """ Processes the given field value; if it's not an string index, converts it to one.
  95. Non-index-fields are converted to indices using `get_index_for_string`, which automatically
  96. adds the relevant fields to our string descriptor collection.
  97. """
  98. if isinstance(field_value, str):
  99. return self.get_index_for_string(field_value)
  100. else:
  101. return field_value
  102. def get_index_for_string(self, string):
  103. """ Returns an string descriptor index for the given string.
  104. If a string descriptor already exists for the given string, its index is
  105. returned. Otherwise, a string descriptor is created.
  106. """
  107. # If we already have a descriptor for this string, return it.
  108. if string in self._index_for_string:
  109. return self._index_for_string[string]
  110. # Otherwise, create one:
  111. # Allocate an index...
  112. index = self._next_string_index
  113. self._index_for_string[string] = index
  114. self._next_string_index += 1
  115. # ... store our string descriptor with it ...
  116. identifier = StandardDescriptorNumbers.STRING, index
  117. self._descriptors[identifier] = get_string_descriptor(string)
  118. # ... and return our index.
  119. return index
  120. def add_descriptor(self, descriptor, index=0):
  121. """ Adds a descriptor to our collection.
  122. Parameters:
  123. descriptor -- The descriptor to be added.
  124. index -- The index of the relevant descriptor. Defaults to 0.
  125. """
  126. # If this is an emitter rather than a descriptor itself, convert it.
  127. if hasattr(descriptor, 'emit'):
  128. descriptor = descriptor.emit()
  129. # Figure out the identifier (type + index) for this descriptor...
  130. descriptor_type = descriptor[1]
  131. identifier = descriptor_type, index
  132. # ... and store it.
  133. self._descriptors[identifier] = descriptor
  134. def add_language_descriptor(self, supported_languages=None):
  135. """ Adds a language descriptor to the list of device descriptors.
  136. Parameters:
  137. supported_languages -- A list of languages supported by the device.
  138. """
  139. if supported_languages is None:
  140. supported_languages = self.DEFAULT_SUPPORTED_LANGUAGES
  141. descriptor = StringLanguageDescriptorEmitter()
  142. descriptor.wLANGID = supported_languages
  143. self.add_descriptor(descriptor)
  144. @contextmanager
  145. def DeviceDescriptor(self):
  146. """ Context manager that allows addition of a device descriptor.
  147. It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
  148. that can be populated:
  149. with collection.DeviceDescriptor() as d:
  150. d.idVendor = 0xabcd
  151. d.idProduct = 0x1234
  152. [snip]
  153. This adds the relevant descriptor, automatically.
  154. """
  155. descriptor = DeviceDescriptorEmitter()
  156. yield descriptor
  157. # If we have any string fields, ensure that they're indices before continuing.
  158. for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
  159. if hasattr(descriptor, field):
  160. value = getattr(descriptor, field)
  161. index = self.ensure_string_field_is_index(value)
  162. setattr(descriptor, field, index)
  163. self.add_descriptor(descriptor)
  164. @contextmanager
  165. def ConfigurationDescriptor(self):
  166. """ Context manager that allows addition of a configuration descriptor.
  167. It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
  168. that can be populated:
  169. with collection.ConfigurationDescriptor() as d:
  170. d.bConfigurationValue = 1
  171. [snip]
  172. This adds the relevant descriptor, automatically. Note that populating derived
  173. fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
  174. """
  175. descriptor = ConfigurationDescriptorEmitter()
  176. yield descriptor
  177. self.add_descriptor(descriptor)
  178. def _ensure_has_language_descriptor(self):
  179. """ Ensures that we have a language descriptor; adding one if necessary."""
  180. # If we're not automatically adding a language descriptor, we shouldn't do anything,
  181. # and we'll just ignore this.
  182. if not self._automatic_language_descriptor:
  183. return
  184. # If we don't have a language descriptor, add our default one.
  185. if not (StandardDescriptorNumbers.STRING, 0) in self._descriptors:
  186. self.add_language_descriptor()
  187. def get_descriptor_bytes(self, type_number: int, index: int = 0):
  188. """ Returns the raw, binary descriptor for a given descriptor type/index.
  189. Parmeters:
  190. type_number -- The descriptor type number.
  191. index -- The index of the relevant descriptor, if relevant.
  192. """
  193. # If this is a request for a language descriptor, return one.
  194. if (type_number, index) == (StandardDescriptorNumbers.STRING, 0):
  195. self._ensure_has_language_descriptor()
  196. return self._descriptors[(type_number, index)]
  197. def __iter__(self):
  198. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  199. self._ensure_has_language_descriptor()
  200. return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())
  201. class EmitterTests(unittest.TestCase):
  202. def test_string_emitter(self):
  203. emitter = StringDescriptorEmitter()
  204. emitter.bString = "Hello"
  205. self.assertEqual(emitter.emit(), b"\x0C\x03H\0e\0l\0l\0o\0")
  206. def test_string_emitter_function(self):
  207. self.assertEqual(get_string_descriptor("Hello"), b"\x0C\x03H\0e\0l\0l\0o\0")
  208. def test_configuration_emitter(self):
  209. descriptor = bytes([
  210. # config descriptor
  211. 12, # length
  212. 2, # type
  213. 25, 00, # total length
  214. 1, # num interfaces
  215. 1, # configuration number
  216. 0, # config string
  217. 0x80, # attributes
  218. 250, # max power
  219. # interface descriptor
  220. 9, # length
  221. 4, # type
  222. 0, # number
  223. 0, # alternate
  224. 1, # num endpoints
  225. 0xff, # class
  226. 0xff, # subclass
  227. 0xff, # protocol
  228. 0, # string
  229. # endpoint descriptor
  230. 7, # length
  231. 5, # type
  232. 0x01, # address
  233. 2, # attributes
  234. 64, 0, # max packet size
  235. 255, # interval
  236. ])
  237. # Create a trivial configuration descriptor...
  238. emitter = ConfigurationDescriptorEmitter()
  239. with emitter.InterfaceDescriptor() as interface:
  240. interface.bInterfaceNumber = 0
  241. with interface.EndpointDescriptor() as endpoint:
  242. endpoint.bEndpointAddress = 1
  243. # ... and validate that it maches our reference descriptor.
  244. binary = emitter.emit()
  245. self.assertEqual(len(binary), len(descriptor))
  246. def test_descriptor_collection(self):
  247. collection = DeviceDescriptorCollection()
  248. with collection.DeviceDescriptor() as d:
  249. d.idVendor = 0xdead
  250. d.idProduct = 0xbeef
  251. d.bNumConfigurations = 1
  252. d.iManufacturer = "Test Company"
  253. d.iProduct = "Test Product"
  254. with collection.ConfigurationDescriptor() as c:
  255. c.bConfigurationValue = 1
  256. with c.InterfaceDescriptor() as i:
  257. i.bInterfaceNumber = 1
  258. with i.EndpointDescriptor() as e:
  259. e.bEndpointAddress = 0x81
  260. with i.EndpointDescriptor() as e:
  261. e.bEndpointAddress = 0x01
  262. results = list(collection)
  263. # We should wind up with four descriptor entries, as our endpoint/interface descriptors are
  264. # included in our configuration descriptor.
  265. self.assertEqual(len(results), 5)
  266. # Supported languages string.
  267. self.assertIn((3, 0, b'\x04\x03\x09\x04'), results)
  268. # Manufacturer / product string.
  269. self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
  270. self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)
  271. # Device descriptor.
  272. self.assertIn((1, 0, b'\x12\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)
  273. # Configuration descriptor, with subordinates.
  274. self.assertIn((2, 0, b'\t\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)
  275. def test_empty_descriptor_collection(self):
  276. collection = DeviceDescriptorCollection(automatic_language_descriptor=False)
  277. results = list(collection)
  278. self.assertEqual(len(results), 0)
  279. def test_automatic_language_descriptor(self):
  280. collection = DeviceDescriptorCollection(automatic_language_descriptor=True)
  281. results = list(collection)
  282. self.assertEqual(len(results), 1)
  283. if __name__ == "__main__":
  284. unittest.main()