Browse Source

make this work on python 3...

main
John-Mark Gurney 4 years ago
parent
commit
b635cf31cf
3 changed files with 113 additions and 89 deletions
  1. +5
    -1
      Makefile
  2. +107
    -88
      pasn1.py
  3. +1
    -0
      requirements.txt

+ 5
- 1
Makefile View File

@@ -1,4 +1,8 @@
MODULES=pasn1.py MODULES=pasn1.py
VIRTUALENV?=virtualenv-3.7


test: test:
(echo $(MODULES) | entr sh -c 'python -m coverage run -m unittest $(basename $(MODULES)) && coverage report -m')
(echo $(MODULES) | entr sh -c 'python -m coverage run -m unittest $(basename $(MODULES)) && coverage report --omit=p/\* -m -i')

env:
($(VIRTUALENV) p && . ./p/bin/activate && pip install -r requirements.txt)

+ 107
- 88
pasn1.py View File

@@ -43,18 +43,22 @@ __license__ = '2-clause BSD license'
# either expressed or implied, of the Project. # either expressed or implied, of the Project.


import datetime import datetime
import functools
import math import math
import os import os
import sys import sys
import unittest import unittest


if sys.version_info.major != 3: # pragma: no cover
raise RuntimeError('this module only supports python 3')

__all__ = [ 'dumps', 'loads', 'ASN1Coder' ] __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]


def _numtostr(n):
def _numtobytes(n):
hs = '%x' % n hs = '%x' % n
if len(hs) & 1 == 1: if len(hs) & 1 == 1:
hs = '0' + hs hs = '0' + hs
bs = hs.decode('hex')
bs = bytes.fromhex(hs)


return bs return bs


@@ -63,30 +67,30 @@ def _encodelen(l):
represents l per ASN.1 rules.''' represents l per ASN.1 rules.'''


if l < 128: if l < 128:
return chr(l)
return bytes([l])


bs = _numtostr(l)
return chr(len(bs) | 0x80) + bs
bs = _numtobytes(l)
return bytes([len(bs) | 0x80]) + bs


def _decodelen(d, pos=0): def _decodelen(d, pos=0):
'''Returns the length, and number of bytes required.''' '''Returns the length, and number of bytes required.'''


odp = ord(d[pos])
odp = d[pos]
if odp < 128: if odp < 128:
return ord(d[pos]), 1
return d[pos], 1
else: else:
l = odp & 0x7f l = odp & 0x7f
return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
return int(d[pos + 1:pos + 1 + l].hex(), 16), l + 1


class Test_codelen(unittest.TestCase): class Test_codelen(unittest.TestCase):
_testdata = [ _testdata = [
(2, '\x02'),
(127, '\x7f'),
(128, '\x81\x80'),
(255, '\x81\xff'),
(256, '\x82\x01\x00'),
(65536-1, '\x82\xff\xff'),
(65536, '\x83\x01\x00\x00'),
(2, b'\x02'),
(127, b'\x7f'),
(128, b'\x81\x80'),
(255, b'\x81\xff'),
(256, b'\x82\x01\x00'),
(65536-1, b'\x82\xff\xff'),
(65536, b'\x83\x01\x00\x00'),
] ]


def test_el(self): def test_el(self):
@@ -101,7 +105,7 @@ def _splitfloat(f):
m *= 2 m *= 2
e -= 1 e -= 1


return m, e
return math.trunc(m), e


