diff --git a/Makefile b/Makefile index 9ee85c2..eb8cf8b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,8 @@ MODULES=pasn1.py +VIRTUALENV?=virtualenv-3.7 test: - (echo $(MODULES) | entr sh -c 'python -m coverage run -m unittest $(basename $(MODULES)) && coverage report -m') + (echo $(MODULES) | entr sh -c 'python -m coverage run -m unittest $(basename $(MODULES)) && coverage report --omit=p/\* -m -i') + +env: + ($(VIRTUALENV) p && . ./p/bin/activate && pip install -r requirements.txt) diff --git a/pasn1.py b/pasn1.py index 6066a0a..5a40b09 100644 --- a/pasn1.py +++ b/pasn1.py @@ -43,18 +43,22 @@ __license__ = '2-clause BSD license' # either expressed or implied, of the Project. import datetime +import functools import math import os import sys import unittest +if sys.version_info.major != 3: # pragma: no cover + raise RuntimeError('this module only supports python 3') + __all__ = [ 'dumps', 'loads', 'ASN1Coder' ] -def _numtostr(n): +def _numtobytes(n): hs = '%x' % n if len(hs) & 1 == 1: hs = '0' + hs - bs = hs.decode('hex') + bs = bytes.fromhex(hs) return bs @@ -63,30 +67,30 @@ def _encodelen(l): represents l per ASN.1 rules.''' if l < 128: - return chr(l) + return bytes([l]) - bs = _numtostr(l) - return chr(len(bs) | 0x80) + bs + bs = _numtobytes(l) + return bytes([len(bs) | 0x80]) + bs def _decodelen(d, pos=0): '''Returns the length, and number of bytes required.''' - odp = ord(d[pos]) + odp = d[pos] if odp < 128: - return ord(d[pos]), 1 + return d[pos], 1 else: l = odp & 0x7f - return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1 + return int(d[pos + 1:pos + 1 + l].hex(), 16), l + 1 class Test_codelen(unittest.TestCase): _testdata = [ - (2, '\x02'), - (127, '\x7f'), - (128, '\x81\x80'), - (255, '\x81\xff'), - (256, '\x82\x01\x00'), - (65536-1, '\x82\xff\xff'), - (65536, '\x83\x01\x00\x00'), + (2, b'\x02'), + (127, b'\x7f'), + (128, b'\x81\x80'), + (255, b'\x81\xff'), + (256, b'\x82\x01\x00'), + (65536-1, b'\x82\xff\xff'), + (65536, b'\x83\x01\x00\x00'), ] def test_el(self): @@ -101,7 +105,7 @@ def _splitfloat(f): m *= 2 e -= 1 - return m, e + return math.trunc(m), e class TestSplitFloat(unittest.TestCase): def test_sf(self): @@ -124,7 +128,7 @@ class ASN1Coder(object): 'int': compatible w/ int 'list': __iter__ 'set': __iter__ - 'bytes': __str__ + 'bytes': __str__ # XXX what is correct here 'null': no method needed 'unicode': encode method returns UTF-8 encoded bytes 'datetime': strftime and microsecond @@ -137,28 +141,27 @@ class ASN1Coder(object): float: 'float', int: 'int', list: 'list', - long: 'int', set: 'set', - str: 'bytes', + bytes: 'bytes', type(None): 'null', - unicode: 'unicode', + str: 'unicode', #decimal.Decimal: 'float', datetime.datetime: 'datetime', #datetime.timedelta: 'timedelta', } _tagmap = { - '\x01': 'bool', - '\x02': 'int', - '\x04': 'bytes', - '\x05': 'null', - '\x09': 'float', - '\x0c': 'unicode', - '\x18': 'datetime', - '\x30': 'list', - '\x31': 'set', + b'\x01': 'bool', + b'\x02': 'int', + b'\x04': 'bytes', + b'\x05': 'null', + b'\x09': 'float', + b'\x0c': 'unicode', + b'\x18': 'datetime', + b'\x30': 'list', + b'\x31': 'set', } - _typetag = dict((v, k) for k, v in _tagmap.iteritems()) + _typetag = dict((v, k) for k, v in _tagmap.items()) @staticmethod def enc_int(obj, **kwargs): @@ -170,10 +173,10 @@ class ASN1Coder(object): if obj < 0: obj += 1 << (l * 8) # twos-complement conversion - v = _numtostr(obj) + v = _numtobytes(obj) if len(v) != l: # XXX - is this a problem for signed values? - v = '\x00' + v # add sign octect + v = b'\x00' + v # add sign octect return _encodelen(l) + v @@ -182,7 +185,7 @@ class ASN1Coder(object): if pos == end: return 0, end - v = int(d[pos:end].encode('hex'), 16) + v = int(bytes.hex(d[pos:end]), 16) av = 1 << ((end - pos) * 8 - 1) # sign bit if v > av: v -= av * 2 # twos-complement conversion @@ -191,7 +194,7 @@ class ASN1Coder(object): @staticmethod def enc_bool(obj, **kwargs): - return '\x01' + ('\xff' if obj else '\x00') + return b'\x01' + (b'\xff' if obj else b'\x00') def dec_bool(self, d, pos, end): v = self.dec_int(d, pos, end)[0] @@ -202,14 +205,14 @@ class ASN1Coder(object): @staticmethod def enc_null(obj, **kwargs): - return '\x00' + return b'\x00' @staticmethod def dec_null(d, pos, end): return None, end def enc_list(self, obj, **kwargs): - r = ''.join(self.dumps(x, **kwargs) for x in obj) + r = b''.join(self.dumps(x, **kwargs) for x in obj) return _encodelen(len(r)) + r def dec_list(self, d, pos, end): @@ -250,17 +253,17 @@ class ASN1Coder(object): def enc_float(obj, **kwargs): s = math.copysign(1, obj) if math.isnan(obj): - return _encodelen(1) + chr(0b01000010) + return _encodelen(1) + bytes([0b01000010]) elif math.isinf(obj): if s == 1: - return _encodelen(1) + chr(0b01000000) + return _encodelen(1) + bytes([0b01000000]) else: - return _encodelen(1) + chr(0b01000001) + return _encodelen(1) + bytes([0b01000001]) elif obj == 0: if s == 1: return _encodelen(0) else: - return _encodelen(1) + chr(0b01000011) + return _encodelen(1) + bytes([0b01000011]) m, e = _splitfloat(obj) @@ -279,17 +282,17 @@ class ASN1Coder(object): e += 256**el # convert negative to twos-complement v = el - 1 - encexp = _numtostr(e) + encexp = _numtobytes(e) val |= v - r = chr(val) + encexp + _numtostr(m) + r = bytes([val]) + encexp + _numtobytes(m) return _encodelen(len(r)) + r def dec_float(self, d, pos, end): if pos == end: return float(0), end - v = ord(d[pos]) + v = d[pos] if v == 0b01000000: return float('inf'), end elif v == 0b01000001: @@ -314,7 +317,7 @@ class ASN1Coder(object): exp = self.dec_int(d, pexp, eexp)[0] - n = float(int(d[eexp:end].encode('hex'), 16)) + n = float(int(bytes.hex(d[eexp:end]), 16)) r = n * 2 ** exp if v & 0b1000000: r = -r @@ -340,7 +343,7 @@ class ASN1Coder(object): pass if self.coerce is None: - raise TypeError('unhandled object: %s' % `obj`) + raise TypeError('unhandled object: %s' % repr(obj)) tf, obj = self.coerce(obj) @@ -348,7 +351,7 @@ class ASN1Coder(object): return self._typetag[tf] + fun(obj, default=default) def _loads(self, data, pos, end): - tag = data[pos] + tag = data[pos:pos + 1] l, b = _decodelen(data, pos + 1) if len(data) < pos + 1 + b + l: raise ValueError('string not long enough') @@ -365,12 +368,12 @@ class ASN1Coder(object): if obj.microsecond: ts += ('.%06d' % obj.microsecond).rstrip('0') ts += 'Z' - return _encodelen(len(ts)) + ts + return _encodelen(len(ts)) + ts.encode('utf-8') def dec_datetime(self, data, pos, end): - ts = data[pos:end] - if ts[-1] != 'Z': - raise ValueError('last character must be Z') + ts = data[pos:end].decode('ascii') + if ts[-1:] != 'Z': + raise ValueError('last character must be Z, was: %s' % repr(ts[-1])) # Real bug is in strptime, but work around it here. if ' ' in ts: @@ -411,14 +414,14 @@ class ASN1DictCoder(ASN1Coder): _typemap = ASN1Coder._typemap.copy() _typemap[dict] = 'dict' _tagmap = ASN1Coder._tagmap.copy() - _tagmap['\xe0'] = 'dict' - _typetag = dict((v, k) for k, v in _tagmap.iteritems()) + _tagmap[b'\xe0'] = 'dict' + _typetag = dict((v, k) for k, v in _tagmap.items()) def enc_dict(self, obj, **kwargs): #it = list(obj.iteritems()) #it.sort() - r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in - obj.iteritems()) + r = b''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in + obj.items()) return _encodelen(len(r)) + r def dec_dict(self, d, pos, end): @@ -441,11 +444,29 @@ _coder = ASN1DictCoder() dumps = _coder.dumps loads = _coder.loads +def cmp(a, b): + return (a > b) - (a < b) + +def universal_cmp(a, b): + # Because Python 3 sucks, make this function that + # orders first based upon type, then on value. + if type(a) == type(b): + if isinstance(a, (tuple, list)): + for a, b in zip(a, b): + if a != b: + return universal_cmp(a, b) + + #return cmp(len(a), len(b)) + else: + return cmp(a, b) + else: + return id(type(a)) < id(type(b)) + def deeptypecmp(obj, o): - #print 'dtc:', `obj`, `o` + #print('dtc:', repr(obj), repr(o)) if type(obj) != type(o): return False - if type(obj) in (str, unicode): + if type(obj) is str: return True if type(obj) in (list, set): @@ -454,10 +475,8 @@ def deeptypecmp(obj, o): return False if type(obj) in (dict,): - itms = obj.items() - itms.sort() - nitms = o.items() - nitms.sort() + itms = sorted(obj.items(), key=functools.cmp_to_key(universal_cmp)) + nitms = sorted(o.items(), key=functools.cmp_to_key(universal_cmp)) for (k, v), (nk, nv) in zip(itms, nitms): if not deeptypecmp(k, nk): return False @@ -474,15 +493,15 @@ class Test_deeptypecmp(unittest.TestCase): def test_false(self): for i in (([[]], [{}]), ([1], ['str']), ([], set()), - ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}), - ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}), + ({1: 2, 5: 'sdlkfj'}, {1: 2, 5: b'sdlkfj'}), + ({1: 2, 'sdlkfj': 5}, {1: 2, b'sdlkfj': 5}), ): - self.assertFalse(deeptypecmp(*i)) + self.assertFalse(deeptypecmp(*i), '%s != %s' % (i[0], i[1])) def genfailures(obj): s = dumps(obj) - for i in xrange(len(s)): - for j in (chr(x) for x in xrange(256)): + for i in range(len(s)): + for j in (bytes([x]) for x in range(256)): ts = s[:i] + j + s[i + 1:] if ts == s: continue @@ -493,24 +512,24 @@ def genfailures(obj): except (ValueError, KeyError, IndexError, TypeError): pass else: # pragma: no cover - raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i]))) + raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, s[i])) class TestCode(unittest.TestCase): def test_primv(self): - self.assertEqual(dumps(-257), '0202feff'.decode('hex')) - self.assertEqual(dumps(-256), '0202ff00'.decode('hex')) - self.assertEqual(dumps(-255), '0202ff01'.decode('hex')) - self.assertEqual(dumps(-1), '0201ff'.decode('hex')) - self.assertEqual(dumps(5), '020105'.decode('hex')) - self.assertEqual(dumps(128), '02020080'.decode('hex')) - self.assertEqual(dumps(256), '02020100'.decode('hex')) + self.assertEqual(dumps(-257), bytes.fromhex('0202feff')) + self.assertEqual(dumps(-256), bytes.fromhex('0202ff00')) + self.assertEqual(dumps(-255), bytes.fromhex('0202ff01')) + self.assertEqual(dumps(-1), bytes.fromhex('0201ff')) + self.assertEqual(dumps(5), bytes.fromhex('020105')) + self.assertEqual(dumps(128), bytes.fromhex('02020080')) + self.assertEqual(dumps(256), bytes.fromhex('02020100')) - self.assertEqual(dumps(False), '010100'.decode('hex')) - self.assertEqual(dumps(True), '0101ff'.decode('hex')) + self.assertEqual(dumps(False), bytes.fromhex('010100')) + self.assertEqual(dumps(True), bytes.fromhex('0101ff')) - self.assertEqual(dumps(None), '0500'.decode('hex')) + self.assertEqual(dumps(None), bytes.fromhex('0500')) - self.assertEqual(dumps(.15625), '090380fb05'.decode('hex')) + self.assertEqual(dumps(.15625), bytes.fromhex('090380fb05')) def test_fuzzing(self): # Make sure that when a failure is detected here, that it @@ -539,7 +558,7 @@ class TestCode(unittest.TestCase): '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z '181632303136313220383031303933302e3931353133385a', #datetime w/ space ]: - self.assertRaises(ValueError, loads, v.decode('hex')) + self.assertRaises(ValueError, loads, bytes.fromhex(v)) def test_invalid_floats(self): import mock @@ -548,7 +567,7 @@ class TestCode(unittest.TestCase): def test_consume(self): b = dumps(5) - self.assertRaises(ValueError, loads, b + '398473', + self.assertRaises(ValueError, loads, b + b'398473', consume=True) # XXX - still possible that an internal data member @@ -565,10 +584,10 @@ class TestCode(unittest.TestCase): def test_cryptoutilasn1(self): '''Test DER sequences generated by Crypto.Util.asn1.''' - for s, v in [ ('\x02\x03$\x8a\xf9', 2394873), - ('\x05\x00', None), - ('\x02\x03\x00\x96I', 38473), - ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200), + for s, v in [ (b'\x02\x03$\x8a\xf9', 2394873), + (b'\x05\x00', None), + (b'\x02\x03\x00\x96I', 38473), + (b'\x04\x81\xc8' + b'\x00' * 200, b'\x00' * 200), ]: self.assertEqual(loads(s), v) @@ -595,7 +614,7 @@ class TestCode(unittest.TestCase): sys.float_info.max, sys.float_info.min, float('.15625'), 'weoifjwef', - u'\U0001f4a9', + '\U0001f4a9', [], [ 1,2,3 ], {}, { 5: 10, 'adfkj': 34 }, set(), set((1,2,3)), @@ -609,8 +628,8 @@ class TestCode(unittest.TestCase): o = loads(s) self.assertEqual(i, o) - tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, - 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] } + tobj = { 1: 'dflkj', 5: 'sdlkfj', 'float': 1, + 'largeint': 1<<342, 'list': [ 1, 2, 'str', 'str' ] } out = dumps(tobj) self.assertEqual(tobj, loads(out)) @@ -641,7 +660,7 @@ class TestCode(unittest.TestCase): self.assertEqual(ac.loads(ac.dumps(o)), v) self.assertRaises(TypeError, ac.dumps, Bar()) - v = u'oiejfd' + v = b'oiejfd' o = Baz() o.s = v @@ -652,7 +671,7 @@ class TestCode(unittest.TestCase): self.assertRaises(TypeError, dumps, o) def test_loads(self): - self.assertRaises(ValueError, loads, '\x00\x02\x00') + self.assertRaises(ValueError, loads, b'\x00\x02\x00') def test_nodict(self): '''Verify that ASN1Coder does not support dict.''' diff --git a/requirements.txt b/requirements.txt index 932a895..451b8c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +coverage mock