diff --git a/pasn1.py b/pasn1.py index ef50393..5bfbb5d 100644 --- a/pasn1.py +++ b/pasn1.py @@ -77,17 +77,29 @@ class TestSplitFloat(unittest.TestCase): (0x1fa8c3b094adf1, 971) ]: self.assertEqual(_splitfloat(a * 2**b), (a, b)) -class ASN1Object: - def __init__(self, tag): - self._tag = tag - class ASN1Coder(object): '''A class that contains an PASN.1 encoder/decoder. Exports two methods, loads and dumps.''' - def __init__(self): - pass + def __init__(self, coerce=None): + '''If the arg coerce is provided, when dumping the object, + if the type is not found, the coerce function will be called + with the obj. It is expected to return a tuple of a string + and an object that has the method w/ the string as defined: + 'bool': __nonzero__ + 'dict': iteritems + 'float': compatible w/ float + 'int': compatible w/ int + 'list': __iter__ + 'set': __iter__ + 'bytes': __str__ + 'null': no method needed + 'unicode': encode method returns UTF-8 encoded bytes + 'datetime': strftime and microsecond + ''' + + self.coerce = coerce _typemap = { bool: 'bool', @@ -189,14 +201,6 @@ class ASN1Coder(object): return r, vend - def enc_set(self, obj): - r = ''.join(self.dumps(x) for x in obj) - return _encodelen(len(r)) + r - - def dec_set(self, d, pos, end): - r, end = self.dec_list(d, pos, end) - return set(r), end - def enc_list(self, obj): r = ''.join(self.dumps(x) for x in obj) return _encodelen(len(r)) + r @@ -213,9 +217,15 @@ class ASN1Coder(object): return r, vend + enc_set = enc_list + + def dec_set(self, d, pos, end): + r, end = self.dec_list(d, pos, end) + return set(r), end + @staticmethod def enc_bytes(obj): - return _encodelen(len(obj)) + obj + return _encodelen(len(obj)) + bytes(obj) @staticmethod def dec_bytes(d, pos, end): @@ -311,7 +321,14 @@ class ASN1Coder(object): def dumps(self, obj): '''Convert obj into a string.''' - tf = self._typemap[type(obj)] + try: + tf = self._typemap[type(obj)] + except KeyError: + if self.coerce is None: + raise + + tf, obj = self.coerce(obj) + fun = getattr(self, 'enc_%s' % tf) return self._typetag[tf] + fun(obj) @@ -520,5 +537,39 @@ class TestCode(unittest.TestCase): out = dumps(tobj) self.assertEqual(tobj, loads(out)) + def test_coerce(self): + class Foo: + pass + + class Bar: + pass + + class Baz: + pass + def coerce(obj): + if isinstance(obj, Foo): + return 'list', obj.lst + elif isinstance(obj, Baz): + return 'bytes', obj.s + + raise TypeError('unknown type') + + ac = ASN1Coder(coerce) + + v = [1, 2, 3] + o = Foo() + o.lst = v + + self.assertEqual(ac.loads(ac.dumps(o)), v) + self.assertRaises(TypeError, ac.dumps, Bar()) + + v = u'oiejfd' + o = Baz() + o.s = v + + es = ac.dumps(o) + self.assertEqual(ac.loads(es), v) + self.assertIsInstance(es, bytes) + def test_loads(self): self.assertRaises(ValueError, loads, '\x00\x02\x00')