From 21ba0397f7af943860690f0e5113077cb10661c9 Mon Sep 17 00:00:00 2001 From: Kate Temkin Date: Sun, 5 Apr 2020 04:12:07 -0600 Subject: [PATCH] fix issues with descriptor length field & make emitters easier to access --- usb_protocol/emitters/__init__.py | 101 +---------------- usb_protocol/emitters/construct.py | 103 ++++++++++++++++++ usb_protocol/emitters/descriptors/standard.py | 11 ++ usb_protocol/types/descriptors/standard.py | 6 +- 4 files changed, 120 insertions(+), 101 deletions(-) create mode 100644 usb_protocol/emitters/construct.py diff --git a/usb_protocol/emitters/__init__.py b/usb_protocol/emitters/__init__.py index ed503e5..62639c9 100644 --- a/usb_protocol/emitters/__init__.py +++ b/usb_protocol/emitters/__init__.py @@ -1,102 +1,7 @@ # # This file is part of usb-protocol. # -""" Helpers for creating easy emitters. """ +""" USB-related emitters. """ -import unittest -import construct - -class ConstructEmitter: - """ Class that creates a simple emitter based on a construct struct. - - For example, if we have a construct format that looks like the following: - MyStruct = struct( - "a" / Int8 - "b" / Int8 - ) - - We could create emit an object like follows: - emitter = ConstructEmitter(MyStruct) - emitter.a = 0xab - emitter.b = 0xcd - my_bytes = emitter.emit() # "\xab\xcd" - """ - - def __init__(self, struct): - """ - Parmeters: - construct_format -- The format for which to create an emitter. - """ - self.__dict__['format'] = struct - self.__dict__['fields'] = {} - - - def _format_contains_field(self, field_name): - """ Returns True iff the given format has a field with the provided name. - - Parameters: - format_object -- The Construct format to work with. This includes e.g. most descriptor types. - field_name -- The field name to query. - """ - return any(f.name == field_name for f in self.format.subcons) - - - def __setattr__(self, name, value): - """ Hook that we used to set our fields. """ - - # If the field starts with a '_', don't handle it, as it's an internal field. - if name.startswith('_'): - super().__setattr__(name, value) - return - - if not self._format_contains_field(name): - raise AttributeError(f"emitter specification contains no field {name}") - - self.fields[name] = value - - - def emit(self): - """ Emits the stream of bytes associated with this object. """ - - try: - return self.format.build(self.fields) - except KeyError as e: - raise KeyError(f"missing necessary field: {e}") - - - def __getattr__(self, name): - """ Retrieves an emitter field, if possible. """ - - if name in self.fields: - return self.fields[name] - else: - raise AttributeError(f"descriptor emitter has no property {name}") - - -class ConstructEmitterTest(unittest.TestCase): - - def test_simple_emitter(self): - - test_struct = construct.Struct( - "a" / construct.Int8ul, - "b" / construct.Int8ul - ) - - emitter = ConstructEmitter(test_struct) - emitter.a = 0xab - emitter.b = 0xcd - - self.assertEqual(emitter.emit(), b"\xab\xcd") - - -def emitter_for_format(construct_format): - """ Creates a factory method for the relevant construct format. """ - - def _factory(): - return ConstructEmitter(construct_format) - - return _factory - - -if __name__ == "__main__": - unittest.main() +from .construct import emitter_for_format, ConstructEmitter +from .descriptors.standard import DeviceDescriptorCollection diff --git a/usb_protocol/emitters/construct.py b/usb_protocol/emitters/construct.py new file mode 100644 index 0000000..f2d995d --- /dev/null +++ b/usb_protocol/emitters/construct.py @@ -0,0 +1,103 @@ +# +# This file is part of usb-protocol. +# +""" Helpers for creating construct-related emitters. """ + +import unittest +import construct + + +class ConstructEmitter: + """ Class that creates a simple emitter based on a construct struct. + + For example, if we have a construct format that looks like the following: + MyStruct = struct( + "a" / Int8 + "b" / Int8 + ) + + We could create emit an object like follows: + emitter = ConstructEmitter(MyStruct) + emitter.a = 0xab + emitter.b = 0xcd + my_bytes = emitter.emit() # "\xab\xcd" + """ + + def __init__(self, struct): + """ + Parmeters: + construct_format -- The format for which to create an emitter. + """ + self.__dict__['format'] = struct + self.__dict__['fields'] = {} + + + def _format_contains_field(self, field_name): + """ Returns True iff the given format has a field with the provided name. + + Parameters: + format_object -- The Construct format to work with. This includes e.g. most descriptor types. + field_name -- The field name to query. + """ + return any(f.name == field_name for f in self.format.subcons) + + + def __setattr__(self, name, value): + """ Hook that we used to set our fields. """ + + # If the field starts with a '_', don't handle it, as it's an internal field. + if name.startswith('_'): + super().__setattr__(name, value) + return + + if not self._format_contains_field(name): + raise AttributeError(f"emitter specification contains no field {name}") + + self.fields[name] = value + + + def emit(self): + """ Emits the stream of bytes associated with this object. """ + + try: + return self.format.build(self.fields) + except KeyError as e: + raise KeyError(f"missing necessary field: {e}") + + + def __getattr__(self, name): + """ Retrieves an emitter field, if possible. """ + + if name in self.fields: + return self.fields[name] + else: + raise AttributeError(f"descriptor emitter has no property {name}") + + +class ConstructEmitterTest(unittest.TestCase): + + def test_simple_emitter(self): + + test_struct = construct.Struct( + "a" / construct.Int8ul, + "b" / construct.Int8ul + ) + + emitter = ConstructEmitter(test_struct) + emitter.a = 0xab + emitter.b = 0xcd + + self.assertEqual(emitter.emit(), b"\xab\xcd") + + +def emitter_for_format(construct_format): + """ Creates a factory method for the relevant construct format. """ + + def _factory(): + return ConstructEmitter(construct_format) + + return _factory + + +if __name__ == "__main__": + unittest.main() diff --git a/usb_protocol/emitters/descriptors/standard.py b/usb_protocol/emitters/descriptors/standard.py index bc9b70a..6f2ca35 100644 --- a/usb_protocol/emitters/descriptors/standard.py +++ b/usb_protocol/emitters/descriptors/standard.py @@ -230,11 +230,22 @@ class DeviceDescriptorCollection: self.add_descriptor(descriptor) + def get_descriptor_bytes(self, type_number: int, index: int = 0): + """ Returns the raw, binary descriptor for a given descriptor type/index. + + Parmeters: + type_number -- The descriptor type number. + index -- The index of the relevant descriptor, if relevant. + """ + return self._descriptors[(type_number, index)] + + def __iter__(self): """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """ return ((number, index, desc) for ((number, index), desc) in self._descriptors.items()) + class EmitterTests(unittest.TestCase): def test_string_emitter(self): diff --git a/usb_protocol/types/descriptors/standard.py b/usb_protocol/types/descriptors/standard.py index ea11c27..cde91a8 100644 --- a/usb_protocol/types/descriptors/standard.py +++ b/usb_protocol/types/descriptors/standard.py @@ -29,7 +29,7 @@ class StandardDescriptorNumbers(IntEnum): DeviceDescriptor = DescriptorFormat( - "bLength" / DescriptorLength, + "bLength" / construct.Const(0x12, construct.Int8ul), "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.DEVICE), "bcdUSB" / DescriptorField("USB Version", default=2.0), "bDeviceClass" / DescriptorField("Class", default=0), @@ -47,7 +47,7 @@ DeviceDescriptor = DescriptorFormat( ConfigurationDescriptor = DescriptorFormat( - "bLength" / DescriptorLength, + "bLength" / construct.Const(9, construct.Int8ul), "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.CONFIGURATION), "wTotalLength" / DescriptorField("Length including subordinates"), "bNumInterfaces" / DescriptorField("Interface count"), @@ -101,7 +101,7 @@ EndpointDescriptor = DescriptorFormat( DeviceQualifierDescriptor = DescriptorFormat( - "bLength" / DescriptorLength, + "bLength" / construct.Const(9, construct.Int8ul), "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.DEVICE_QUALIFIER), "bcdUSB" / DescriptorField("USB Version"), "bDeviceClass" / DescriptorField("Class"),