Browse Source

make this work on python 3...

main
John-Mark Gurney 4 years ago
parent
commit
b635cf31cf
3 changed files with 113 additions and 89 deletions
  1. +5
    -1
      Makefile
  2. +107
    -88
      pasn1.py
  3. +1
    -0
      requirements.txt

+ 5
- 1
Makefile View File

@@ -1,4 +1,8 @@
MODULES=pasn1.py
VIRTUALENV?=virtualenv-3.7

test:
(echo $(MODULES) | entr sh -c 'python -m coverage run -m unittest $(basename $(MODULES)) && coverage report -m')
(echo $(MODULES) | entr sh -c 'python -m coverage run -m unittest $(basename $(MODULES)) && coverage report --omit=p/\* -m -i')

env:
($(VIRTUALENV) p && . ./p/bin/activate && pip install -r requirements.txt)

+ 107
- 88
pasn1.py View File

@@ -43,18 +43,22 @@ __license__ = '2-clause BSD license'
# either expressed or implied, of the Project.

import datetime
import functools
import math
import os
import sys
import unittest

if sys.version_info.major != 3: # pragma: no cover
raise RuntimeError('this module only supports python 3')

__all__ = [ 'dumps', 'loads', 'ASN1Coder' ]

def _numtostr(n):
def _numtobytes(n):
hs = '%x' % n
if len(hs) & 1 == 1:
hs = '0' + hs
bs = hs.decode('hex')
bs = bytes.fromhex(hs)

return bs

