@@ -1,102 +1,7 @@ | |||||
# | # | ||||
# This file is part of usb-protocol. | # This file is part of usb-protocol. | ||||
# | # | ||||
""" Helpers for creating easy emitters. """ | |||||
""" USB-related emitters. """ | |||||
import unittest | |||||
import construct | |||||
class ConstructEmitter: | |||||
""" Class that creates a simple emitter based on a construct struct. | |||||
For example, if we have a construct format that looks like the following: | |||||
MyStruct = struct( | |||||
"a" / Int8 | |||||
"b" / Int8 | |||||
) | |||||
We could create emit an object like follows: | |||||
emitter = ConstructEmitter(MyStruct) | |||||
emitter.a = 0xab | |||||
emitter.b = 0xcd | |||||
my_bytes = emitter.emit() # "\xab\xcd" | |||||
""" | |||||
def __init__(self, struct): | |||||
""" | |||||
Parmeters: | |||||
construct_format -- The format for which to create an emitter. | |||||
""" | |||||
self.__dict__['format'] = struct | |||||
self.__dict__['fields'] = {} | |||||
def _format_contains_field(self, field_name): | |||||
""" Returns True iff the given format has a field with the provided name. | |||||
Parameters: | |||||
format_object -- The Construct format to work with. This includes e.g. most descriptor types. | |||||
field_name -- The field name to query. | |||||
""" | |||||
return any(f.name == field_name for f in self.format.subcons) | |||||
def __setattr__(self, name, value): | |||||
""" Hook that we used to set our fields. """ | |||||
# If the field starts with a '_', don't handle it, as it's an internal field. | |||||
if name.startswith('_'): | |||||
super().__setattr__(name, value) | |||||
return | |||||
if not self._format_contains_field(name): | |||||
raise AttributeError(f"emitter specification contains no field {name}") | |||||
self.fields[name] = value | |||||
def emit(self): | |||||
""" Emits the stream of bytes associated with this object. """ | |||||
try: | |||||
return self.format.build(self.fields) | |||||
except KeyError as e: | |||||
raise KeyError(f"missing necessary field: {e}") | |||||
def __getattr__(self, name): | |||||
""" Retrieves an emitter field, if possible. """ | |||||
if name in self.fields: | |||||
return self.fields[name] | |||||
else: | |||||
raise AttributeError(f"descriptor emitter has no property {name}") | |||||
class ConstructEmitterTest(unittest.TestCase): | |||||
def test_simple_emitter(self): | |||||
test_struct = construct.Struct( | |||||
"a" / construct.Int8ul, | |||||
"b" / construct.Int8ul | |||||
) | |||||
emitter = ConstructEmitter(test_struct) | |||||
emitter.a = 0xab | |||||
emitter.b = 0xcd | |||||
self.assertEqual(emitter.emit(), b"\xab\xcd") | |||||
def emitter_for_format(construct_format): | |||||
""" Creates a factory method for the relevant construct format. """ | |||||
def _factory(): | |||||
return ConstructEmitter(construct_format) | |||||
return _factory | |||||
if __name__ == "__main__": | |||||
unittest.main() | |||||
from .construct import emitter_for_format, ConstructEmitter | |||||
from .descriptors.standard import DeviceDescriptorCollection |
@@ -0,0 +1,103 @@ | |||||
# | |||||
# This file is part of usb-protocol. | |||||
# | |||||
""" Helpers for creating construct-related emitters. """ | |||||
import unittest | |||||
import construct | |||||
class ConstructEmitter: | |||||
""" Class that creates a simple emitter based on a construct struct. | |||||
For example, if we have a construct format that looks like the following: | |||||
MyStruct = struct( | |||||
"a" / Int8 | |||||
"b" / Int8 | |||||
) | |||||
We could create emit an object like follows: | |||||
emitter = ConstructEmitter(MyStruct) | |||||
emitter.a = 0xab | |||||
emitter.b = 0xcd | |||||
my_bytes = emitter.emit() # "\xab\xcd" | |||||
""" | |||||
def __init__(self, struct): | |||||
""" | |||||
Parmeters: | |||||
construct_format -- The format for which to create an emitter. | |||||
""" | |||||
self.__dict__['format'] = struct | |||||
self.__dict__['fields'] = {} | |||||
def _format_contains_field(self, field_name): | |||||
""" Returns True iff the given format has a field with the provided name. | |||||
Parameters: | |||||
format_object -- The Construct format to work with. This includes e.g. most descriptor types. | |||||
field_name -- The field name to query. | |||||
""" | |||||
return any(f.name == field_name for f in self.format.subcons) | |||||
def __setattr__(self, name, value): | |||||
""" Hook that we used to set our fields. """ | |||||
# If the field starts with a '_', don't handle it, as it's an internal field. | |||||
if name.startswith('_'): | |||||
super().__setattr__(name, value) | |||||
return | |||||
if not self._format_contains_field(name): | |||||
raise AttributeError(f"emitter specification contains no field {name}") | |||||
self.fields[name] = value | |||||
def emit(self): | |||||
""" Emits the stream of bytes associated with this object. """ | |||||
try: | |||||
return self.format.build(self.fields) | |||||
except KeyError as e: | |||||
raise KeyError(f"missing necessary field: {e}") | |||||
def __getattr__(self, name): | |||||
""" Retrieves an emitter field, if possible. """ | |||||
if name in self.fields: | |||||
return self.fields[name] | |||||
else: | |||||
raise AttributeError(f"descriptor emitter has no property {name}") | |||||
class ConstructEmitterTest(unittest.TestCase): | |||||
def test_simple_emitter(self): | |||||
test_struct = construct.Struct( | |||||
"a" / construct.Int8ul, | |||||
"b" / construct.Int8ul | |||||
) | |||||
emitter = ConstructEmitter(test_struct) | |||||
emitter.a = 0xab | |||||
emitter.b = 0xcd | |||||
self.assertEqual(emitter.emit(), b"\xab\xcd") | |||||
def emitter_for_format(construct_format): | |||||
""" Creates a factory method for the relevant construct format. """ | |||||
def _factory(): | |||||
return ConstructEmitter(construct_format) | |||||
return _factory | |||||
if __name__ == "__main__": | |||||
unittest.main() |
@@ -230,11 +230,22 @@ class DeviceDescriptorCollection: | |||||
self.add_descriptor(descriptor) | self.add_descriptor(descriptor) | ||||
def get_descriptor_bytes(self, type_number: int, index: int = 0): | |||||
""" Returns the raw, binary descriptor for a given descriptor type/index. | |||||
Parmeters: | |||||
type_number -- The descriptor type number. | |||||
index -- The index of the relevant descriptor, if relevant. | |||||
""" | |||||
return self._descriptors[(type_number, index)] | |||||
def __iter__(self): | def __iter__(self): | ||||
""" Allow iterating over each of our descriptors; yields (index, value, descriptor). """ | """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """ | ||||
return ((number, index, desc) for ((number, index), desc) in self._descriptors.items()) | return ((number, index, desc) for ((number, index), desc) in self._descriptors.items()) | ||||
class EmitterTests(unittest.TestCase): | class EmitterTests(unittest.TestCase): | ||||
def test_string_emitter(self): | def test_string_emitter(self): | ||||
@@ -29,7 +29,7 @@ class StandardDescriptorNumbers(IntEnum): | |||||
DeviceDescriptor = DescriptorFormat( | DeviceDescriptor = DescriptorFormat( | ||||
"bLength" / DescriptorLength, | |||||
"bLength" / construct.Const(0x12, construct.Int8ul), | |||||
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.DEVICE), | "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.DEVICE), | ||||
"bcdUSB" / DescriptorField("USB Version", default=2.0), | "bcdUSB" / DescriptorField("USB Version", default=2.0), | ||||
"bDeviceClass" / DescriptorField("Class", default=0), | "bDeviceClass" / DescriptorField("Class", default=0), | ||||
@@ -47,7 +47,7 @@ DeviceDescriptor = DescriptorFormat( | |||||
ConfigurationDescriptor = DescriptorFormat( | ConfigurationDescriptor = DescriptorFormat( | ||||
"bLength" / DescriptorLength, | |||||
"bLength" / construct.Const(9, construct.Int8ul), | |||||
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.CONFIGURATION), | "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.CONFIGURATION), | ||||
"wTotalLength" / DescriptorField("Length including subordinates"), | "wTotalLength" / DescriptorField("Length including subordinates"), | ||||
"bNumInterfaces" / DescriptorField("Interface count"), | "bNumInterfaces" / DescriptorField("Interface count"), | ||||
@@ -101,7 +101,7 @@ EndpointDescriptor = DescriptorFormat( | |||||
DeviceQualifierDescriptor = DescriptorFormat( | DeviceQualifierDescriptor = DescriptorFormat( | ||||
"bLength" / DescriptorLength, | |||||
"bLength" / construct.Const(9, construct.Int8ul), | |||||
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.DEVICE_QUALIFIER), | "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.DEVICE_QUALIFIER), | ||||
"bcdUSB" / DescriptorField("USB Version"), | "bcdUSB" / DescriptorField("USB Version"), | ||||
"bDeviceClass" / DescriptorField("Class"), | "bDeviceClass" / DescriptorField("Class"), | ||||