Browse Source

make the dumps default arg work for all container args...

This should also allow expansion if needed..
python2
John-Mark Gurney 5 years ago
parent
commit
f5fd88e7b3
1 changed files with 25 additions and 14 deletions
  1. +25
    -14
      pasn1.py

+ 25
- 14
pasn1.py View File

@@ -161,7 +161,7 @@ class ASN1Coder(object):
_typetag = dict((v, k) for k, v in _tagmap.iteritems()) _typetag = dict((v, k) for k, v in _tagmap.iteritems())


@staticmethod @staticmethod
def enc_int(obj):
def enc_int(obj, **kwargs):
l = obj.bit_length() l = obj.bit_length()
l += 1 # space for sign bit l += 1 # space for sign bit


@@ -190,7 +190,7 @@ class ASN1Coder(object):
return v, end return v, end


@staticmethod @staticmethod
def enc_bool(obj):
def enc_bool(obj, **kwargs):
return '\x01' + ('\xff' if obj else '\x00') return '\x01' + ('\xff' if obj else '\x00')


def dec_bool(self, d, pos, end): def dec_bool(self, d, pos, end):
@@ -201,15 +201,15 @@ class ASN1Coder(object):
return bool(v), end return bool(v), end


@staticmethod @staticmethod
def enc_null(obj):
def enc_null(obj, **kwargs):
return '\x00' return '\x00'


@staticmethod @staticmethod
def dec_null(d, pos, end): def dec_null(d, pos, end):
return None, end return None, end


def enc_list(self, obj):
r = ''.join(self.dumps(x) for x in obj)
def enc_list(self, obj, **kwargs):
r = ''.join(self.dumps(x, **kwargs) for x in obj)
return _encodelen(len(r)) + r return _encodelen(len(r)) + r


def dec_list(self, d, pos, end): def dec_list(self, d, pos, end):
@@ -231,7 +231,7 @@ class ASN1Coder(object):
return set(r), end return set(r), end


@staticmethod @staticmethod
def enc_bytes(obj):
def enc_bytes(obj, **kwargs):
return _encodelen(len(obj)) + bytes(obj) return _encodelen(len(obj)) + bytes(obj)


@staticmethod @staticmethod
@@ -239,7 +239,7 @@ class ASN1Coder(object):
return d[pos:end], end return d[pos:end], end


@staticmethod @staticmethod
def enc_unicode(obj):
def enc_unicode(obj, **kwargs):
encobj = obj.encode('utf-8') encobj = obj.encode('utf-8')
return _encodelen(len(encobj)) + encobj return _encodelen(len(encobj)) + encobj


@@ -247,7 +247,7 @@ class ASN1Coder(object):
return d[pos:end].decode('utf-8'), end return d[pos:end].decode('utf-8'), end


@staticmethod @staticmethod
def enc_float(obj):
def enc_float(obj, **kwargs):
s = math.copysign(1, obj) s = math.copysign(1, obj)
if math.isnan(obj): if math.isnan(obj):
return _encodelen(1) + chr(0b01000010) return _encodelen(1) + chr(0b01000010)
@@ -335,7 +335,7 @@ class ASN1Coder(object):
except KeyError: except KeyError:
if default is not None: if default is not None:
try: try:
return self.dumps(default(obj))
return self.dumps(default(obj), default=default)
except TypeError: except TypeError:
pass pass


@@ -345,7 +345,7 @@ class ASN1Coder(object):
tf, obj = self.coerce(obj) tf, obj = self.coerce(obj)


fun = getattr(self, 'enc_%s' % tf) fun = getattr(self, 'enc_%s' % tf)
return self._typetag[tf] + fun(obj)
return self._typetag[tf] + fun(obj, default=default)


def _loads(self, data, pos, end): def _loads(self, data, pos, end):
tag = data[pos] tag = data[pos]
@@ -360,7 +360,7 @@ class ASN1Coder(object):
fun = getattr(self, 'dec_%s' % t) fun = getattr(self, 'dec_%s' % t)
return fun(data, pos + 1 + b, end) return fun(data, pos + 1 + b, end)


def enc_datetime(self, obj):
def enc_datetime(self, obj, **kwargs):
ts = obj.strftime('%Y%m%d%H%M%S') ts = obj.strftime('%Y%m%d%H%M%S')
if obj.microsecond: if obj.microsecond:
ts += ('.%06d' % obj.microsecond).rstrip('0') ts += ('.%06d' % obj.microsecond).rstrip('0')
@@ -414,10 +414,10 @@ class ASN1DictCoder(ASN1Coder):
_tagmap['\xe0'] = 'dict' _tagmap['\xe0'] = 'dict'
_typetag = dict((v, k) for k, v in _tagmap.iteritems()) _typetag = dict((v, k) for k, v in _tagmap.iteritems())


def enc_dict(self, obj):
def enc_dict(self, obj, **kwargs):
#it = list(obj.iteritems()) #it = list(obj.iteritems())
#it.sort() #it.sort()
r = ''.join(self.dumps(k) + self.dumps(v) for k, v in
r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
obj.iteritems()) obj.iteritems())
return _encodelen(len(r)) + r return _encodelen(len(r)) + r


@@ -667,6 +667,10 @@ class TestCode(unittest.TestCase):
def somefun(self): def somefun(self):
return 5 return 5


class Dummy2(object):
def somefun(self):
return [ Dummy() ]

def deffun(obj): def deffun(obj):
try: try:
return obj.somefun() return obj.somefun()
@@ -674,5 +678,12 @@ class TestCode(unittest.TestCase):
raise TypeError raise TypeError


self.assertEqual(dumps(5), dumps(Dummy(), default=deffun)) self.assertEqual(dumps(5), dumps(Dummy(), default=deffun))
# XXX make subobjects work

# Make sure it works for the various containers
self.assertEqual(dumps([5]), dumps([Dummy()], default=deffun)) self.assertEqual(dumps([5]), dumps([Dummy()], default=deffun))
self.assertEqual(dumps({ 5: 5 }), dumps({ Dummy(): Dummy() },
default=deffun))
self.assertEqual(dumps([5]), dumps(Dummy2(), default=deffun))

# Make sure that an error is raised when the function doesn't work
self.assertRaises(TypeError, dumps, object(), default=deffun)

Loading…
Cancel
Save