@@ -63,30 +67,30 @@ def _encodelen(l):
represents l per ASN.1 rules.'''

if l < 128:
return chr(l)
return bytes([l])

bs = _numtostr(l)
return chr(len(bs) | 0x80) + bs
bs = _numtobytes(l)
return bytes([len(bs) | 0x80]) + bs

def _decodelen(d, pos=0):
'''Returns the length, and number of bytes required.'''

odp = ord(d[pos])
odp = d[pos]
if odp < 128:
return ord(d[pos]), 1
return d[pos], 1
else:
l = odp & 0x7f
return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
return int(d[pos + 1:pos + 1 + l].hex(), 16), l + 1

class Test_codelen(unittest.TestCase):
_testdata = [
(2, '\x02'),
(127, '\x7f'),
(128, '\x81\x80'),
(255, '\x81\xff'),
(256, '\x82\x01\x00'),
(65536-1, '\x82\xff\xff'),
(65536, '\x83\x01\x00\x00'),
(2, b'\x02'),
(127, b'\x7f'),
(128, b'\x81\x80'),
(255, b'\x81\xff'),
(256, b'\x82\x01\x00'),
(65536-1, b'\x82\xff\xff'),
(65536, b'\x83\x01\x00\x00'),
]

def test_el(self):
@@ -101,7 +105,7 @@ def _splitfloat(f):
m *= 2
e -= 1

return m, e
return math.trunc(m), e

class TestSplitFloat(unittest.TestCase):
def test_sf(self):
@@ -124,7 +128,7 @@ class ASN1Coder(object):
'int': compatible w/ int
'list': __iter__
'set': __iter__
'bytes': __str__
'bytes': __str__ # XXX what is correct here
'null': no method needed
'unicode': encode method returns UTF-8 encoded bytes
'datetime': strftime and microsecond
@@ -137,28 +141,27 @@ class ASN1Coder(object):
float: 'float',
int: 'int',
list: 'list',
long: 'int',
set: 'set',
str: 'bytes',
bytes: 'bytes',
type(None): 'null',
unicode: 'unicode',
str: 'unicode',
#decimal.Decimal: 'float',
datetime.datetime: 'datetime',
#datetime.timedelta: 'timedelta',
}
_tagmap = {
'\x01': 'bool',
'\x02': 'int',
'\x04': 'bytes',
'\x05': 'null',
'\x09': 'float',
'\x0c': 'unicode',
'\x18': 'datetime',
'\x30': 'list',
'\x31': 'set',
b'\x01': 'bool',
b'\x02': 'int',
b'\x04': 'bytes',
b'\x05': 'null',
b'\x09': 'float',
b'\x0c': 'unicode',
b'\x18': 'datetime',
b'\x30': 'list',
b'\x31': 'set',
}

_typetag = dict((v, k) for k, v in _tagmap.iteritems())
_typetag = dict((v, k) for k, v in _tagmap.items())

@staticmethod
def enc_int(obj, **kwargs):
@@ -170,10 +173,10 @@ class ASN1Coder(object):
if obj < 0:
obj += 1 << (l * 8) # twos-complement conversion

v = _numtostr(obj)
v = _numtobytes(obj)
if len(v) != l:
# XXX - is this a problem for signed values?
v = '\x00' + v # add sign octect
v = b'\x00' + v # add sign octect

return _encodelen(l) + v

@@ -182,7 +185,7 @@ class ASN1Coder(object):
if pos == end:
return 0, end

v = int(d[pos:end].encode('hex'), 16)
v = int(bytes.hex(d[pos:end]), 16)
av = 1 << ((end - pos) * 8 - 1) # sign bit
if v > av:
v -= av * 2 # twos-complement conversion
@@ -191,7 +194,7 @@ class ASN1Coder(object):

@staticmethod
def enc_bool(obj, **kwargs):
return '\x01' + ('\xff' if obj else '\x00')
return b'\x01' + (b'\xff' if obj else b'\x00')

def dec_bool(self, d, pos, end):
v = self.dec_int(d, pos, end)[0]
@@ -202,14 +205,14 @@ class ASN1Coder(object):

@staticmethod
def enc_null(obj, **kwargs):
return '\x00'
return b'\x00'

@staticmethod
def dec_null(d, pos, end):
return None, end

def enc_list(self, obj, **kwargs):
r = ''.join(self.dumps(x, **kwargs) for x in obj)
r = b''.join(self.dumps(x, **kwargs) for x in obj)
return _encodelen(len(r)) + r

def dec_list(self, d, pos, end):
@@ -250,17 +253,17 @@ class ASN1Coder(object):
def enc_float(obj, **kwargs):
s = math.copysign(1, obj)
if math.isnan(obj):
return _encodelen(1) + chr(0b01000010)
return _encodelen(1) + bytes([0b01000010])
elif math.isinf(obj):
if s == 1:
return _encodelen(1) + chr(0b01000000)
return _encodelen(1) + bytes([0b01000000])
else:
return _encodelen(1) + chr(0b01000001)
return _encodelen(1) + bytes([0b01000001])
elif obj == 0:
if s == 1:
return _encodelen(0)
else:
return _encodelen(1) + chr(0b01000011)
return _encodelen(1) + bytes([0b01000011])

m, e = _splitfloat(obj)

@@ -279,17 +282,17 @@ class ASN1Coder(object):
e += 256**el # convert negative to twos-complement

v = el - 1
encexp = _numtostr(e)
encexp = _numtobytes(e)

val |= v
r = chr(val) + encexp + _numtostr(m)
r = bytes([val]) + encexp + _numtobytes(m)
return _encodelen(len(r)) + r

def dec_float(self, d, pos, end):
if pos == end:
return float(0), end

v = ord(d[pos])
v = d[pos]
if v == 0b01000000:
return float('inf'), end
elif v == 0b01000001:
@@ -314,7 +317,7 @@ class ASN1Coder(object):

exp = self.dec_int(d, pexp, eexp)[0]

n = float(int(d[eexp:end].encode('hex'), 16))
n = float(int(bytes.hex(d[eexp:end]), 16))
r = n * 2 ** exp
if v & 0b1000000:
r = -r
@@ -340,7 +343,7 @@ class ASN1Coder(object):
pass

if self.coerce is None:
raise TypeError('unhandled object: %s' % `obj`)
raise TypeError('unhandled object: %s' % repr(obj))

tf, obj = self.coerce(obj)

@@ -348,7 +351,7 @@ class ASN1Coder(object):
return self._typetag[tf] + fun(obj, default=default)

def _loads(self, data, pos, end):
tag = data[pos]
tag = data[pos:pos + 1]
l, b = _decodelen(data, pos + 1)
if len(data) < pos + 1 + b + l:
raise ValueError('string not long enough')
@@ -365,12 +368,12 @@ class ASN1Coder(object):
if obj.microsecond:
ts += ('.%06d' % obj.microsecond).rstrip('0')
ts += 'Z'
return _encodelen(len(ts)) + ts
return _encodelen(len(ts)) + ts.encode('utf-8')

def dec_datetime(self, data, pos, end):
ts = data[pos:end]
if ts[-1] != 'Z':
raise ValueError('last character must be Z')
ts = data[pos:end].decode('ascii')
if ts[-1:] != 'Z':
raise ValueError('last character must be Z, was: %s' % repr(ts[-1]))

# Real bug is in strptime, but work around it here.
if ' ' in ts:
@@ -411,14 +414,14 @@ class ASN1DictCoder(ASN1Coder):
_typemap = ASN1Coder._typemap.copy()
_typemap[dict] = 'dict'
_tagmap = ASN1Coder._tagmap.copy()
_tagmap['\xe0'] = 'dict'
_typetag = dict((v, k) for k, v in _tagmap.iteritems())
_tagmap[b'\xe0'] = 'dict'
_typetag = dict((v, k) for k, v in _tagmap.items())

def enc_dict(self, obj, **kwargs):
#it = list(obj.iteritems())
#it.sort()
r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
obj.iteritems())
r = b''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
obj.items())
return _encodelen(len(r)) + r

def dec_dict(self, d, pos, end):
@@ -441,11 +444,29 @@ _coder = ASN1DictCoder()
dumps = _coder.dumps
loads = _coder.loads

def cmp(a, b):
return (a > b) - (a < b)

def universal_cmp(a, b):
# Because Python 3 sucks, make this function that
# orders first based upon type, then on value.
if type(a) == type(b):
if isinstance(a, (tuple, list)):
for a, b in zip(a, b):
if a != b:
return universal_cmp(a, b)

#return cmp(len(a), len(b))
else:
return cmp(a, b)
else:
return id(type(a)) < id(type(b))

def deeptypecmp(obj, o):
#print 'dtc:', `obj`, `o`
#print('dtc:', repr(obj), repr(o))
if type(obj) != type(o):
return False
if type(obj) in (str, unicode):
if type(obj) is str:
return True

if type(obj) in (list, set):
@@ -454,10 +475,8 @@ def deeptypecmp(obj, o):
return False

if type(obj) in (dict,):
itms = obj.items()
itms.sort()
nitms = o.items()
nitms.sort()
itms = sorted(obj.items(), key=functools.cmp_to_key(universal_cmp))
nitms = sorted(o.items(), key=functools.cmp_to_key(universal_cmp))
for (k, v), (nk, nv) in zip(itms, nitms):
if not deeptypecmp(k, nk):
return False
@@ -474,15 +493,15 @@ class Test_deeptypecmp(unittest.TestCase):

def test_false(self):
for i in (([[]], [{}]), ([1], ['str']), ([], set()),
({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
({1: 2, 5: 'sdlkfj'}, {1: 2, 5: b'sdlkfj'}),
({1: 2, 'sdlkfj': 5}, {1: 2, b'sdlkfj': 5}),
):
self.assertFalse(deeptypecmp(*i))
self.assertFalse(deeptypecmp(*i), '%s != %s' % (i[0], i[1]))

def genfailures(obj):
s = dumps(obj)
for i in xrange(len(s)):
for j in (chr(x) for x in xrange(256)):
for i in range(len(s)):
for j in (bytes([x]) for x in range(256)):
ts = s[:i] + j + s[i + 1:]
if ts == s:
continue
@@ -493,24 +512,24 @@ def genfailures(obj):
except (ValueError, KeyError, IndexError, TypeError):
pass
else: # pragma: no cover
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, s[i]))

class TestCode(unittest.TestCase):
def test_primv(self):
self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
self.assertEqual(dumps(5), '020105'.decode('hex'))
self.assertEqual(dumps(128), '02020080'.decode('hex'))
self.assertEqual(dumps(256), '02020100'.decode('hex'))
self.assertEqual(dumps(-257), bytes.fromhex('0202feff'))
self.assertEqual(dumps(-256), bytes.fromhex('0202ff00'))
self.assertEqual(dumps(-255), bytes.fromhex('0202ff01'))
self.assertEqual(dumps(-1), bytes.fromhex('0201ff'))
self.assertEqual(dumps(5), bytes.fromhex('020105'))
self.assertEqual(dumps(128), bytes.fromhex('02020080'))
self.assertEqual(dumps(256), bytes.fromhex('02020100'))

self.assertEqual(dumps(False), '010100'.decode('hex'))
self.assertEqual(dumps(True), '0101ff'.decode('hex'))
self.assertEqual(dumps(False), bytes.fromhex('010100'))
self.assertEqual(dumps(True), bytes.fromhex('0101ff'))

self.assertEqual(dumps(None), '0500'.decode('hex'))
self.assertEqual(dumps(None), bytes.fromhex('0500'))

self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
self.assertEqual(dumps(.15625), bytes.fromhex('090380fb05'))

def test_fuzzing(self):
# Make sure that when a failure is detected here, that it
@@ -539,7 +558,7 @@ class TestCode(unittest.TestCase):
'181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
'181632303136313220383031303933302e3931353133385a', #datetime w/ space
]:
self.assertRaises(ValueError, loads, v.decode('hex'))
self.assertRaises(ValueError, loads, bytes.fromhex(v))

def test_invalid_floats(self):
import mock
@@ -548,7 +567,7 @@ class TestCode(unittest.TestCase):

def test_consume(self):
b = dumps(5)
self.assertRaises(ValueError, loads, b + '398473',
self.assertRaises(ValueError, loads, b + b'398473',
consume=True)

# XXX - still possible that an internal data member
@@ -565,10 +584,10 @@ class TestCode(unittest.TestCase):
def test_cryptoutilasn1(self):
'''Test DER sequences generated by Crypto.Util.asn1.'''

for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
('\x05\x00', None),
('\x02\x03\x00\x96I', 38473),
('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
for s, v in [ (b'\x02\x03$\x8a\xf9', 2394873),
(b'\x05\x00', None),
(b'\x02\x03\x00\x96I', 38473),
(b'\x04\x81\xc8' + b'\x00' * 200, b'\x00' * 200),
]:
self.assertEqual(loads(s), v)

@@ -595,7 +614,7 @@ class TestCode(unittest.TestCase):
sys.float_info.max, sys.float_info.min,
float('.15625'),
'weoifjwef',
u'\U0001f4a9',
'\U0001f4a9',
[], [ 1,2,3 ],
{}, { 5: 10, 'adfkj': 34 },
set(), set((1,2,3)),
@@ -609,8 +628,8 @@ class TestCode(unittest.TestCase):
o = loads(s)
self.assertEqual(i, o)

tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1,
'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
tobj = { 1: 'dflkj', 5: 'sdlkfj', 'float': 1,
'largeint': 1<<342, 'list': [ 1, 2, 'str', 'str' ] }

out = dumps(tobj)
self.assertEqual(tobj, loads(out))
@@ -641,7 +660,7 @@ class TestCode(unittest.TestCase):
self.assertEqual(ac.loads(ac.dumps(o)), v)
self.assertRaises(TypeError, ac.dumps, Bar())

v = u'oiejfd'
v = b'oiejfd'
o = Baz()
o.s = v

@@ -652,7 +671,7 @@ class TestCode(unittest.TestCase):
self.assertRaises(TypeError, dumps, o)

def test_loads(self):
self.assertRaises(ValueError, loads, '\x00\x02\x00')
self.assertRaises(ValueError, loads, b'\x00\x02\x00')

def test_nodict(self):
'''Verify that ASN1Coder does not support dict.'''


+ 1
- 0
requirements.txt View File

@@ -1 +1,2 @@
coverage
mock

Loading…
Cancel
Save