diff --git a/usb_protocol/emitters/__init__.py b/usb_protocol/emitters/__init__.py index 2c27afc..ed503e5 100644 --- a/usb_protocol/emitters/__init__.py +++ b/usb_protocol/emitters/__init__.py @@ -64,6 +64,14 @@ class ConstructEmitter: raise KeyError(f"missing necessary field: {e}") + def __getattr__(self, name): + """ Retrieves an emitter field, if possible. """ + + if name in self.fields: + return self.fields[name] + else: + raise AttributeError(f"descriptor emitter has no property {name}") + class ConstructEmitterTest(unittest.TestCase): diff --git a/usb_protocol/emitters/descriptor.py b/usb_protocol/emitters/descriptor.py index 0d74142..a0cc92d 100644 --- a/usb_protocol/emitters/descriptor.py +++ b/usb_protocol/emitters/descriptor.py @@ -12,7 +12,15 @@ class ComplexDescriptorEmitter(ConstructEmitter): # Base classes should override this. DESCRIPTOR_FORMAT = None - def __init__(self): + def __init__(self, collection=None): + """ + Parameters: + collection -- If this descriptor belongs to a collection, it should be + provided here. Using a collection object allows e.g. automatic + assignment of string descriptor indices. + """ + + self._collection = collection # Always create a basic ConstructEmitter from the given format. super().__init__(self.DESCRIPTOR_FORMAT) @@ -49,6 +57,11 @@ class ComplexDescriptorEmitter(ConstructEmitter): self._subordinates.append(subordinate) + def _pre_emit(self): + """ Performs any manipulations needed on this object before emission. """ + pass + + def emit(self, include_subordinates=True): """ Emit our descriptor. @@ -56,9 +69,11 @@ class ComplexDescriptorEmitter(ConstructEmitter): include_subordinates -- If true or not provided, any subordinate descriptors will be included. """ - result = bytearray() + # Run any pre-emit hook code before we perform our emission... + self._pre_emit() - # Add our basic descriptor... + # Start with our core descriptor... + result = bytearray() result.extend(super().emit()) # ... and if descired, add our subordinates... @@ -66,3 +81,5 @@ class ComplexDescriptorEmitter(ConstructEmitter): result.extend(sub) return bytes(result) + + diff --git a/usb_protocol/emitters/descriptors/standard.py b/usb_protocol/emitters/descriptors/standard.py index 8ad4b19..bc9b70a 100644 --- a/usb_protocol/emitters/descriptors/standard.py +++ b/usb_protocol/emitters/descriptors/standard.py @@ -4,7 +4,6 @@ """ Convenience emitters for simple, standard descriptors. """ import unittest -import functools from contextlib import contextmanager @@ -59,11 +58,14 @@ class InterfaceDescriptorEmitter(ComplexDescriptorEmitter): self.add_subordinate_descriptor(descriptor) - def emit(self, include_subordinates=True): + def _pre_emit(self): - # Count our endpoints, and then call our parent emitter. + # Count our endpoints, and update our internal count. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT] - return super().emit(include_subordinates=include_subordinates) + + # Ensure that our interface string is an index, if we can. + if self._collection and hasattr(self, 'iInterface'): + self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface) @@ -86,23 +88,151 @@ class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter): This adds the relevant descriptor, automatically. Note that populating derived fields such as bNumEndpoints aren't necessary; they'll be populated automatically. """ - descriptor = InterfaceDescriptorEmitter() + descriptor = InterfaceDescriptorEmitter(collection=self._collection) yield descriptor self.add_subordinate_descriptor(descriptor) - def emit(self, include_subordinates=True): + def _pre_emit(self): - # Count our interfaces... + # Count our interfaces. self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE] - # ... and figure out our total length. + # Figure out our total length. subordinate_length = sum(len(sub) for sub in self._subordinates) self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof() - # Finally, emit our whole descriptor. - return super().emit(include_subordinates=include_subordinates) + # Ensure that our configuration string is an index, if we can. + if self._collection and hasattr(self, 'iConfiguration'): + self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration) + + + +class DeviceDescriptorCollection: + """ Object that builds a full collection of descriptors related to a given USB device. """ + + def __init__(self): + + # Create our internal descriptor tracker. + # Keys are a tuple of (type, index). + self._descriptors = {} + + # Track string descriptors as they're created. + self._next_string_index = 1 + self._index_for_string = {} + + + def ensure_string_field_is_index(self, field_value): + """ Processes the given field value; if it's not an string index, converts it to one. + + Non-index-fields are converted to indices using `get_index_for_string`, which automatically + adds the relevant fields to our string descriptor collection. + """ + + if isinstance(field_value, str): + return self.get_index_for_string(field_value) + else: + return field_value + + + def get_index_for_string(self, string): + """ Returns an string descriptor index for the given string. + + If a string descriptor already exists for the given string, its index is + returned. Otherwise, a string descriptor is created. + """ + + # If we already have a descriptor for this string, return it. + if string in self._index_for_string: + return self._index_for_string[string] + + + # Otherwise, create one: + + # Allocate an index... + index = self._next_string_index + self._index_for_string[string] = index + self._next_string_index += 1 + + # ... store our string descriptor with it ... + identifier = StandardDescriptorNumbers.STRING, index + self._descriptors[identifier] = get_string_descriptor(string) + + # ... and return our index. + return index + + + def add_descriptor(self, descriptor, index=0): + """ Adds a descriptor to our collection. + + Parameters: + descriptor -- The descriptor to be added. + index -- The index of the relevant descriptor. Defaults to 0. + """ + + # If this is an emitter rather than a descriptor itself, convert it. + if hasattr(descriptor, 'emit'): + descriptor = descriptor.emit() + + # Figure out the identifier (type + index) for this descriptor... + descriptor_type = descriptor[1] + identifier = descriptor_type, index + + # ... and store it. + self._descriptors[identifier] = descriptor + + + @contextmanager + def DeviceDescriptor(self): + """ Context manager that allows addition of a device descriptor. + + It can be used with a `with` statement; and yields an DeviceDescriptorEmitter + that can be populated: + + with collection.DeviceDescriptor() as d: + d.idVendor = 0xabcd + d.idProduct = 0x1234 + [snip] + + This adds the relevant descriptor, automatically. + """ + descriptor = DeviceDescriptorEmitter() + yield descriptor + + # If we have any string fields, ensure that they're indices before continuing. + for field in ('iManufacturer', 'iProduct', 'iSerialNumber'): + if hasattr(descriptor, field): + value = getattr(descriptor, field) + index = self.ensure_string_field_is_index(value) + setattr(descriptor, field, index) + + self.add_descriptor(descriptor) + + + @contextmanager + def ConfigurationDescriptor(self): + """ Context manager that allows addition of a configuration descriptor. + + It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter + that can be populated: + + with collection.ConfigurationDescriptor() as d: + d.bConfigurationValue = 1 + [snip] + + This adds the relevant descriptor, automatically. Note that populating derived + fields such as bNumInterfaces aren't necessary; they'll be populated automatically. + """ + descriptor = ConfigurationDescriptorEmitter() + yield descriptor + + self.add_descriptor(descriptor) + + + def __iter__(self): + """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """ + return ((number, index, desc) for ((number, index), desc) in self._descriptors.items()) class EmitterTests(unittest.TestCase): @@ -167,5 +297,46 @@ class EmitterTests(unittest.TestCase): self.assertEqual(len(binary), len(descriptor)) + def test_descriptor_collection(self): + collection = DeviceDescriptorCollection() + + with collection.DeviceDescriptor() as d: + d.idVendor = 0xdead + d.idProduct = 0xbeef + d.bNumConfigurations = 1 + + d.iManufacturer = "Test Company" + d.iProduct = "Test Product" + + + with collection.ConfigurationDescriptor() as c: + c.bConfigurationValue = 1 + + with c.InterfaceDescriptor() as i: + i.bInterfaceNumber = 1 + + with i.EndpointDescriptor() as e: + e.bEndpointAddress = 0x81 + + with i.EndpointDescriptor() as e: + e.bEndpointAddress = 0x01 + + + results = list(collection) + + # We should wind up with four descriptor entries, as our endpoint/interface descriptors are + # included in our configuration descriptor. + self.assertEqual(len(results), 4) + + # Manufacturer / product string. + self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results) + self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results) + + # Device descriptor. + self.assertIn((1, 0, b'\x0f\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results) + + # Configuration descriptor, with subordinates. + self.assertIn((2, 0, b'\r\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results) + if __name__ == "__main__": unittest.main() diff --git a/usb_protocol/types/descriptor.py b/usb_protocol/types/descriptor.py index e3057ca..5698be8 100644 --- a/usb_protocol/types/descriptor.py +++ b/usb_protocol/types/descriptor.py @@ -159,15 +159,13 @@ class DescriptorField(construct.Subconstruct): def __rtruediv__(self, field_name): field_type = self._get_type_for_name(field_name) + if self.default is not None: + field_type = construct.Default(field_type, self.default) + # Build our subconstruct. Construct makes this look super weird, # but this is actually "we have a field with of type ". # In long form, we'll call it "description". - subconstruct = (field_name / field_type) * self.description - - if self.default is not None: - return construct.Default(subconstruct, self.default) - else: - return subconstruct + return (field_name / field_type) * self.description # Convenience type that gets a descriptor's own length.