class TestSplitFloat(unittest.TestCase): class TestSplitFloat(unittest.TestCase):
def test_sf(self): def test_sf(self):
@@ -124,7 +128,7 @@ class ASN1Coder(object):
'int': compatible w/ int 'int': compatible w/ int
'list': __iter__ 'list': __iter__
'set': __iter__ 'set': __iter__
'bytes': __str__
'bytes': __str__ # XXX what is correct here
'null': no method needed 'null': no method needed
'unicode': encode method returns UTF-8 encoded bytes 'unicode': encode method returns UTF-8 encoded bytes
'datetime': strftime and microsecond 'datetime': strftime and microsecond
@@ -137,28 +141,27 @@ class ASN1Coder(object):
float: 'float', float: 'float',
int: 'int', int: 'int',
list: 'list', list: 'list',
long: 'int',
set: 'set', set: 'set',
str: 'bytes',
bytes: 'bytes',
type(None): 'null', type(None): 'null',
unicode: 'unicode',
str: 'unicode',
#decimal.Decimal: 'float', #decimal.Decimal: 'float',
datetime.datetime: 'datetime', datetime.datetime: 'datetime',
#datetime.timedelta: 'timedelta', #datetime.timedelta: 'timedelta',
} }
_tagmap = { _tagmap = {
'\x01': 'bool',
'\x02': 'int',
'\x04': 'bytes',
'\x05': 'null',
'\x09': 'float',
'\x0c': 'unicode',
'\x18': 'datetime',
'\x30': 'list',
'\x31': 'set',
b'\x01': 'bool',
b'\x02': 'int',
b'\x04': 'bytes',
b'\x05': 'null',
b'\x09': 'float',
b'\x0c': 'unicode',
b'\x18': 'datetime',
b'\x30': 'list',
b'\x31': 'set',
} }


_typetag = dict((v, k) for k, v in _tagmap.iteritems())
_typetag = dict((v, k) for k, v in _tagmap.items())


@staticmethod @staticmethod
def enc_int(obj, **kwargs): def enc_int(obj, **kwargs):
@@ -170,10 +173,10 @@ class ASN1Coder(object):
if obj < 0: if obj < 0:
obj += 1 << (l * 8) # twos-complement conversion obj += 1 << (l * 8) # twos-complement conversion


v = _numtostr(obj)
v = _numtobytes(obj)
if len(v) != l: if len(v) != l:
# XXX - is this a problem for signed values? # XXX - is this a problem for signed values?
v = '\x00' + v # add sign octect
v = b'\x00' + v # add sign octect


return _encodelen(l) + v return _encodelen(l) + v


@@ -182,7 +185,7 @@ class ASN1Coder(object):
if pos == end: if pos == end:
return 0, end return 0, end


v = int(d[pos:end].encode('hex'), 16)
v = int(bytes.hex(d[pos:end]), 16)
av = 1 << ((end - pos) * 8 - 1) # sign bit av = 1 << ((end - pos) * 8 - 1) # sign bit
if v > av: if v > av:
v -= av * 2 # twos-complement conversion v -= av * 2 # twos-complement conversion
@@ -191,7 +194,7 @@ class ASN1Coder(object):


@staticmethod @staticmethod
def enc_bool(obj, **kwargs): def enc_bool(obj, **kwargs):
return '\x01' + ('\xff' if obj else '\x00')
return b'\x01' + (b'\xff' if obj else b'\x00')


def dec_bool(self, d, pos, end): def dec_bool(self, d, pos, end):
v = self.dec_int(d, pos, end)[0] v = self.dec_int(d, pos, end)[0]
@@ -202,14 +205,14 @@ class ASN1Coder(object):


@staticmethod @staticmethod
def enc_null(obj, **kwargs): def enc_null(obj, **kwargs):
return '\x00'
return b'\x00'


@staticmethod @staticmethod
def dec_null(d, pos, end): def dec_null(d, pos, end):
return None, end return None, end


def enc_list(self, obj, **kwargs): def enc_list(self, obj, **kwargs):
r = ''.join(self.dumps(x, **kwargs) for x in obj)
r = b''.join(self.dumps(x, **kwargs) for x in obj)
return _encodelen(len(r)) + r return _encodelen(len(r)) + r


def dec_list(self, d, pos, end): def dec_list(self, d, pos, end):
@@ -250,17 +253,17 @@ class ASN1Coder(object):
def enc_float(obj, **kwargs): def enc_float(obj, **kwargs):
s = math.copysign(1, obj) s = math.copysign(1, obj)
if math.isnan(obj): if math.isnan(obj):
return _encodelen(1) + chr(0b01000010)
return _encodelen(1) + bytes([0b01000010])
elif math.isinf(obj): elif math.isinf(obj):
if s == 1: if s == 1:
return _encodelen(1) + chr(0b01000000)
return _encodelen(1) + bytes([0b01000000])
else: else:
return _encodelen(1) + chr(0b01000001)
return _encodelen(1) + bytes([0b01000001])
elif obj == 0: elif obj == 0:
if s == 1: if s == 1:
return _encodelen(0) return _encodelen(0)
else: else:
return _encodelen(1) + chr(0b01000011)
return _encodelen(1) + bytes([0b01000011])


