|
|
@@ -43,18 +43,22 @@ __license__ = '2-clause BSD license' |
|
|
|
# either expressed or implied, of the Project. |
|
|
|
|
|
|
|
import datetime |
|
|
|
import functools |
|
|
|
import math |
|
|
|
import os |
|
|
|
import sys |
|
|
|
import unittest |
|
|
|
|
|
|
|
if sys.version_info.major != 3: # pragma: no cover |
|
|
|
raise RuntimeError('this module only supports python 3') |
|
|
|
|
|
|
|
__all__ = [ 'dumps', 'loads', 'ASN1Coder' ] |
|
|
|
|
|
|
|
def _numtostr(n): |
|
|
|
def _numtobytes(n): |
|
|
|
hs = '%x' % n |
|
|
|
if len(hs) & 1 == 1: |
|
|
|
hs = '0' + hs |
|
|
|
bs = hs.decode('hex') |
|
|
|
bs = bytes.fromhex(hs) |
|
|
|
|
|
|
|
return bs |
|
|
|
|
|
|
@@ -63,30 +67,30 @@ def _encodelen(l): |
|
|
|
represents l per ASN.1 rules.''' |
|
|
|
|
|
|
|
if l < 128: |
|
|
|
return chr(l) |
|
|
|
return bytes([l]) |
|
|
|
|
|
|
|
bs = _numtostr(l) |
|
|
|
return chr(len(bs) | 0x80) + bs |
|
|
|
bs = _numtobytes(l) |
|
|
|
return bytes([len(bs) | 0x80]) + bs |
|
|
|
|
|
|
|
def _decodelen(d, pos=0): |
|
|
|
'''Returns the length, and number of bytes required.''' |
|
|
|
|
|
|
|
odp = ord(d[pos]) |
|
|
|
odp = d[pos] |
|
|
|
if odp < 128: |
|
|
|
return ord(d[pos]), 1 |
|
|
|
return d[pos], 1 |
|
|
|
else: |
|
|
|
l = odp & 0x7f |
|
|
|
return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1 |
|
|
|
return int(d[pos + 1:pos + 1 + l].hex(), 16), l + 1 |
|
|
|
|
|
|
|
class Test_codelen(unittest.TestCase): |
|
|
|
_testdata = [ |
|
|
|
(2, '\x02'), |
|
|
|
(127, '\x7f'), |
|
|
|
(128, '\x81\x80'), |
|
|
|
(255, '\x81\xff'), |
|
|
|
(256, '\x82\x01\x00'), |
|
|
|
(65536-1, '\x82\xff\xff'), |
|
|
|
(65536, '\x83\x01\x00\x00'), |
|
|
|
(2, b'\x02'), |
|
|
|
(127, b'\x7f'), |
|
|
|
(128, b'\x81\x80'), |
|
|
|
(255, b'\x81\xff'), |
|
|
|
(256, b'\x82\x01\x00'), |
|
|
|
(65536-1, b'\x82\xff\xff'), |
|
|
|
(65536, b'\x83\x01\x00\x00'), |
|
|
|
] |
|
|
|
|
|
|
|
def test_el(self): |
|
|
@@ -101,7 +105,7 @@ def _splitfloat(f): |
|
|
|
m *= 2 |
|
|
|
e -= 1 |
|
|
|
|
|
|
|
return m, e |
|
|
|
return math.trunc(m), e |
|
|
|
|
|
|
|
class TestSplitFloat(unittest.TestCase): |
|
|
|
def test_sf(self): |
|
|
@@ -124,7 +128,7 @@ class ASN1Coder(object): |
|
|
|
'int': compatible w/ int |
|
|
|
'list': __iter__ |
|
|
|
'set': __iter__ |
|
|
|
'bytes': __str__ |
|
|
|
'bytes': __str__ # XXX what is correct here |
|
|
|
'null': no method needed |
|
|
|
'unicode': encode method returns UTF-8 encoded bytes |
|
|
|
'datetime': strftime and microsecond |
|
|
@@ -137,28 +141,27 @@ class ASN1Coder(object): |
|
|
|
float: 'float', |
|
|
|
int: 'int', |
|
|
|
list: 'list', |
|
|
|
long: 'int', |
|
|
|
set: 'set', |
|
|
|
str: 'bytes', |
|
|
|
bytes: 'bytes', |
|
|
|
type(None): 'null', |
|
|
|
unicode: 'unicode', |
|
|
|
str: 'unicode', |
|
|
|
#decimal.Decimal: 'float', |
|
|
|
datetime.datetime: 'datetime', |
|
|
|
#datetime.timedelta: 'timedelta', |
|
|
|
} |
|
|
|
_tagmap = { |
|
|
|
'\x01': 'bool', |
|
|
|
'\x02': 'int', |
|
|
|
'\x04': 'bytes', |
|
|
|
'\x05': 'null', |
|
|
|
'\x09': 'float', |
|
|
|
'\x0c': 'unicode', |
|
|
|
'\x18': 'datetime', |
|
|
|
'\x30': 'list', |
|
|
|
'\x31': 'set', |
|
|
|
b'\x01': 'bool', |
|
|
|
b'\x02': 'int', |
|
|
|
b'\x04': 'bytes', |
|
|
|
b'\x05': 'null', |
|
|
|
b'\x09': 'float', |
|
|
|
b'\x0c': 'unicode', |
|
|
|
b'\x18': 'datetime', |
|
|
|
b'\x30': 'list', |
|
|
|
b'\x31': 'set', |
|
|
|
} |
|
|
|
|
|
|
|
_typetag = dict((v, k) for k, v in _tagmap.iteritems()) |
|
|
|
_typetag = dict((v, k) for k, v in _tagmap.items()) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def enc_int(obj, **kwargs): |
|
|
@@ -170,10 +173,10 @@ class ASN1Coder(object): |
|
|
|
if obj < 0: |
|
|
|
obj += 1 << (l * 8) # twos-complement conversion |
|
|
|
|
|
|
|
v = _numtostr(obj) |
|
|
|
v = _numtobytes(obj) |
|
|
|
if len(v) != l: |
|
|
|
# XXX - is this a problem for signed values? |
|
|
|
v = '\x00' + v # add sign octect |
|
|
|
v = b'\x00' + v # add sign octect |
|
|
|
|
|
|
|
return _encodelen(l) + v |
|
|
|
|
|
|
@@ -182,7 +185,7 @@ class ASN1Coder(object): |
|
|
|
if pos == end: |
|
|
|
return 0, end |
|
|
|
|
|
|
|
v = int(d[pos:end].encode('hex'), 16) |
|
|
|
v = int(bytes.hex(d[pos:end]), 16) |
|
|
|
av = 1 << ((end - pos) * 8 - 1) # sign bit |
|
|
|
if v > av: |
|
|
|
v -= av * 2 # twos-complement conversion |
|
|
@@ -191,7 +194,7 @@ class ASN1Coder(object): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def enc_bool(obj, **kwargs): |
|
|
|
return '\x01' + ('\xff' if obj else '\x00') |
|
|
|
return b'\x01' + (b'\xff' if obj else b'\x00') |
|
|
|
|
|
|
|
def dec_bool(self, d, pos, end): |
|
|
|
v = self.dec_int(d, pos, end)[0] |
|
|
@@ -202,14 +205,14 @@ class ASN1Coder(object): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def enc_null(obj, **kwargs): |
|
|
|
return '\x00' |
|
|
|
return b'\x00' |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def dec_null(d, pos, end): |
|
|
|
return None, end |
|
|
|
|
|
|
|
def enc_list(self, obj, **kwargs): |
|
|
|
r = ''.join(self.dumps(x, **kwargs) for x in obj) |
|
|
|
r = b''.join(self.dumps(x, **kwargs) for x in obj) |
|
|
|
return _encodelen(len(r)) + r |
|
|
|
|
|
|
|
def dec_list(self, d, pos, end): |
|
|
@@ -250,17 +253,17 @@ class ASN1Coder(object): |
|
|
|
def enc_float(obj, **kwargs): |
|
|
|
s = math.copysign(1, obj) |
|
|
|
if math.isnan(obj): |
|
|
|
return _encodelen(1) + chr(0b01000010) |
|
|
|
return _encodelen(1) + bytes([0b01000010]) |
|
|
|
elif math.isinf(obj): |
|
|
|
if s == 1: |
|
|
|
return _encodelen(1) + chr(0b01000000) |
|
|
|
return _encodelen(1) + bytes([0b01000000]) |
|
|
|
else: |
|
|
|
return _encodelen(1) + chr(0b01000001) |
|
|
|
return _encodelen(1) + bytes([0b01000001]) |
|
|
|
elif obj == 0: |
|
|
|
if s == 1: |
|
|
|
return _encodelen(0) |
|
|
|
else: |
|
|
|
return _encodelen(1) + chr(0b01000011) |
|
|
|
return _encodelen(1) + bytes([0b01000011]) |
|
|
|
|
|
|
|
m, e = _splitfloat(obj) |
|
|
|
|
|
|
@@ -279,17 +282,17 @@ class ASN1Coder(object): |
|
|
|
e += 256**el # convert negative to twos-complement |
|
|
|
|
|
|
|
v = el - 1 |
|
|
|
encexp = _numtostr(e) |
|
|
|
encexp = _numtobytes(e) |
|
|
|
|
|
|
|
val |= v |
|
|
|
r = chr(val) + encexp + _numtostr(m) |
|
|
|
r = bytes([val]) + encexp + _numtobytes(m) |
|
|
|
return _encodelen(len(r)) + r |
|
|
|
|
|
|
|
def dec_float(self, d, pos, end): |
|
|
|
if pos == end: |
|
|
|
return float(0), end |
|
|
|
|
|
|
|
v = ord(d[pos]) |
|
|
|
v = d[pos] |
|
|
|
if v == 0b01000000: |
|
|
|
return float('inf'), end |
|
|
|
elif v == 0b01000001: |
|
|
@@ -314,7 +317,7 @@ class ASN1Coder(object): |
|
|
|
|
|
|
|
exp = self.dec_int(d, pexp, eexp)[0] |
|
|
|
|
|
|
|
n = float(int(d[eexp:end].encode('hex'), 16)) |
|
|
|
n = float(int(bytes.hex(d[eexp:end]), 16)) |
|
|
|
r = n * 2 ** exp |
|
|
|
if v & 0b1000000: |
|
|
|
r = -r |
|
|
@@ -340,7 +343,7 @@ class ASN1Coder(object): |
|
|
|
pass |
|
|
|
|
|
|
|
if self.coerce is None: |
|
|
|
raise TypeError('unhandled object: %s' % `obj`) |
|
|
|
raise TypeError('unhandled object: %s' % repr(obj)) |
|
|
|
|
|
|
|
tf, obj = self.coerce(obj) |
|
|
|
|
|
|
@@ -348,7 +351,7 @@ class ASN1Coder(object): |
|
|
|
return self._typetag[tf] + fun(obj, default=default) |
|
|
|
|
|
|
|
def _loads(self, data, pos, end): |
|
|
|
tag = data[pos] |
|
|
|
tag = data[pos:pos + 1] |
|
|
|
l, b = _decodelen(data, pos + 1) |
|
|
|
if len(data) < pos + 1 + b + l: |
|
|
|
raise ValueError('string not long enough') |
|
|
@@ -365,12 +368,12 @@ class ASN1Coder(object): |
|
|
|
if obj.microsecond: |
|
|
|
ts += ('.%06d' % obj.microsecond).rstrip('0') |
|
|
|
ts += 'Z' |
|
|
|
return _encodelen(len(ts)) + ts |
|
|
|
return _encodelen(len(ts)) + ts.encode('utf-8') |
|
|
|
|
|
|
|
def dec_datetime(self, data, pos, end): |
|
|
|
ts = data[pos:end] |
|
|
|
if ts[-1] != 'Z': |
|
|
|
raise ValueError('last character must be Z') |
|
|
|
ts = data[pos:end].decode('ascii') |
|
|
|
if ts[-1:] != 'Z': |
|
|
|
raise ValueError('last character must be Z, was: %s' % repr(ts[-1])) |
|
|
|
|
|
|
|
# Real bug is in strptime, but work around it here. |
|
|
|
if ' ' in ts: |
|
|
@@ -411,14 +414,14 @@ class ASN1DictCoder(ASN1Coder): |
|
|
|
_typemap = ASN1Coder._typemap.copy() |
|
|
|
_typemap[dict] = 'dict' |
|
|
|
_tagmap = ASN1Coder._tagmap.copy() |
|
|
|
_tagmap['\xe0'] = 'dict' |
|
|
|
_typetag = dict((v, k) for k, v in _tagmap.iteritems()) |
|
|
|
_tagmap[b'\xe0'] = 'dict' |
|
|
|
_typetag = dict((v, k) for k, v in _tagmap.items()) |
|
|
|
|
|
|
|
def enc_dict(self, obj, **kwargs): |
|
|
|
#it = list(obj.iteritems()) |
|
|
|
#it.sort() |
|
|
|
r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in |
|
|
|
obj.iteritems()) |
|
|
|
r = b''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in |
|
|
|
obj.items()) |
|
|
|
return _encodelen(len(r)) + r |
|
|
|
|
|
|
|
def dec_dict(self, d, pos, end): |
|
|
@@ -441,11 +444,29 @@ _coder = ASN1DictCoder() |
|
|
|
dumps = _coder.dumps |
|
|
|
loads = _coder.loads |
|
|
|
|
|
|
|
def cmp(a, b): |
|
|
|
return (a > b) - (a < b) |
|
|
|
|
|
|
|
def universal_cmp(a, b): |
|
|
|
# Because Python 3 sucks, make this function that |
|
|
|
# orders first based upon type, then on value. |
|
|
|
if type(a) == type(b): |
|
|
|
if isinstance(a, (tuple, list)): |
|
|
|
for a, b in zip(a, b): |
|
|
|
if a != b: |
|
|
|
return universal_cmp(a, b) |
|
|
|
|
|
|
|
#return cmp(len(a), len(b)) |
|
|
|
else: |
|
|
|
return cmp(a, b) |
|
|
|
else: |
|
|
|
return id(type(a)) < id(type(b)) |
|
|
|
|
|
|
|
def deeptypecmp(obj, o): |
|
|
|
#print 'dtc:', `obj`, `o` |
|
|
|
#print('dtc:', repr(obj), repr(o)) |
|
|
|
if type(obj) != type(o): |
|
|
|
return False |
|
|
|
if type(obj) in (str, unicode): |
|
|
|
if type(obj) is str: |
|
|
|
return True |
|
|
|
|
|
|
|
if type(obj) in (list, set): |
|
|
@@ -454,10 +475,8 @@ def deeptypecmp(obj, o): |
|
|
|
return False |
|
|
|
|
|
|
|
if type(obj) in (dict,): |
|
|
|
itms = obj.items() |
|
|
|
itms.sort() |
|
|
|
nitms = o.items() |
|
|
|
nitms.sort() |
|
|
|
itms = sorted(obj.items(), key=functools.cmp_to_key(universal_cmp)) |
|
|
|
nitms = sorted(o.items(), key=functools.cmp_to_key(universal_cmp)) |
|
|
|
for (k, v), (nk, nv) in zip(itms, nitms): |
|
|
|
if not deeptypecmp(k, nk): |
|
|
|
return False |
|
|
@@ -474,15 +493,15 @@ class Test_deeptypecmp(unittest.TestCase): |
|
|
|
|
|
|
|
def test_false(self): |
|
|
|
for i in (([[]], [{}]), ([1], ['str']), ([], set()), |
|
|
|
({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}), |
|
|
|
({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}), |
|
|
|
({1: 2, 5: 'sdlkfj'}, {1: 2, 5: b'sdlkfj'}), |
|
|
|
({1: 2, 'sdlkfj': 5}, {1: 2, b'sdlkfj': 5}), |
|
|
|
): |
|
|
|
self.assertFalse(deeptypecmp(*i)) |
|
|
|
self.assertFalse(deeptypecmp(*i), '%s != %s' % (i[0], i[1])) |
|
|
|
|
|
|
|
def genfailures(obj): |
|
|
|
s = dumps(obj) |
|
|
|
for i in xrange(len(s)): |
|
|
|
for j in (chr(x) for x in xrange(256)): |
|
|
|
for i in range(len(s)): |
|
|
|
for j in (bytes([x]) for x in range(256)): |
|
|
|
ts = s[:i] + j + s[i + 1:] |
|
|
|
if ts == s: |
|
|
|
continue |
|
|
@@ -493,24 +512,24 @@ def genfailures(obj): |
|
|
|
except (ValueError, KeyError, IndexError, TypeError): |
|
|
|
pass |
|
|
|
else: # pragma: no cover |
|
|
|
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i]))) |
|
|
|
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, s[i])) |
|
|
|
|
|
|
|
class TestCode(unittest.TestCase): |
|
|
|
def test_primv(self): |
|
|
|
self.assertEqual(dumps(-257), '0202feff'.decode('hex')) |
|
|
|
self.assertEqual(dumps(-256), '0202ff00'.decode('hex')) |
|
|
|
self.assertEqual(dumps(-255), '0202ff01'.decode('hex')) |
|
|
|
self.assertEqual(dumps(-1), '0201ff'.decode('hex')) |
|
|
|
self.assertEqual(dumps(5), '020105'.decode('hex')) |
|
|
|
self.assertEqual(dumps(128), '02020080'.decode('hex')) |
|
|
|
self.assertEqual(dumps(256), '02020100'.decode('hex')) |
|
|
|
self.assertEqual(dumps(-257), bytes.fromhex('0202feff')) |
|
|
|
self.assertEqual(dumps(-256), bytes.fromhex('0202ff00')) |
|
|
|
self.assertEqual(dumps(-255), bytes.fromhex('0202ff01')) |
|
|
|
self.assertEqual(dumps(-1), bytes.fromhex('0201ff')) |
|
|
|
self.assertEqual(dumps(5), bytes.fromhex('020105')) |
|
|
|
self.assertEqual(dumps(128), bytes.fromhex('02020080')) |
|
|
|
self.assertEqual(dumps(256), bytes.fromhex('02020100')) |
|
|
|
|
|
|
|
self.assertEqual(dumps(False), '010100'.decode('hex')) |
|
|
|
self.assertEqual(dumps(True), '0101ff'.decode('hex')) |
|
|
|
self.assertEqual(dumps(False), bytes.fromhex('010100')) |
|
|
|
self.assertEqual(dumps(True), bytes.fromhex('0101ff')) |
|
|
|
|
|
|
|
self.assertEqual(dumps(None), '0500'.decode('hex')) |
|
|
|
self.assertEqual(dumps(None), bytes.fromhex('0500')) |
|
|
|
|
|
|
|
self.assertEqual(dumps(.15625), '090380fb05'.decode('hex')) |
|
|
|
self.assertEqual(dumps(.15625), bytes.fromhex('090380fb05')) |
|
|
|
|
|
|
|
def test_fuzzing(self): |
|
|
|
# Make sure that when a failure is detected here, that it |
|
|
@@ -539,7 +558,7 @@ class TestCode(unittest.TestCase): |
|
|
|
'181632303136303231373136343034372e3035343433367a', #datetime w/ lower z |
|
|
|
'181632303136313220383031303933302e3931353133385a', #datetime w/ space |
|
|
|
]: |
|
|
|
self.assertRaises(ValueError, loads, v.decode('hex')) |
|
|
|
self.assertRaises(ValueError, loads, bytes.fromhex(v)) |
|
|
|
|
|
|
|
def test_invalid_floats(self): |
|
|
|
import mock |
|
|
@@ -548,7 +567,7 @@ class TestCode(unittest.TestCase): |
|
|
|
|
|
|
|
def test_consume(self): |
|
|
|
b = dumps(5) |
|
|
|
self.assertRaises(ValueError, loads, b + '398473', |
|
|
|
self.assertRaises(ValueError, loads, b + b'398473', |
|
|
|
consume=True) |
|
|
|
|
|
|
|
# XXX - still possible that an internal data member |
|
|
@@ -565,10 +584,10 @@ class TestCode(unittest.TestCase): |
|
|
|
def test_cryptoutilasn1(self): |
|
|
|
'''Test DER sequences generated by Crypto.Util.asn1.''' |
|
|
|
|
|
|
|
for s, v in [ ('\x02\x03$\x8a\xf9', 2394873), |
|
|
|
('\x05\x00', None), |
|
|
|
('\x02\x03\x00\x96I', 38473), |
|
|
|
('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200), |
|
|
|
for s, v in [ (b'\x02\x03$\x8a\xf9', 2394873), |
|
|
|
(b'\x05\x00', None), |
|
|
|
(b'\x02\x03\x00\x96I', 38473), |
|
|
|
(b'\x04\x81\xc8' + b'\x00' * 200, b'\x00' * 200), |
|
|
|
]: |
|
|
|
self.assertEqual(loads(s), v) |
|
|
|
|
|
|
@@ -595,7 +614,7 @@ class TestCode(unittest.TestCase): |
|
|
|
sys.float_info.max, sys.float_info.min, |
|
|
|
float('.15625'), |
|
|
|
'weoifjwef', |
|
|
|
u'\U0001f4a9', |
|
|
|
'\U0001f4a9', |
|
|
|
[], [ 1,2,3 ], |
|
|
|
{}, { 5: 10, 'adfkj': 34 }, |
|
|
|
set(), set((1,2,3)), |
|
|
@@ -609,8 +628,8 @@ class TestCode(unittest.TestCase): |
|
|
|
o = loads(s) |
|
|
|
self.assertEqual(i, o) |
|
|
|
|
|
|
|
tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, |
|
|
|
'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] } |
|
|
|
tobj = { 1: 'dflkj', 5: 'sdlkfj', 'float': 1, |
|
|
|
'largeint': 1<<342, 'list': [ 1, 2, 'str', 'str' ] } |
|
|
|
|
|
|
|
out = dumps(tobj) |
|
|
|
self.assertEqual(tobj, loads(out)) |
|
|
@@ -641,7 +660,7 @@ class TestCode(unittest.TestCase): |
|
|
|
self.assertEqual(ac.loads(ac.dumps(o)), v) |
|
|
|
self.assertRaises(TypeError, ac.dumps, Bar()) |
|
|
|
|
|
|
|
v = u'oiejfd' |
|
|
|
v = b'oiejfd' |
|
|
|
o = Baz() |
|
|
|
o.s = v |
|
|
|
|
|
|
@@ -652,7 +671,7 @@ class TestCode(unittest.TestCase): |
|
|
|
self.assertRaises(TypeError, dumps, o) |
|
|
|
|
|
|
|
def test_loads(self): |
|
|
|
self.assertRaises(ValueError, loads, '\x00\x02\x00') |
|
|
|
self.assertRaises(ValueError, loads, b'\x00\x02\x00') |
|
|
|
|
|
|
|
def test_nodict(self): |
|
|
|
'''Verify that ASN1Coder does not support dict.''' |
|
|
|