@@ -15,6 +15,8 @@ import pdb
import sys
import sys
import unittest
import unittest
__all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
def _numtostr(n):
def _numtostr(n):
hs = '%x' % n
hs = '%x' % n
if len(hs) & 1 == 1:
if len(hs) & 1 == 1:
@@ -165,9 +167,14 @@ class ASN1Coder(object):
def dec_dict(self, d, pos, end):
def dec_dict(self, d, pos, end):
r = {}
r = {}
vend = pos
while pos < end:
while pos < end:
k, kend = self._loads(d, pos, end)
k, kend = self._loads(d, pos, end)
#if kend > end:
# raise ValueError('key past end')
v, vend = self._loads(d, kend, end)
v, vend = self._loads(d, kend, end)
if vend > end:
raise ValueError('value past end')
r[k] = v
r[k] = v
pos = vend
pos = vend
@@ -188,8 +195,11 @@ class ASN1Coder(object):
def dec_list(self, d, pos, end):
def dec_list(self, d, pos, end):
r = []
r = []
vend = pos
while pos < end:
while pos < end:
v, vend = self._loads(d, pos, end)
v, vend = self._loads(d, pos, end)
if vend > end:
raise ValueError('load past end')
r.append(v)
r.append(v)
pos = vend
pos = vend
@@ -262,15 +272,21 @@ class ASN1Coder(object):
return float('nan'), end
return float('nan'), end
elif v == 0b01000011:
elif v == 0b01000011:
return float('-0'), end
return float('-0'), end
elif v & 0b110000:
raise ValueError('base must be 2')
elif v & 0b1100:
raise ValueError('scaling factor must be 0')
elif v & 0b11000000 == 0:
raise ValueError('decimal encoding not supported')
#elif v & 0b11000000 == 0b01000000:
#elif v & 0b11000000 == 0b01000000:
# raise ValueError('invalid encoding')
# raise ValueError('invalid encoding')
if not (v & 0b10000000):
raise NotImplementedError
if v & 3 == 3:
if v & 3 == 3:
pexp = pos + 2
pexp = pos + 2
eexp = pos + 2 + ord(d[pos + 1])
explen = ord(d[pos + 1])
if explen <= 3:
raise ValueError('must use other length encoding')
eexp = pos + 2 + explen
else:
else:
pexp = pos + 1
pexp = pos + 1
eexp = pos + 1 + (v & 3) + 1
eexp = pos + 1 + (v & 3) + 1
@@ -311,6 +327,62 @@ class ASN1Coder(object):
raise ValueError('entire string not consumed')
raise ValueError('entire string not consumed')
return r
return r
def deeptypecmp(obj, o):
#print 'dtc:', `obj`, `o`
if type(obj) != type(o):
return False
if type(obj) in (str, unicode):
return True
if type(obj) in (list, set):
for i, j in zip(obj, o):
if not deeptypecmp(i, j):
return False
if type(obj) in (dict,):
itms = obj.items()
itms.sort()
nitms = o.items()
nitms.sort()
for (k, v), (nk, nv) in zip(itms, nitms):
if not deeptypecmp(k, nk):
return False
if not deeptypecmp(v, nv):
return False
return True
class Test_deeptypecmp(unittest.TestCase):
def test_true(self):
for i in ((1,1), ('sldkfj', 'sldkfj')
):
self.assertTrue(deeptypecmp(*i))
def test_false(self):
for i in (([[]], [{}]), ([1], ['str']), ([], set()),
({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
):
self.assertFalse(deeptypecmp(*i))
def genfailures(obj):
s = dumps(obj)
for i in xrange(len(s)):
for j in (chr(x) for x in xrange(256)):
ts = s[:i] + j + s[i + 1:]
if ts == s:
continue
try:
o = loads(ts, consume=True)
if o != obj or not deeptypecmp(o, obj):
raise ValueError
except (ValueError, KeyError, IndexError):
pass
except Exception:
raise
else:
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
_coder = ASN1Coder()
_coder = ASN1Coder()
dumps = _coder.dumps
dumps = _coder.dumps
loads = _coder.loads
loads = _coder.loads
@@ -332,6 +404,12 @@ class TestCode(unittest.TestCase):
self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
def test_fuzzing(self):
genfailures(float(1))
genfailures([ 1, 2, 'sdlkfj' ])
genfailures({ 1: 2, 5: 'sdlkfj' })
genfailures(set([ 1, 2, 'sdlkfj' ]))
def test_consume(self):
def test_consume(self):
b = dumps(5)
b = dumps(5)
self.assertRaises(ValueError, loads, b + '398473', consume=True)
self.assertRaises(ValueError, loads, b + '398473', consume=True)
@@ -349,7 +427,16 @@ class TestCode(unittest.TestCase):
def test_invalids(self):
def test_invalids(self):
# Add tests for base 8, 16 floats among others
# Add tests for base 8, 16 floats among others
for v in [ '010101', ]:
for v in [ '010101',
'0903040001', # float scaling factor
'0903840001', # float scaling factor
'0903100001', # float base
'0903900001', # float base
'0903000001', # float decimal encoding
'0903830001', # float exponent encoding
'3007020101020102040673646c6b666a', # list short string still valid
'c007020101020102020105040673646c6b666a', # dict short value still valid
]:
self.assertRaises(ValueError, loads, v.decode('hex'))
self.assertRaises(ValueError, loads, v.decode('hex'))
def test_cryptoutilasn1(self):
def test_cryptoutilasn1(self):
@@ -377,7 +464,9 @@ class TestCode(unittest.TestCase):
float('.15625'),
float('.15625'),
'weoifjwef',
'weoifjwef',
u'\U0001f4a9',
u'\U0001f4a9',
set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
[],
{},
set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
]:
]:
s = dumps(i)
s = dumps(i)
o = loads(s)
o = loads(s)