m, e = _splitfloat(obj) m, e = _splitfloat(obj)


@@ -279,17 +282,17 @@ class ASN1Coder(object):
e += 256**el # convert negative to twos-complement e += 256**el # convert negative to twos-complement


v = el - 1 v = el - 1
encexp = _numtostr(e)
encexp = _numtobytes(e)


val |= v val |= v
r = chr(val) + encexp + _numtostr(m)
r = bytes([val]) + encexp + _numtobytes(m)
return _encodelen(len(r)) + r return _encodelen(len(r)) + r


def dec_float(self, d, pos, end): def dec_float(self, d, pos, end):
if pos == end: if pos == end:
return float(0), end return float(0), end


v = ord(d[pos])
v = d[pos]
if v == 0b01000000: if v == 0b01000000:
return float('inf'), end return float('inf'), end
elif v == 0b01000001: elif v == 0b01000001:
@@ -314,7 +317,7 @@ class ASN1Coder(object):


exp = self.dec_int(d, pexp, eexp)[0] exp = self.dec_int(d, pexp, eexp)[0]


n = float(int(d[eexp:end].encode('hex'), 16))
n = float(int(bytes.hex(d[eexp:end]), 16))
r = n * 2 ** exp r = n * 2 ** exp
if v & 0b1000000: if v & 0b1000000:
r = -r r = -r
@@ -340,7 +343,7 @@ class ASN1Coder(object):
pass pass


if self.coerce is None: if self.coerce is None:
raise TypeError('unhandled object: %s' % `obj`)
raise TypeError('unhandled object: %s' % repr(obj))


tf, obj = self.coerce(obj) tf, obj = self.coerce(obj)


@@ -348,7 +351,7 @@ class ASN1Coder(object):
return self._typetag[tf] + fun(obj, default=default) return self._typetag[tf] + fun(obj, default=default)


def _loads(self, data, pos, end): def _loads(self, data, pos, end):
tag = data[pos]
tag = data[pos:pos + 1]
l, b = _decodelen(data, pos + 1) l, b = _decodelen(data, pos + 1)
if len(data) < pos + 1 + b + l: if len(data) < pos + 1 + b + l:
raise ValueError('string not long enough') raise ValueError('string not long enough')
@@ -365,12 +368,12 @@ class ASN1Coder(object):
if obj.microsecond: if obj.microsecond:
ts += ('.%06d' % obj.microsecond).rstrip('0') ts += ('.%06d' % obj.microsecond).rstrip('0')
ts += 'Z' ts += 'Z'
return _encodelen(len(ts)) + ts
return _encodelen(len(ts)) + ts.encode('utf-8')


def dec_datetime(self, data, pos, end): def dec_datetime(self, data, pos, end):
ts = data[pos:end]
if ts[-1] != 'Z':
raise ValueError('last character must be Z')
ts = data[pos:end].decode('ascii')
if ts[-1:] != 'Z':
raise ValueError('last character must be Z, was: %s' % repr(ts[-1]))


# Real bug is in strptime, but work around it here. # Real bug is in strptime, but work around it here.
if ' ' in ts: if ' ' in ts:
@@ -411,14 +414,14 @@ class ASN1DictCoder(ASN1Coder):
_typemap = ASN1Coder._typemap.copy() _typemap = ASN1Coder._typemap.copy()
_typemap[dict] = 'dict' _typemap[dict] = 'dict'
_tagmap = ASN1Coder._tagmap.copy() _tagmap = ASN1Coder._tagmap.copy()
_tagmap['\xe0'] = 'dict'
_typetag = dict((v, k) for k, v in _tagmap.iteritems())
_tagmap[b'\xe0'] = 'dict'
_typetag = dict((v, k) for k, v in _tagmap.items())


def enc_dict(self, obj, **kwargs): def enc_dict(self, obj, **kwargs):
#it = list(obj.iteritems()) #it = list(obj.iteritems())
#it.sort() #it.sort()
r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
obj.iteritems())
r = b''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
obj.items())
return _encodelen(len(r)) + r return _encodelen(len(r)) + r


