Browse Source

add complex emitters for building USB descriptors

main
Kate Temkin 4 years ago
parent
commit
7672fb5221
4 changed files with 213 additions and 19 deletions
  1. +8
    -0
      usb_protocol/emitters/__init__.py
  2. +20
    -3
      usb_protocol/emitters/descriptor.py
  3. +181
    -10
      usb_protocol/emitters/descriptors/standard.py
  4. +4
    -6
      usb_protocol/types/descriptor.py

+ 8
- 0
usb_protocol/emitters/__init__.py View File

@@ -64,6 +64,14 @@ class ConstructEmitter:
raise KeyError(f"missing necessary field: {e}")


def __getattr__(self, name):
""" Retrieves an emitter field, if possible. """

if name in self.fields:
return self.fields[name]
else:
raise AttributeError(f"descriptor emitter has no property {name}")


class ConstructEmitterTest(unittest.TestCase):



+ 20
- 3
usb_protocol/emitters/descriptor.py View File

@@ -12,7 +12,15 @@ class ComplexDescriptorEmitter(ConstructEmitter):
# Base classes should override this.
DESCRIPTOR_FORMAT = None

def __init__(self):
def __init__(self, collection=None):
"""
Parameters:
collection -- If this descriptor belongs to a collection, it should be
provided here. Using a collection object allows e.g. automatic
assignment of string descriptor indices.
"""

self._collection = collection

# Always create a basic ConstructEmitter from the given format.
super().__init__(self.DESCRIPTOR_FORMAT)
@@ -49,6 +57,11 @@ class ComplexDescriptorEmitter(ConstructEmitter):
self._subordinates.append(subordinate)


def _pre_emit(self):
""" Performs any manipulations needed on this object before emission. """
pass


def emit(self, include_subordinates=True):
""" Emit our descriptor.

@@ -56,9 +69,11 @@ class ComplexDescriptorEmitter(ConstructEmitter):
include_subordinates -- If true or not provided, any subordinate descriptors will be included.
"""

result = bytearray()
# Run any pre-emit hook code before we perform our emission...
self._pre_emit()

# Add our basic descriptor...
# Start with our core descriptor...
result = bytearray()
result.extend(super().emit())

# ... and if descired, add our subordinates...
@@ -66,3 +81,5 @@ class ComplexDescriptorEmitter(ConstructEmitter):
result.extend(sub)

return bytes(result)



+ 181
- 10
usb_protocol/emitters/descriptors/standard.py View File

@@ -4,7 +4,6 @@
""" Convenience emitters for simple, standard descriptors. """

import unittest
import functools

from contextlib import contextmanager

@@ -59,11 +58,14 @@ class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
self.add_subordinate_descriptor(descriptor)


def emit(self, include_subordinates=True):
def _pre_emit(self):

# Count our endpoints, and then call our parent emitter.
# Count our endpoints, and update our internal count.
self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
return super().emit(include_subordinates=include_subordinates)

# Ensure that our interface string is an index, if we can.
if self._collection and hasattr(self, 'iInterface'):
self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)



@@ -86,23 +88,151 @@ class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
This adds the relevant descriptor, automatically. Note that populating derived
fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
"""
descriptor = InterfaceDescriptorEmitter()
descriptor = InterfaceDescriptorEmitter(collection=self._collection)
yield descriptor

self.add_subordinate_descriptor(descriptor)


def emit(self, include_subordinates=True):
def _pre_emit(self):

# Count our interfaces...
# Count our interfaces.
self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE]

# ... and figure out our total length.
# Figure out our total length.
subordinate_length = sum(len(sub) for sub in self._subordinates)
self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()

# Finally, emit our whole descriptor.
return super().emit(include_subordinates=include_subordinates)
# Ensure that our configuration string is an index, if we can.
if self._collection and hasattr(self, 'iConfiguration'):
self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)



class DeviceDescriptorCollection:
""" Object that builds a full collection of descriptors related to a given USB device. """

def __init__(self):

# Create our internal descriptor tracker.
# Keys are a tuple of (type, index).
self._descriptors = {}

# Track string descriptors as they're created.
self._next_string_index = 1
self._index_for_string = {}


def ensure_string_field_is_index(self, field_value):
""" Processes the given field value; if it's not an string index, converts it to one.

