From 67c86d3494e5f6d846bbe7117f46efa6cf89aedb Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Sat, 13 Feb 2016 00:44:43 -0800 Subject: [PATCH] add inital version of this code.. only special float values work.. various time formats not supported.. [git-p4: depot-paths = "//depot/python/pypasn1/main/": change = 1817] --- pasn1.py | 346 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 346 insertions(+) create mode 100644 pasn1.py diff --git a/pasn1.py b/pasn1.py new file mode 100644 index 0000000..94ec31d --- /dev/null +++ b/pasn1.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python + +# A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit +# of pickle. It will automaticly do the correct thing if possible. +# +# This uses a profile of ASN.1. +# +# All lengths must be specified. That is that End-of-contents octets +# MUST not be used. The shorted form of length encoding MUST be used. +# A longer length encoding MUST be rejected. + +import pdb +import math +import sys +import unittest + +def _numtostr(n): + hs = '%x' % n + if len(hs) & 1 == 1: + hs = '0' + hs + bs = hs.decode('hex') + + return bs + +def _encodelen(l): + '''Takes l as a length value, and returns a byte string that + represents l per ASN.1 rules.''' + + if l < 128: + return chr(l) + + bs = _numtostr(l) + return chr(len(bs) | 0x80) + bs + +def _decodelen(d, pos=0): + '''Returns the length, and number of bytes required.''' + + odp = ord(d[pos]) + if odp < 128: + return ord(d[pos]), 1 + else: + l = odp & 0x7f + return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1 + +class Test_codelen(unittest.TestCase): + _testdata = [ + (2, '\x02'), + (127, '\x7f'), + (128, '\x81\x80'), + (255, '\x81\xff'), + (256, '\x82\x01\x00'), + (65536-1, '\x82\xff\xff'), + (65536, '\x83\x01\x00\x00'), + ] + + def test_el(self): + for i, j in self._testdata: + self.assertEqual(_encodelen(i), j) + self.assertEqual(_decodelen(j), (i, len(j))) + +def _splitfloat(f): + m, e = math.frexp(f) + # XXX - less than ideal + while m != math.trunc(m): + m *= 2 + e -= 1 + + return m, e + +class TestSplitFloat(unittest.TestCase): + def test_sf(self): + for a, b in [ (0x2421, -32), (0x5382f, 238), + (0x1fa8c3b094adf1, 971) ]: + self.assertEqual(_splitfloat(a * 2**b), (a, b)) + +class ASN1Object: + def __init__(self, tag): + self._tag = tag + +class ASN1Coder(object): + def __init__(self): + pass + + _typemap = { + bool: 'bool', + dict: 'dict', + float: 'float', + int: 'int', + list: 'list', + long: 'int', + set: 'set', + str: 'bytes', + type(None): 'none', + unicode: 'unicode', + } + _tagmap = { + '\x01': 'bool', + '\x02': 'int', + '\x04': 'bytes', + '\x05': 'none', + '\x09': 'float', + '\x0c': 'unicode', + '\x30': 'list', + '\x31': 'set', + '\xc0': 'dict', + } + + _typetag = dict((v, k) for k, v in _tagmap.iteritems()) + + @staticmethod + def enc_int(obj): + l = obj.bit_length() + l += 1 # space for sign bit + + l = (l + 7) // 8 + + if obj < 0: + obj += 1 << (l * 8) # twos-complement conversion + + v = _numtostr(obj) + if len(v) != l: + # XXX - is this a problem for signed values? + v = '\x00' + v # add sign octect + + return _encodelen(l) + v + + @staticmethod + def dec_int(d, pos, end): + if pos == end: + return 0, end + + v = int(d[pos:end].encode('hex'), 16) + av = 1 << ((end - pos) * 8 - 1) # sign bit + if v > av: + v -= av * 2 # twos-complement conversion + + return v, end + + @staticmethod + def enc_bool(obj): + return '\x01' + chr(obj) + + def dec_bool(self, d, pos, end): + return bool(self.dec_int(d, pos, end)[0]), end + + @staticmethod + def enc_none(obj): + return '\x00' + + @staticmethod + def dec_none(d, pos, end): + return None, end + + def enc_dict(self, obj): + #it = list(obj.iteritems()) + #it.sort() + r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems()) + return _encodelen(len(r)) + r + + def dec_dict(self, d, pos, end): + r = {} + while pos < end: + k, kend = self._loads(d, pos, end) + v, vend = self._loads(d, kend, end) + + r[k] = v + pos = vend + + return r, vend + + def enc_set(self, obj): + r = ''.join(self.dumps(x) for x in obj) + return _encodelen(len(r)) + r + + def dec_set(self, d, pos, end): + r, end = self.dec_list(d, pos, end) + return set(r), end + + def enc_list(self, obj): + r = ''.join(self.dumps(x) for x in obj) + return _encodelen(len(r)) + r + + def dec_list(self, d, pos, end): + r = [] + while pos < end: + v, vend = self._loads(d, pos, end) + r.append(v) + pos = vend + + return r, vend + + @staticmethod + def enc_bytes(obj): + return _encodelen(len(obj)) + obj + + @staticmethod + def dec_bytes(d, pos, end): + return d[pos:end], end + + @staticmethod + def enc_unicode(obj): + encobj = obj.encode('utf-8') + return _encodelen(len(encobj)) + encobj + + def dec_unicode(self, d, pos, end): + return d[pos:end].decode('utf-8'), end + + @staticmethod + def enc_float(obj): + s = math.copysign(1, obj) + if math.isnan(obj): + return _encodelen(1) + chr(0b01000010) + elif math.isinf(obj): + if s == 1: + return _encodelen(1) + chr(0b01000000) + else: + return _encodelen(1) + chr(0b01000001) + elif obj == 0: + if s == 1: + return _encodelen(0) + else: + return _encodelen(1) + chr(0b01000011) + + m, e = _splitfloat(obj) + + # Binary encoding + val = 0x80 + if m < 0: + val |= 0x40 + m = -m + # Base 2 + # XXX - negative e + el = (e.bit_length() + 7) // 8 + if el > 3: + v = 0x3 + encexp = _encodelen(el) + _numtostr(e) + else: + v = el - 1 + encexp = _numtostr(e) + + return chr(val) + encexp + _numtostr(m) + + @staticmethod + def dec_float(d, pos, end): + if pos == end: + return float(0), end + + v = ord(d[pos]) + if v == 0b01000000: + return float('inf'), end + elif v == 0b01000001: + return float('-inf'), end + elif v == 0b01000010: + return float('nan'), end + elif v == 0b01000011: + return float('-0'), end + #elif v & 0b11000000 == 0b01000000: + # raise ValueError('invalid encoding') + + print 'df:', `d, pos, end` + + raise NotImplementedError + + def dumps(self, obj): + tf = self._typemap[type(obj)] + fun = getattr(self, 'enc_%s' % tf) + return self._typetag[tf] + fun(obj) + + def _loads(self, data, pos, end): + tag = data[pos] + l, b = _decodelen(data, pos + 1) + if len(data) < pos + 1 + b + l: + print `data, pos, end, l, b` + raise ValueError('string not long enough') + + # XXX - enforce that len(data) == end? + end = pos + 1 + b + l + + t = self._tagmap[tag] + fun = getattr(self, 'dec_%s' % t) + return fun(data, pos + 1 + b, end) + + def loads(self, data, pos=0, end=None, consume=False): + if end is None: + end = len(data) + r, e = self._loads(data, pos, end) + + if consume and e != end: + raise ValueError('entire string not consumed') + return r + +_coder = ASN1Coder() +dumps = _coder.dumps +loads = _coder.loads + +class TestCode(unittest.TestCase): + def test_primv(self): + self.assertEqual(dumps(-257), '0202feff'.decode('hex')) + self.assertEqual(dumps(-256), '0202ff00'.decode('hex')) + self.assertEqual(dumps(-255), '0202ff01'.decode('hex')) + self.assertEqual(dumps(-1), '0201ff'.decode('hex')) + self.assertEqual(dumps(5), '020105'.decode('hex')) + self.assertEqual(dumps(128), '02020080'.decode('hex')) + self.assertEqual(dumps(256), '02020100'.decode('hex')) + + self.assertEqual(dumps(False), '010100'.decode('hex')) + self.assertEqual(dumps(True), '010101'.decode('hex')) + + self.assertEqual(dumps(None), '0500'.decode('hex')) + + def test_consume(self): + b = dumps(5) + self.assertRaises(ValueError, loads, b + '398473', consume=True) + + # XXX - still possible that an internal data member + # doesn't consume all + + def test_nan(self): + s = dumps(float('nan')) + v = loads(s) + self.assertTrue(math.isnan(v)) + + def test_dumps(self): + for i in [ None, + True, False, + -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984, + float(0), float('-0'), float('inf'), float('-inf'), + 'weoifjwef', + u'\U0001f4a9', + set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))), + ]: + s = dumps(i) + o = loads(s) + self.assertEqual(i, o) + + print 'done' + tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] } + + out = dumps(tobj) + self.assertEqual(tobj, loads(out)) + + def test_loads(self): + self.assertRaises(ValueError, loads, '\x00\x02\x00') + +if __name__ == '__main__': + pass