def dec_dict(self, d, pos, end): def dec_dict(self, d, pos, end):
@@ -441,11 +444,29 @@ _coder = ASN1DictCoder()
dumps = _coder.dumps dumps = _coder.dumps
loads = _coder.loads loads = _coder.loads


def cmp(a, b):
return (a > b) - (a < b)

def universal_cmp(a, b):
# Because Python 3 sucks, make this function that
# orders first based upon type, then on value.
if type(a) == type(b):
if isinstance(a, (tuple, list)):
for a, b in zip(a, b):
if a != b:
return universal_cmp(a, b)

#return cmp(len(a), len(b))
else:
return cmp(a, b)
else:
return id(type(a)) < id(type(b))

def deeptypecmp(obj, o): def deeptypecmp(obj, o):
#print 'dtc:', `obj`, `o`
#print('dtc:', repr(obj), repr(o))
if type(obj) != type(o): if type(obj) != type(o):
return False return False
if type(obj) in (str, unicode):
if type(obj) is str:
return True return True


if type(obj) in (list, set): if type(obj) in (list, set):
@@ -454,10 +475,8 @@ def deeptypecmp(obj, o):
return False return False


if type(obj) in (dict,): if type(obj) in (dict,):
itms = obj.items()
itms.sort()
nitms = o.items()
nitms.sort()
itms = sorted(obj.items(), key=functools.cmp_to_key(universal_cmp))
nitms = sorted(o.items(), key=functools.cmp_to_key(universal_cmp))
for (k, v), (nk, nv) in zip(itms, nitms): for (k, v), (nk, nv) in zip(itms, nitms):
if not deeptypecmp(k, nk): if not deeptypecmp(k, nk):
return False return False
@@ -474,15 +493,15 @@ class Test_deeptypecmp(unittest.TestCase):


def test_false(self): def test_false(self):
for i in (([[]], [{}]), ([1], ['str']), ([], set()), for i in (([[]], [{}]), ([1], ['str']), ([], set()),
({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
({1: 2, 5: 'sdlkfj'}, {1: 2, 5: b'sdlkfj'}),
({1: 2, 'sdlkfj': 5}, {1: 2, b'sdlkfj': 5}),
): ):
self.assertFalse(deeptypecmp(*i))
self.assertFalse(deeptypecmp(*i), '%s != %s' % (i[0], i[1]))


def genfailures(obj): def genfailures(obj):
s = dumps(obj) s = dumps(obj)
for i in xrange(len(s)):
for j in (chr(x) for x in xrange(256)):
for i in range(len(s)):
for j in (bytes([x]) for x in range(256)):
ts = s[:i] + j + s[i + 1:] ts = s[:i] + j + s[i + 1:]
if ts == s: if ts == s:
continue continue
@@ -493,24 +512,24 @@ def genfailures(obj):
except (ValueError, KeyError, IndexError, TypeError): except (ValueError, KeyError, IndexError, TypeError):
pass pass
else: # pragma: no cover else: # pragma: no cover
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, s[i]))


class TestCode(unittest.TestCase): class TestCode(unittest.TestCase):
def test_primv(self): def test_primv(self):
self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
self.assertEqual(dumps(5), '020105'.decode('hex'))
self.assertEqual(dumps(128), '02020080'.decode('hex'))
self.assertEqual(dumps(256), '02020100'.decode('hex'))
self.assertEqual(dumps(-257), bytes.fromhex('0202feff'))
self.assertEqual(dumps(-256), bytes.fromhex('0202ff00'))
self.assertEqual(dumps(-255), bytes.fromhex('0202ff01'))
self.assertEqual(dumps(-1), bytes.fromhex('0201ff'))
self.assertEqual(dumps(5), bytes.fromhex('020105'))
self.assertEqual(dumps(128), bytes.fromhex('02020080'))
self.assertEqual(dumps(256), bytes.fromhex('02020100'))


self.assertEqual(dumps(False), '010100'.decode('hex'))
self.assertEqual(dumps(True), '0101ff'.decode('hex'))
self.assertEqual(dumps(False), bytes.fromhex('010100'))
self.assertEqual(dumps(True), bytes.fromhex('0101ff'))


