#!/usr/bin/env python # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit # of pickle. It will automaticly do the correct thing if possible. # # This uses a profile of ASN.1. # # All lengths must be specified. That is that End-of-contents octets # MUST not be used. The shorted form of length encoding MUST be used. # A longer length encoding MUST be rejected. import pdb import math import sys import unittest def _numtostr(n): hs = '%x' % n if len(hs) & 1 == 1: hs = '0' + hs bs = hs.decode('hex') return bs def _encodelen(l): '''Takes l as a length value, and returns a byte string that represents l per ASN.1 rules.''' if l < 128: return chr(l) bs = _numtostr(l) return chr(len(bs) | 0x80) + bs def _decodelen(d, pos=0): '''Returns the length, and number of bytes required.''' odp = ord(d[pos]) if odp < 128: return ord(d[pos]), 1 else: l = odp & 0x7f return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1 class Test_codelen(unittest.TestCase): _testdata = [ (2, '\x02'), (127, '\x7f'), (128, '\x81\x80'), (255, '\x81\xff'), (256, '\x82\x01\x00'), (65536-1, '\x82\xff\xff'), (65536, '\x83\x01\x00\x00'), ] def test_el(self): for i, j in self._testdata: self.assertEqual(_encodelen(i), j) self.assertEqual(_decodelen(j), (i, len(j))) def _splitfloat(f): m, e = math.frexp(f) # XXX - less than ideal while m != math.trunc(m): m *= 2 e -= 1 return m, e class TestSplitFloat(unittest.TestCase): def test_sf(self): for a, b in [ (0x2421, -32), (0x5382f, 238), (0x1fa8c3b094adf1, 971) ]: self.assertEqual(_splitfloat(a * 2**b), (a, b)) class ASN1Object: def __init__(self, tag): self._tag = tag class ASN1Coder(object): def __init__(self): pass _typemap = { bool: 'bool', dict: 'dict', float: 'float', int: 'int', list: 'list', long: 'int', set: 'set', str: 'bytes', type(None): 'none', unicode: 'unicode', } _tagmap = { '\x01': 'bool', '\x02': 'int', '\x04': 'bytes', '\x05': 'none', '\x09': 'float', '\x0c': 'unicode', '\x30': 'list', '\x31': 'set', '\xc0': 'dict', } _typetag = dict((v, k) for k, v in _tagmap.iteritems()) @staticmethod def enc_int(obj): l = obj.bit_length() l += 1 # space for sign bit l = (l + 7) // 8 if obj < 0: obj += 1 << (l * 8) # twos-complement conversion v = _numtostr(obj) if len(v) != l: # XXX - is this a problem for signed values? v = '\x00' + v # add sign octect return _encodelen(l) + v @staticmethod def dec_int(d, pos, end): if pos == end: return 0, end v = int(d[pos:end].encode('hex'), 16) av = 1 << ((end - pos) * 8 - 1) # sign bit if v > av: v -= av * 2 # twos-complement conversion return v, end @staticmethod def enc_bool(obj): return '\x01' + chr(obj) def dec_bool(self, d, pos, end): return bool(self.dec_int(d, pos, end)[0]), end @staticmethod def enc_none(obj): return '\x00' @staticmethod def dec_none(d, pos, end): return None, end def enc_dict(self, obj): #it = list(obj.iteritems()) #it.sort() r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems()) return _encodelen(len(r)) + r def dec_dict(self, d, pos, end): r = {} while pos < end: k, kend = self._loads(d, pos, end) v, vend = self._loads(d, kend, end) r[k] = v pos = vend return r, vend def enc_set(self, obj): r = ''.join(self.dumps(x) for x in obj) return _encodelen(len(r)) + r def dec_set(self, d, pos, end): r, end = self.dec_list(d, pos, end) return set(r), end def enc_list(self, obj): r = ''.join(self.dumps(x) for x in obj) return _encodelen(len(r)) + r def dec_list(self, d, pos, end): r = [] while pos < end: v, vend = self._loads(d, pos, end) r.append(v) pos = vend return r, vend @staticmethod def enc_bytes(obj): return _encodelen(len(obj)) + obj @staticmethod def dec_bytes(d, pos, end): return d[pos:end], end @staticmethod def enc_unicode(obj): encobj = obj.encode('utf-8') return _encodelen(len(encobj)) + encobj def dec_unicode(self, d, pos, end): return d[pos:end].decode('utf-8'), end @staticmethod def enc_float(obj): s = math.copysign(1, obj) if math.isnan(obj): return _encodelen(1) + chr(0b01000010) elif math.isinf(obj): if s == 1: return _encodelen(1) + chr(0b01000000) else: return _encodelen(1) + chr(0b01000001) elif obj == 0: if s == 1: return _encodelen(0) else: return _encodelen(1) + chr(0b01000011) m, e = _splitfloat(obj) # Binary encoding val = 0x80 if m < 0: val |= 0x40 m = -m # Base 2 # XXX - negative e el = (e.bit_length() + 7) // 8 if el > 3: v = 0x3 encexp = _encodelen(el) + _numtostr(e) else: v = el - 1 encexp = _numtostr(e) return chr(val) + encexp + _numtostr(m) @staticmethod def dec_float(d, pos, end): if pos == end: return float(0), end v = ord(d[pos]) if v == 0b01000000: return float('inf'), end elif v == 0b01000001: return float('-inf'), end elif v == 0b01000010: return float('nan'), end elif v == 0b01000011: return float('-0'), end #elif v & 0b11000000 == 0b01000000: # raise ValueError('invalid encoding') print 'df:', `d, pos, end` raise NotImplementedError def dumps(self, obj): tf = self._typemap[type(obj)] fun = getattr(self, 'enc_%s' % tf) return self._typetag[tf] + fun(obj) def _loads(self, data, pos, end): tag = data[pos] l, b = _decodelen(data, pos + 1) if len(data) < pos + 1 + b + l: print `data, pos, end, l, b` raise ValueError('string not long enough') # XXX - enforce that len(data) == end? end = pos + 1 + b + l t = self._tagmap[tag] fun = getattr(self, 'dec_%s' % t) return fun(data, pos + 1 + b, end) def loads(self, data, pos=0, end=None, consume=False): if end is None: end = len(data) r, e = self._loads(data, pos, end) if consume and e != end: raise ValueError('entire string not consumed') return r _coder = ASN1Coder() dumps = _coder.dumps loads = _coder.loads class TestCode(unittest.TestCase): def test_primv(self): self.assertEqual(dumps(-257), '0202feff'.decode('hex')) self.assertEqual(dumps(-256), '0202ff00'.decode('hex')) self.assertEqual(dumps(-255), '0202ff01'.decode('hex')) self.assertEqual(dumps(-1), '0201ff'.decode('hex')) self.assertEqual(dumps(5), '020105'.decode('hex')) self.assertEqual(dumps(128), '02020080'.decode('hex')) self.assertEqual(dumps(256), '02020100'.decode('hex')) self.assertEqual(dumps(False), '010100'.decode('hex')) self.assertEqual(dumps(True), '010101'.decode('hex')) self.assertEqual(dumps(None), '0500'.decode('hex')) def test_consume(self): b = dumps(5) self.assertRaises(ValueError, loads, b + '398473', consume=True) # XXX - still possible that an internal data member # doesn't consume all def test_nan(self): s = dumps(float('nan')) v = loads(s) self.assertTrue(math.isnan(v)) def test_dumps(self): for i in [ None, True, False, -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984, float(0), float('-0'), float('inf'), float('-inf'), 'weoifjwef', u'\U0001f4a9', set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))), ]: s = dumps(i) o = loads(s) self.assertEqual(i, o) print 'done' tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] } out = dumps(tobj) self.assertEqual(tobj, loads(out)) def test_loads(self): self.assertRaises(ValueError, loads, '\x00\x02\x00') if __name__ == '__main__': pass