From 02e8165cd152fb9463aa1c5390e750619c2c3ec1 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Sun, 14 Feb 2016 12:20:10 -0800 Subject: [PATCH] add a test that fuzzes the loading process to make sure that any changes to the laod string is either caught, or makes a change to the returned objects... Add a few tests found by the fuzzing routines to make sure they are continue to be detected even if fuzzing is turned off.. fixes caught, reading past end of list, set and dict... Various float parameters now raise errors when invalid.. [git-p4: depot-paths = "//depot/python/pypasn1/main/": change = 1823] --- pasn1.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 95 insertions(+), 6 deletions(-) diff --git a/pasn1.py b/pasn1.py index 1f71e77..7015f40 100644 --- a/pasn1.py +++ b/pasn1.py @@ -15,6 +15,8 @@ import pdb import sys import unittest +__all__ = [ 'dumps', 'loads', 'ASN1Coder' ] + def _numtostr(n): hs = '%x' % n if len(hs) & 1 == 1: @@ -165,9 +167,14 @@ class ASN1Coder(object): def dec_dict(self, d, pos, end): r = {} + vend = pos while pos < end: k, kend = self._loads(d, pos, end) + #if kend > end: + # raise ValueError('key past end') v, vend = self._loads(d, kend, end) + if vend > end: + raise ValueError('value past end') r[k] = v pos = vend @@ -188,8 +195,11 @@ class ASN1Coder(object): def dec_list(self, d, pos, end): r = [] + vend = pos while pos < end: v, vend = self._loads(d, pos, end) + if vend > end: + raise ValueError('load past end') r.append(v) pos = vend @@ -262,15 +272,21 @@ class ASN1Coder(object): return float('nan'), end elif v == 0b01000011: return float('-0'), end + elif v & 0b110000: + raise ValueError('base must be 2') + elif v & 0b1100: + raise ValueError('scaling factor must be 0') + elif v & 0b11000000 == 0: + raise ValueError('decimal encoding not supported') #elif v & 0b11000000 == 0b01000000: # raise ValueError('invalid encoding') - if not (v & 0b10000000): - raise NotImplementedError - if v & 3 == 3: pexp = pos + 2 - eexp = pos + 2 + ord(d[pos + 1]) + explen = ord(d[pos + 1]) + if explen <= 3: + raise ValueError('must use other length encoding') + eexp = pos + 2 + explen else: pexp = pos + 1 eexp = pos + 1 + (v & 3) + 1 @@ -311,6 +327,62 @@ class ASN1Coder(object): raise ValueError('entire string not consumed') return r +def deeptypecmp(obj, o): + #print 'dtc:', `obj`, `o` + if type(obj) != type(o): + return False + if type(obj) in (str, unicode): + return True + + if type(obj) in (list, set): + for i, j in zip(obj, o): + if not deeptypecmp(i, j): + return False + + if type(obj) in (dict,): + itms = obj.items() + itms.sort() + nitms = o.items() + nitms.sort() + for (k, v), (nk, nv) in zip(itms, nitms): + if not deeptypecmp(k, nk): + return False + if not deeptypecmp(v, nv): + return False + + return True + +class Test_deeptypecmp(unittest.TestCase): + def test_true(self): + for i in ((1,1), ('sldkfj', 'sldkfj') + ): + self.assertTrue(deeptypecmp(*i)) + + def test_false(self): + for i in (([[]], [{}]), ([1], ['str']), ([], set()), + ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}), + ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}), + ): + self.assertFalse(deeptypecmp(*i)) + +def genfailures(obj): + s = dumps(obj) + for i in xrange(len(s)): + for j in (chr(x) for x in xrange(256)): + ts = s[:i] + j + s[i + 1:] + if ts == s: + continue + try: + o = loads(ts, consume=True) + if o != obj or not deeptypecmp(o, obj): + raise ValueError + except (ValueError, KeyError, IndexError): + pass + except Exception: + raise + else: + raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i]))) + _coder = ASN1Coder() dumps = _coder.dumps loads = _coder.loads @@ -332,6 +404,12 @@ class TestCode(unittest.TestCase): self.assertEqual(dumps(.15625), '090380fb05'.decode('hex')) + def test_fuzzing(self): + genfailures(float(1)) + genfailures([ 1, 2, 'sdlkfj' ]) + genfailures({ 1: 2, 5: 'sdlkfj' }) + genfailures(set([ 1, 2, 'sdlkfj' ])) + def test_consume(self): b = dumps(5) self.assertRaises(ValueError, loads, b + '398473', consume=True) @@ -349,7 +427,16 @@ class TestCode(unittest.TestCase): def test_invalids(self): # Add tests for base 8, 16 floats among others - for v in [ '010101', ]: + for v in [ '010101', + '0903040001', # float scaling factor + '0903840001', # float scaling factor + '0903100001', # float base + '0903900001', # float base + '0903000001', # float decimal encoding + '0903830001', # float exponent encoding + '3007020101020102040673646c6b666a', # list short string still valid + 'c007020101020102020105040673646c6b666a', # dict short value still valid + ]: self.assertRaises(ValueError, loads, v.decode('hex')) def test_cryptoutilasn1(self): @@ -377,7 +464,9 @@ class TestCode(unittest.TestCase): float('.15625'), 'weoifjwef', u'\U0001f4a9', - set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))), + [], + {}, + set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))), ]: s = dumps(i) o = loads(s)