Browse Source

add a test that fuzzes the loading process to make sure that

any changes to the laod string is either caught, or makes a
change to the returned objects...

Add a few tests found by the fuzzing routines to make sure
they are continue to be detected even if fuzzing is turned off..

fixes caught, reading past end of list, set and dict... Various float
parameters now raise errors when invalid..

[git-p4: depot-paths = "//depot/python/pypasn1/main/": change = 1823]
python2
John-Mark Gurney 8 years ago
parent
commit
02e8165cd1
1 changed files with 95 additions and 6 deletions
  1. +95
    -6
      pasn1.py

+ 95
- 6
pasn1.py View File

@@ -15,6 +15,8 @@ import pdb
import sys import sys
import unittest import unittest


__all__ = [ 'dumps', 'loads', 'ASN1Coder' ]

def _numtostr(n): def _numtostr(n):
hs = '%x' % n hs = '%x' % n
if len(hs) & 1 == 1: if len(hs) & 1 == 1:
@@ -165,9 +167,14 @@ class ASN1Coder(object):


def dec_dict(self, d, pos, end): def dec_dict(self, d, pos, end):
r = {} r = {}
vend = pos
while pos < end: while pos < end:
k, kend = self._loads(d, pos, end) k, kend = self._loads(d, pos, end)
#if kend > end:
# raise ValueError('key past end')
v, vend = self._loads(d, kend, end) v, vend = self._loads(d, kend, end)
if vend > end:
raise ValueError('value past end')


r[k] = v r[k] = v
pos = vend pos = vend
@@ -188,8 +195,11 @@ class ASN1Coder(object):


def dec_list(self, d, pos, end): def dec_list(self, d, pos, end):
r = [] r = []
vend = pos
while pos < end: while pos < end:
v, vend = self._loads(d, pos, end) v, vend = self._loads(d, pos, end)
if vend > end:
raise ValueError('load past end')
r.append(v) r.append(v)
pos = vend pos = vend


@@ -262,15 +272,21 @@ class ASN1Coder(object):
return float('nan'), end return float('nan'), end
elif v == 0b01000011: elif v == 0b01000011:
return float('-0'), end return float('-0'), end
elif v & 0b110000:
raise ValueError('base must be 2')
elif v & 0b1100:
raise ValueError('scaling factor must be 0')
elif v & 0b11000000 == 0:
raise ValueError('decimal encoding not supported')
#elif v & 0b11000000 == 0b01000000: #elif v & 0b11000000 == 0b01000000:
# raise ValueError('invalid encoding') # raise ValueError('invalid encoding')


if not (v & 0b10000000):
raise NotImplementedError

if v & 3 == 3: if v & 3 == 3:
pexp = pos + 2 pexp = pos + 2
eexp = pos + 2 + ord(d[pos + 1])
explen = ord(d[pos + 1])
if explen <= 3:
raise ValueError('must use other length encoding')
eexp = pos + 2 + explen
else: else:
pexp = pos + 1 pexp = pos + 1
eexp = pos + 1 + (v & 3) + 1 eexp = pos + 1 + (v & 3) + 1
@@ -311,6 +327,62 @@ class ASN1Coder(object):
raise ValueError('entire string not consumed') raise ValueError('entire string not consumed')
return r return r


def deeptypecmp(obj, o):
#print 'dtc:', `obj`, `o`
if type(obj) != type(o):
return False
if type(obj) in (str, unicode):
return True

if type(obj) in (list, set):
for i, j in zip(obj, o):
if not deeptypecmp(i, j):
return False

if type(obj) in (dict,):
itms = obj.items()
itms.sort()
nitms = o.items()
nitms.sort()
for (k, v), (nk, nv) in zip(itms, nitms):
if not deeptypecmp(k, nk):
return False
if not deeptypecmp(v, nv):
return False
return True

class Test_deeptypecmp(unittest.TestCase):
def test_true(self):
for i in ((1,1), ('sldkfj', 'sldkfj')
):
self.assertTrue(deeptypecmp(*i))

def test_false(self):
for i in (([[]], [{}]), ([1], ['str']), ([], set()),
({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
):
self.assertFalse(deeptypecmp(*i))

def genfailures(obj):
s = dumps(obj)
for i in xrange(len(s)):
for j in (chr(x) for x in xrange(256)):
ts = s[:i] + j + s[i + 1:]
if ts == s:
continue
try:
o = loads(ts, consume=True)
if o != obj or not deeptypecmp(o, obj):
raise ValueError
except (ValueError, KeyError, IndexError):
pass
except Exception:
raise
else:
raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))

_coder = ASN1Coder() _coder = ASN1Coder()
dumps = _coder.dumps dumps = _coder.dumps
loads = _coder.loads loads = _coder.loads
@@ -332,6 +404,12 @@ class TestCode(unittest.TestCase):


self.assertEqual(dumps(.15625), '090380fb05'.decode('hex')) self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))


def test_fuzzing(self):
genfailures(float(1))
genfailures([ 1, 2, 'sdlkfj' ])
genfailures({ 1: 2, 5: 'sdlkfj' })
genfailures(set([ 1, 2, 'sdlkfj' ]))

def test_consume(self): def test_consume(self):
b = dumps(5) b = dumps(5)
self.assertRaises(ValueError, loads, b + '398473', consume=True) self.assertRaises(ValueError, loads, b + '398473', consume=True)
@@ -349,7 +427,16 @@ class TestCode(unittest.TestCase):


def test_invalids(self): def test_invalids(self):
# Add tests for base 8, 16 floats among others # Add tests for base 8, 16 floats among others
for v in [ '010101', ]:
for v in [ '010101',
'0903040001', # float scaling factor
'0903840001', # float scaling factor
'0903100001', # float base
'0903900001', # float base
'0903000001', # float decimal encoding
'0903830001', # float exponent encoding
'3007020101020102040673646c6b666a', # list short string still valid
'c007020101020102020105040673646c6b666a', # dict short value still valid
]:
self.assertRaises(ValueError, loads, v.decode('hex')) self.assertRaises(ValueError, loads, v.decode('hex'))


def test_cryptoutilasn1(self): def test_cryptoutilasn1(self):
@@ -377,7 +464,9 @@ class TestCode(unittest.TestCase):
float('.15625'), float('.15625'),
'weoifjwef', 'weoifjwef',
u'\U0001f4a9', u'\U0001f4a9',
set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
[],
{},
set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
]: ]:
s = dumps(i) s = dumps(i)
o = loads(s) o = loads(s)


Loading…
Cancel
Save