self.assertEqual(dumps(None), '0500'.decode('hex'))
self.assertEqual(dumps(None), bytes.fromhex('0500'))


self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
self.assertEqual(dumps(.15625), bytes.fromhex('090380fb05'))


def test_fuzzing(self): def test_fuzzing(self):
# Make sure that when a failure is detected here, that it # Make sure that when a failure is detected here, that it
@@ -539,7 +558,7 @@ class TestCode(unittest.TestCase):
'181632303136303231373136343034372e3035343433367a', #datetime w/ lower z '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
'181632303136313220383031303933302e3931353133385a', #datetime w/ space '181632303136313220383031303933302e3931353133385a', #datetime w/ space
]: ]:
self.assertRaises(ValueError, loads, v.decode('hex'))
self.assertRaises(ValueError, loads, bytes.fromhex(v))


def test_invalid_floats(self): def test_invalid_floats(self):
import mock import mock
@@ -548,7 +567,7 @@ class TestCode(unittest.TestCase):


def test_consume(self): def test_consume(self):
b = dumps(5) b = dumps(5)
self.assertRaises(ValueError, loads, b + '398473',
self.assertRaises(ValueError, loads, b + b'398473',
consume=True) consume=True)


# XXX - still possible that an internal data member # XXX - still possible that an internal data member
@@ -565,10 +584,10 @@ class TestCode(unittest.TestCase):
def test_cryptoutilasn1(self): def test_cryptoutilasn1(self):
'''Test DER sequences generated by Crypto.Util.asn1.''' '''Test DER sequences generated by Crypto.Util.asn1.'''


for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
('\x05\x00', None),
('\x02\x03\x00\x96I', 38473),
('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
for s, v in [ (b'\x02\x03$\x8a\xf9', 2394873),
(b'\x05\x00', None),
(b'\x02\x03\x00\x96I', 38473),
(b'\x04\x81\xc8' + b'\x00' * 200, b'\x00' * 200),
]: ]:
self.assertEqual(loads(s), v) self.assertEqual(loads(s), v)


@@ -595,7 +614,7 @@ class TestCode(unittest.TestCase):
sys.float_info.max, sys.float_info.min, sys.float_info.max, sys.float_info.min,
float('.15625'), float('.15625'),
'weoifjwef', 'weoifjwef',
u'\U0001f4a9',
'\U0001f4a9',
[], [ 1,2,3 ], [], [ 1,2,3 ],
{}, { 5: 10, 'adfkj': 34 }, {}, { 5: 10, 'adfkj': 34 },
set(), set((1,2,3)), set(), set((1,2,3)),
@@ -609,8 +628,8 @@ class TestCode(unittest.TestCase):
o = loads(s) o = loads(s)
self.assertEqual(i, o) self.assertEqual(i, o)


tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1,
'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
tobj = { 1: 'dflkj', 5: 'sdlkfj', 'float': 1,
'largeint': 1<<342, 'list': [ 1, 2, 'str', 'str' ] }


out = dumps(tobj) out = dumps(tobj)
self.assertEqual(tobj, loads(out)) self.assertEqual(tobj, loads(out))
@@ -641,7 +660,7 @@ class TestCode(unittest.TestCase):
self.assertEqual(ac.loads(ac.dumps(o)), v) self.assertEqual(ac.loads(ac.dumps(o)), v)
self.assertRaises(TypeError, ac.dumps, Bar()) self.assertRaises(TypeError, ac.dumps, Bar())


v = u'oiejfd'
v = b'oiejfd'
o = Baz() o = Baz()
o.s = v o.s = v


@@ -652,7 +671,7 @@ class TestCode(unittest.TestCase):
self.assertRaises(TypeError, dumps, o) self.assertRaises(TypeError, dumps, o)


def test_loads(self): def test_loads(self):
self.assertRaises(ValueError, loads, '\x00\x02\x00')
self.assertRaises(ValueError, loads, b'\x00\x02\x00')


def test_nodict(self): def test_nodict(self):
'''Verify that ASN1Coder does not support dict.''' '''Verify that ASN1Coder does not support dict.'''


+ 1
- 0
requirements.txt View File

@@ -1 +1,2 @@
coverage
mock mock

Loading…
Cancel
Save