Non-index-fields are converted to indices using `get_index_for_string`, which automatically
adds the relevant fields to our string descriptor collection.
"""

if isinstance(field_value, str):
return self.get_index_for_string(field_value)
else:
return field_value


def get_index_for_string(self, string):
""" Returns an string descriptor index for the given string.

If a string descriptor already exists for the given string, its index is
returned. Otherwise, a string descriptor is created.
"""

# If we already have a descriptor for this string, return it.
if string in self._index_for_string:
return self._index_for_string[string]


# Otherwise, create one:

# Allocate an index...
index = self._next_string_index
self._index_for_string[string] = index
self._next_string_index += 1

# ... store our string descriptor with it ...
identifier = StandardDescriptorNumbers.STRING, index
self._descriptors[identifier] = get_string_descriptor(string)

# ... and return our index.
return index


def add_descriptor(self, descriptor, index=0):
""" Adds a descriptor to our collection.

Parameters:
descriptor -- The descriptor to be added.
index -- The index of the relevant descriptor. Defaults to 0.
"""

# If this is an emitter rather than a descriptor itself, convert it.
if hasattr(descriptor, 'emit'):
descriptor = descriptor.emit()

# Figure out the identifier (type + index) for this descriptor...
descriptor_type = descriptor[1]
identifier = descriptor_type, index

# ... and store it.
self._descriptors[identifier] = descriptor


@contextmanager
def DeviceDescriptor(self):
""" Context manager that allows addition of a device descriptor.

It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
that can be populated:

with collection.DeviceDescriptor() as d:
d.idVendor = 0xabcd
d.idProduct = 0x1234
[snip]

This adds the relevant descriptor, automatically.
"""
descriptor = DeviceDescriptorEmitter()
yield descriptor

# If we have any string fields, ensure that they're indices before continuing.
for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
if hasattr(descriptor, field):
value = getattr(descriptor, field)
index = self.ensure_string_field_is_index(value)
setattr(descriptor, field, index)

self.add_descriptor(descriptor)


@contextmanager
def ConfigurationDescriptor(self):
""" Context manager that allows addition of a configuration descriptor.

It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
that can be populated:

with collection.ConfigurationDescriptor() as d:
d.bConfigurationValue = 1
[snip]

This adds the relevant descriptor, automatically. Note that populating derived
fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
"""
descriptor = ConfigurationDescriptorEmitter()
yield descriptor

self.add_descriptor(descriptor)


def __iter__(self):
""" Allow iterating over each of our descriptors; yields (index, value, descriptor). """
return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())


class EmitterTests(unittest.TestCase):
@@ -167,5 +297,46 @@ class EmitterTests(unittest.TestCase):
self.assertEqual(len(binary), len(descriptor))


def test_descriptor_collection(self):
collection = DeviceDescriptorCollection()

with collection.DeviceDescriptor() as d:
d.idVendor = 0xdead
d.idProduct = 0xbeef
d.bNumConfigurations = 1

d.iManufacturer = "Test Company"
d.iProduct = "Test Product"


with collection.ConfigurationDescriptor() as c:
c.bConfigurationValue = 1

with c.InterfaceDescriptor() as i:
i.bInterfaceNumber = 1

with i.EndpointDescriptor() as e:
e.bEndpointAddress = 0x81

with i.EndpointDescriptor() as e:
e.bEndpointAddress = 0x01


results = list(collection)

# We should wind up with four descriptor entries, as our endpoint/interface descriptors are
# included in our configuration descriptor.
self.assertEqual(len(results), 4)

# Manufacturer / product string.
self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)

# Device descriptor.
self.assertIn((1, 0, b'\x0f\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)

# Configuration descriptor, with subordinates.
self.assertIn((2, 0, b'\r\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)

if __name__ == "__main__":
unittest.main()

+ 4
- 6
usb_protocol/types/descriptor.py View File

@@ -159,15 +159,13 @@ class DescriptorField(construct.Subconstruct):
def __rtruediv__(self, field_name):
field_type = self._get_type_for_name(field_name)

if self.default is not None:
field_type = construct.Default(field_type, self.default)

# Build our subconstruct. Construct makes this look super weird,
# but this is actually "we have a field with <field_name> of type <field_type>".
# In long form, we'll call it "description".
subconstruct = (field_name / field_type) * self.description

if self.default is not None:
return construct.Default(subconstruct, self.default)
else:
return subconstruct
return (field_name / field_type) * self.description


# Convenience type that gets a descriptor's own length.


Loading…
Cancel
Save