diff --git a/pasn1.py b/pasn1.py index baef577..6066a0a 100644 --- a/pasn1.py +++ b/pasn1.py @@ -161,7 +161,7 @@ class ASN1Coder(object): _typetag = dict((v, k) for k, v in _tagmap.iteritems()) @staticmethod - def enc_int(obj): + def enc_int(obj, **kwargs): l = obj.bit_length() l += 1 # space for sign bit @@ -190,7 +190,7 @@ class ASN1Coder(object): return v, end @staticmethod - def enc_bool(obj): + def enc_bool(obj, **kwargs): return '\x01' + ('\xff' if obj else '\x00') def dec_bool(self, d, pos, end): @@ -201,15 +201,15 @@ class ASN1Coder(object): return bool(v), end @staticmethod - def enc_null(obj): + def enc_null(obj, **kwargs): return '\x00' @staticmethod def dec_null(d, pos, end): return None, end - def enc_list(self, obj): - r = ''.join(self.dumps(x) for x in obj) + def enc_list(self, obj, **kwargs): + r = ''.join(self.dumps(x, **kwargs) for x in obj) return _encodelen(len(r)) + r def dec_list(self, d, pos, end): @@ -231,7 +231,7 @@ class ASN1Coder(object): return set(r), end @staticmethod - def enc_bytes(obj): + def enc_bytes(obj, **kwargs): return _encodelen(len(obj)) + bytes(obj) @staticmethod @@ -239,7 +239,7 @@ class ASN1Coder(object): return d[pos:end], end @staticmethod - def enc_unicode(obj): + def enc_unicode(obj, **kwargs): encobj = obj.encode('utf-8') return _encodelen(len(encobj)) + encobj @@ -247,7 +247,7 @@ class ASN1Coder(object): return d[pos:end].decode('utf-8'), end @staticmethod - def enc_float(obj): + def enc_float(obj, **kwargs): s = math.copysign(1, obj) if math.isnan(obj): return _encodelen(1) + chr(0b01000010) @@ -335,7 +335,7 @@ class ASN1Coder(object): except KeyError: if default is not None: try: - return self.dumps(default(obj)) + return self.dumps(default(obj), default=default) except TypeError: pass @@ -345,7 +345,7 @@ class ASN1Coder(object): tf, obj = self.coerce(obj) fun = getattr(self, 'enc_%s' % tf) - return self._typetag[tf] + fun(obj) + return self._typetag[tf] + fun(obj, default=default) def _loads(self, data, pos, end): tag = data[pos] @@ -360,7 +360,7 @@ class ASN1Coder(object): fun = getattr(self, 'dec_%s' % t) return fun(data, pos + 1 + b, end) - def enc_datetime(self, obj): + def enc_datetime(self, obj, **kwargs): ts = obj.strftime('%Y%m%d%H%M%S') if obj.microsecond: ts += ('.%06d' % obj.microsecond).rstrip('0') @@ -414,10 +414,10 @@ class ASN1DictCoder(ASN1Coder): _tagmap['\xe0'] = 'dict' _typetag = dict((v, k) for k, v in _tagmap.iteritems()) - def enc_dict(self, obj): + def enc_dict(self, obj, **kwargs): #it = list(obj.iteritems()) #it.sort() - r = ''.join(self.dumps(k) + self.dumps(v) for k, v in + r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in obj.iteritems()) return _encodelen(len(r)) + r @@ -667,6 +667,10 @@ class TestCode(unittest.TestCase): def somefun(self): return 5 + class Dummy2(object): + def somefun(self): + return [ Dummy() ] + def deffun(obj): try: return obj.somefun() @@ -674,5 +678,12 @@ class TestCode(unittest.TestCase): raise TypeError self.assertEqual(dumps(5), dumps(Dummy(), default=deffun)) - # XXX make subobjects work + + # Make sure it works for the various containers self.assertEqual(dumps([5]), dumps([Dummy()], default=deffun)) + self.assertEqual(dumps({ 5: 5 }), dumps({ Dummy(): Dummy() }, + default=deffun)) + self.assertEqual(dumps([5]), dumps(Dummy2(), default=deffun)) + + # Make sure that an error is raised when the function doesn't work + self.assertRaises(TypeError, dumps, object(), default=deffun)