You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

420 lines
14 KiB

  1. #
  2. # This file is part of usb_protocol.
  3. #
  4. """ Convenience emitters for simple, standard descriptors. """
  5. import unittest
  6. from contextlib import contextmanager
  7. from .. import emitter_for_format
  8. from ..descriptor import ComplexDescriptorEmitter
  9. from ...types import LanguageIDs
  10. from ...types.descriptors.standard import \
  11. DeviceDescriptor, StringDescriptor, EndpointDescriptor, DeviceQualifierDescriptor, \
  12. ConfigurationDescriptor, InterfaceDescriptor, StandardDescriptorNumbers, StringLanguageDescriptor
  13. # Create our basic emitters...
  14. DeviceDescriptorEmitter = emitter_for_format(DeviceDescriptor)
  15. StringDescriptorEmitter = emitter_for_format(StringDescriptor)
  16. StringLanguageDescriptorEmitter = emitter_for_format(StringLanguageDescriptor)
  17. EndpointDescriptorEmitter = emitter_for_format(EndpointDescriptor)
  18. DeviceQualifierDescriptor = emitter_for_format(DeviceQualifierDescriptor)
  19. # ... convenience functions ...
  20. def get_string_descriptor(string):
  21. """ Generates a string descriptor for the relevant string. """
  22. emitter = StringDescriptorEmitter()
  23. emitter.bString = string
  24. return emitter.emit()
  25. # ... and complex emitters.
  26. class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
  27. """ Emitter that creates an InterfaceDescriptor. """
  28. DESCRIPTOR_FORMAT = InterfaceDescriptor
  29. @contextmanager
  30. def EndpointDescriptor(self):
  31. """ Context manager that allows addition of a subordinate endpoint descriptor.
  32. It can be used with a `with` statement; and yields an EndpointDesriptorEmitter
  33. that can be populated:
  34. with interface.EndpointDescriptor() as d:
  35. d.bEndpointAddress = 0x01
  36. d.bmAttributes = 0x80
  37. d.wMaxPacketSize = 64
  38. d.bInterval = 0
  39. This adds the relevant descriptor, automatically.
  40. """
  41. descriptor = EndpointDescriptorEmitter()
  42. yield descriptor
  43. self.add_subordinate_descriptor(descriptor)
  44. def _pre_emit(self):
  45. # Count our endpoints, and update our internal count.
  46. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
  47. # Ensure that our interface string is an index, if we can.
  48. if self._collection and hasattr(self, 'iInterface'):
  49. self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)
  50. class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
  51. """ Emitter that creates a configuration descriptor. """
  52. DESCRIPTOR_FORMAT = ConfigurationDescriptor
  53. @contextmanager
  54. def InterfaceDescriptor(self):
  55. """ Context manager that allows addition of a subordinate interface descriptor.
  56. It can be used with a `with` statement; and yields an InterfaceDescriptorEmitter
  57. that can be populated:
  58. with interface.InterfaceDescriptor() as d:
  59. d.bInterfaceNumber = 0x01
  60. [snip]
  61. This adds the relevant descriptor, automatically. Note that populating derived
  62. fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
  63. """
  64. descriptor = InterfaceDescriptorEmitter(collection=self._collection)
  65. yield descriptor
  66. self.add_subordinate_descriptor(descriptor)
  67. def _pre_emit(self):
  68. # Count our interfaces.
  69. self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE]
  70. # Figure out our total length.
  71. subordinate_length = sum(len(sub) for sub in self._subordinates)
  72. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  73. # Ensure that our configuration string is an index, if we can.
  74. if self._collection and hasattr(self, 'iConfiguration'):
  75. self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)
  76. class DeviceDescriptorCollection:
  77. """ Object that builds a full collection of descriptors related to a given USB device. """
  78. # Most systems seem happiest with en_US (ugh), so default to that.
  79. DEFAULT_SUPPORTED_LANGUAGES = [LanguageIDs.ENGLISH_US]
  80. def __init__(self, automatic_language_descriptor=True):
  81. """
  82. Parameters:
  83. automatic_language_descriptor -- If set or not provided, a language descriptor will automatically
  84. be added if none exists.
  85. """
  86. self._automatic_language_descriptor = automatic_language_descriptor
  87. # Create our internal descriptor tracker.
  88. # Keys are a tuple of (type, index).
  89. self._descriptors = {}
  90. # Track string descriptors as they're created.
  91. self._next_string_index = 1
  92. self._index_for_string = {}
  93. def ensure_string_field_is_index(self, field_value):
  94. """ Processes the given field value; if it's not an string index, converts it to one.
  95. Non-index-fields are converted to indices using `get_index_for_string`, which automatically
  96. adds the relevant fields to our string descriptor collection.
  97. """
  98. if isinstance(field_value, str):
  99. return self.get_index_for_string(field_value)
  100. else:
  101. return field_value
  102. def get_index_for_string(self, string):
  103. """ Returns an string descriptor index for the given string.
  104. If a string descriptor already exists for the given string, its index is
  105. returned. Otherwise, a string descriptor is created.
  106. """
  107. # If we already have a descriptor for this string, return it.
  108. if string in self._index_for_string:
  109. return self._index_for_string[string]
  110. # Otherwise, create one:
  111. # Allocate an index...
  112. index = self._next_string_index
  113. self._index_for_string[string] = index
  114. self._next_string_index += 1
  115. # ... store our string descriptor with it ...
  116. identifier = StandardDescriptorNumbers.STRING, index
  117. self._descriptors[identifier] = get_string_descriptor(string)
  118. # ... and return our index.
  119. return index
  120. def add_descriptor(self, descriptor, index=0, descriptor_type=None):
  121. """ Adds a descriptor to our collection.
  122. Parameters:
  123. descriptor -- The descriptor to be added.
  124. index -- The index of the relevant descriptor. Defaults to 0.
  125. descriptor_type -- The type of the descriptor to be added. If `None`,
  126. this is automatically derived from the descriptor contents.
  127. """
  128. # If this is an emitter rather than a descriptor itself, convert it.
  129. if hasattr(descriptor, 'emit'):
  130. descriptor = descriptor.emit()
  131. # Figure out the identifier (type + index) for this descriptor...
  132. if(descriptor_type is None):
  133. descriptor_type = descriptor[1]
  134. identifier = descriptor_type, index
  135. # ... and store it.
  136. self._descriptors[identifier] = descriptor
  137. def add_language_descriptor(self, supported_languages=None):
  138. """ Adds a language descriptor to the list of device descriptors.
  139. Parameters:
  140. supported_languages -- A list of languages supported by the device.
  141. """
  142. if supported_languages is None:
  143. supported_languages = self.DEFAULT_SUPPORTED_LANGUAGES
  144. descriptor = StringLanguageDescriptorEmitter()
  145. descriptor.wLANGID = supported_languages
  146. self.add_descriptor(descriptor)
  147. @contextmanager
  148. def DeviceDescriptor(self):
  149. """ Context manager that allows addition of a device descriptor.
  150. It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
  151. that can be populated:
  152. with collection.DeviceDescriptor() as d:
  153. d.idVendor = 0xabcd
  154. d.idProduct = 0x1234
  155. [snip]
  156. This adds the relevant descriptor, automatically.
  157. """
  158. descriptor = DeviceDescriptorEmitter()
  159. yield descriptor
  160. # If we have any string fields, ensure that they're indices before continuing.
  161. for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
  162. if hasattr(descriptor, field):
  163. value = getattr(descriptor, field)
  164. index = self.ensure_string_field_is_index(value)
  165. setattr(descriptor, field, index)
  166. self.add_descriptor(descriptor)
  167. @contextmanager
  168. def ConfigurationDescriptor(self):
  169. """ Context manager that allows addition of a configuration descriptor.
  170. It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
  171. that can be populated:
  172. with collection.ConfigurationDescriptor() as d:
  173. d.bConfigurationValue = 1
  174. [snip]
  175. This adds the relevant descriptor, automatically. Note that populating derived
  176. fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
  177. """
  178. descriptor = ConfigurationDescriptorEmitter()
  179. yield descriptor
  180. self.add_descriptor(descriptor)
  181. def _ensure_has_language_descriptor(self):
  182. """ Ensures that we have a language descriptor; adding one if necessary."""
  183. # If we're not automatically adding a language descriptor, we shouldn't do anything,
  184. # and we'll just ignore this.
  185. if not self._automatic_language_descriptor:
  186. return
  187. # If we don't have a language descriptor, add our default one.
  188. if not (StandardDescriptorNumbers.STRING, 0) in self._descriptors:
  189. self.add_language_descriptor()
  190. def get_descriptor_bytes(self, type_number: int, index: int = 0):
  191. """ Returns the raw, binary descriptor for a given descriptor type/index.
  192. Parmeters:
  193. type_number -- The descriptor type number.
  194. index -- The index of the relevant descriptor, if relevant.
  195. """
  196. # If this is a request for a language descriptor, return one.
  197. if (type_number, index) == (StandardDescriptorNumbers.STRING, 0):
  198. self._ensure_has_language_descriptor()
  199. return self._descriptors[(type_number, index)]
  200. def __iter__(self):
  201. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  202. self._ensure_has_language_descriptor()
  203. return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())
  204. class EmitterTests(unittest.TestCase):
  205. def test_string_emitter(self):
  206. emitter = StringDescriptorEmitter()
  207. emitter.bString = "Hello"
  208. self.assertEqual(emitter.emit(), b"\x0C\x03H\0e\0l\0l\0o\0")
  209. def test_string_emitter_function(self):
  210. self.assertEqual(get_string_descriptor("Hello"), b"\x0C\x03H\0e\0l\0l\0o\0")
  211. def test_configuration_emitter(self):
  212. descriptor = bytes([
  213. # config descriptor
  214. 12, # length
  215. 2, # type
  216. 25, 00, # total length
  217. 1, # num interfaces
  218. 1, # configuration number
  219. 0, # config string
  220. 0x80, # attributes
  221. 250, # max power
  222. # interface descriptor
  223. 9, # length
  224. 4, # type
  225. 0, # number
  226. 0, # alternate
  227. 1, # num endpoints
  228. 0xff, # class
  229. 0xff, # subclass
  230. 0xff, # protocol
  231. 0, # string
  232. # endpoint descriptor
  233. 7, # length
  234. 5, # type
  235. 0x01, # address
  236. 2, # attributes
  237. 64, 0, # max packet size
  238. 255, # interval
  239. ])
  240. # Create a trivial configuration descriptor...
  241. emitter = ConfigurationDescriptorEmitter()
  242. with emitter.InterfaceDescriptor() as interface:
  243. interface.bInterfaceNumber = 0
  244. with interface.EndpointDescriptor() as endpoint:
  245. endpoint.bEndpointAddress = 1
  246. # ... and validate that it maches our reference descriptor.
  247. binary = emitter.emit()
  248. self.assertEqual(len(binary), len(descriptor))
  249. def test_descriptor_collection(self):
  250. collection = DeviceDescriptorCollection()
  251. with collection.DeviceDescriptor() as d:
  252. d.idVendor = 0xdead
  253. d.idProduct = 0xbeef
  254. d.bNumConfigurations = 1
  255. d.iManufacturer = "Test Company"
  256. d.iProduct = "Test Product"
  257. with collection.ConfigurationDescriptor() as c:
  258. c.bConfigurationValue = 1
  259. with c.InterfaceDescriptor() as i:
  260. i.bInterfaceNumber = 1
  261. with i.EndpointDescriptor() as e:
  262. e.bEndpointAddress = 0x81
  263. with i.EndpointDescriptor() as e:
  264. e.bEndpointAddress = 0x01
  265. results = list(collection)
  266. # We should wind up with four descriptor entries, as our endpoint/interface descriptors are
  267. # included in our configuration descriptor.
  268. self.assertEqual(len(results), 5)
  269. # Supported languages string.
  270. self.assertIn((3, 0, b'\x04\x03\x09\x04'), results)
  271. # Manufacturer / product string.
  272. self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
  273. self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)
  274. # Device descriptor.
  275. self.assertIn((1, 0, b'\x12\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)
  276. # Configuration descriptor, with subordinates.
  277. self.assertIn((2, 0, b'\t\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)
  278. def test_empty_descriptor_collection(self):
  279. collection = DeviceDescriptorCollection(automatic_language_descriptor=False)
  280. results = list(collection)
  281. self.assertEqual(len(results), 0)
  282. def test_automatic_language_descriptor(self):
  283. collection = DeviceDescriptorCollection(automatic_language_descriptor=True)
  284. results = list(collection)
  285. self.assertEqual(len(results), 1)
  286. if __name__ == "__main__":
  287. unittest.main()