diff --git a/pasn1.py b/pasn1.py index 1f71e77..7015f40 100644 --- a/pasn1.py +++ b/pasn1.py @@ -15,6 +15,8 @@ import pdb import sys import unittest +__all__ = [ 'dumps', 'loads', 'ASN1Coder' ] + def _numtostr(n): hs = '%x' % n if len(hs) & 1 == 1: @@ -165,9 +167,14 @@ class ASN1Coder(object): def dec_dict(self, d, pos, end): r = {} + vend = pos while pos < end: k, kend = self._loads(d, pos, end) + #if kend > end: + # raise ValueError('key past end') v, vend = self._loads(d, kend, end) + if vend > end: + raise ValueError('value past end') r[k] = v pos = vend @@ -188,8 +195,11 @@ class ASN1Coder(object): def dec_list(self, d, pos, end): r = [] + vend = pos while pos < end: v, vend = self._loads(d, pos, end) + if vend > end: + raise ValueError('load past end') r.append(v) pos = vend @@ -262,15 +272,21 @@ class ASN1Coder(object): return float('nan'), end elif v == 0b01000011: return float('-0'), end + elif v & 0b110000: + raise ValueError('base must be 2') + elif v & 0b1100: + raise ValueError('scaling factor must be 0') + elif v & 0b11000000 == 0: + raise ValueError('decimal encoding not supported') #elif v & 0b11000000 == 0b01000000: # raise ValueError('invalid encoding') - if not (v & 0b10000000): - raise NotImplementedError - if v & 3 == 3: pexp = pos + 2 - eexp = pos + 2 + ord(d[pos + 1]) + explen = ord(d[pos + 1]) + if explen <= 3: + raise ValueError('must use other length encoding') + eexp = pos + 2 + explen else: pexp = pos + 1 eexp = pos + 1 + (v & 3) + 1 @@ -311,6 +327,62 @@ class ASN1Coder(object): raise ValueError('entire string not consumed') return r +def deeptypecmp(obj, o): + #print 'dtc:', `obj`, `o` + if type(obj) != type(o): + return False + if type(obj) in (str, unicode): + return True + + if type(obj) in (list, set): + for i, j in zip(obj, o): + if not deeptypecmp(i, j): + return False + + if type(obj) in (dict,): + itms = obj.items() + itms.sort() + nitms = o.items() + nitms.sort() + for (k, v), (nk, nv) in zip(itms, nitms): + if not deeptypecmp(k, nk): + return False + if not deeptypecmp(v, nv): + return False + + return True + +class Test_deeptypecmp(unittest.TestCase): + def test_true(self): + for i in ((1,1), ('sldkfj', 'sldkfj') + ): + self.assertTrue(deeptypecmp(*i)) + + def test_false(self): + for i in (([[]], [{}]), ([1], ['str']), ([], set()), + ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}), + ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}), + ): + self.assertFalse(deeptypecmp(*i)) + +def genfailures(obj): + s = dumps(obj) + for i in xrange(len(s)): + for j in (chr(x) for x in xrange(256)): + ts = s[:i] + j + s[i + 1:] + if ts == s: + continue + try: + o = loads(ts, consume=True) + if o != obj or not deeptypecmp(o, obj): + raise ValueError + except (ValueError, KeyError, IndexError): + pass + except Exception: + raise + else: + raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i]))) + _coder = ASN1Coder() dumps = _coder.dumps loads = _coder.loads @@ -332,6 +404,12 @@ class TestCode(unittest.TestCase): self.assertEqual(dumps(.15625), '090380fb05'.decode('hex')) + def test_fuzzing(self): + genfailures(float(1)) + genfailures([ 1, 2, 'sdlkfj' ]) + genfailures({ 1: 2, 5: 'sdlkfj' }) + genfailures(set([ 1, 2, 'sdlkfj' ])) + def test_consume(self): b = dumps(5) self.assertRaises(ValueError, loads, b + '398473', consume=True) @@ -349,7 +427,16 @@ class TestCode(unittest.TestCase): def test_invalids(self): # Add tests for base 8, 16 floats among others - for v in [ '010101', ]: + for v in [ '010101', + '0903040001', # float scaling factor + '0903840001', # float scaling factor + '0903100001', # float base + '0903900001', # float base + '0903000001', # float decimal encoding + '0903830001', # float exponent encoding + '3007020101020102040673646c6b666a', # list short string still valid + 'c007020101020102020105040673646c6b666a', # dict short value still valid + ]: self.assertRaises(ValueError, loads, v.decode('hex')) def test_cryptoutilasn1(self): @@ -377,7 +464,9 @@ class TestCode(unittest.TestCase): float('.15625'), 'weoifjwef', u'\U0001f4a9', - set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))), + [], + {}, + set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))), ]: s = dumps(i) o = loads(s)