diff --git a/usb_protocol/emitters/descriptors/standard.py b/usb_protocol/emitters/descriptors/standard.py index 52bc2c7..e40556c 100644 --- a/usb_protocol/emitters/descriptors/standard.py +++ b/usb_protocol/emitters/descriptors/standard.py @@ -208,12 +208,14 @@ class DeviceDescriptorCollection: return index - def add_descriptor(self, descriptor, index=0): + def add_descriptor(self, descriptor, index=0, descriptor_type=None): """ Adds a descriptor to our collection. Parameters: - descriptor -- The descriptor to be added. - index -- The index of the relevant descriptor. Defaults to 0. + descriptor -- The descriptor to be added. + index -- The index of the relevant descriptor. Defaults to 0. + descriptor_type -- The type of the descriptor to be added. If `None`, + this is automatically derived from the descriptor contents. """ # If this is an emitter rather than a descriptor itself, convert it. @@ -221,7 +223,8 @@ class DeviceDescriptorCollection: descriptor = descriptor.emit() # Figure out the identifier (type + index) for this descriptor... - descriptor_type = descriptor[1] + if (descriptor_type is None): + descriptor_type = descriptor[1] identifier = descriptor_type, index # ... and store it. @@ -566,5 +569,3 @@ class EmitterTests(unittest.TestCase): if __name__ == "__main__": unittest.main() - -