|
- #!/usr/bin/env python
-
- '''A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
- of pickle.
-
- The default dumps/loads uses a profile of ASN.1 that supports serialization
- of key/value pairs. This is non-standard. Instantiate the class ASN1Coder
- to get a pure ASN.1 serializer/deserializer.
-
- All lengths must be specified. That is that End-of-contents octets
- MUST NOT be used. The shorted form of length encoding MUST be used.
- A longer length encoding MUST be rejected.'''
-
- __author__ = 'John-Mark Gurney'
- __copyright__ = 'Copyright 2016-2020 John-Mark Gurney. All rights reserved.'
- __license__ = '2-clause BSD license'
-
- # Copyright 2016-2020, John-Mark Gurney
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- #
- # 1. Redistributions of source code must retain the above copyright notice, this
- # list of conditions and the following disclaimer.
- # 2. Redistributions in binary form must reproduce the above copyright notice,
- # this list of conditions and the following disclaimer in the documentation
- # and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
- # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # The views and conclusions contained in the software and documentation are those
- # of the authors and should not be interpreted as representing official policies,
- # either expressed or implied, of the Project.
-
- import datetime
- import functools
- import math
- import os
- import sys
- import unittest
-
- if sys.version_info.major != 3: # pragma: no cover
- raise RuntimeError('this module only supports python 3')
-
- __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
-
- def _numtobytes(n):
- hs = '%x' % n
- if len(hs) & 1 == 1:
- hs = '0' + hs
- bs = bytes.fromhex(hs)
-
- return bs
-
- def _encodelen(l):
- '''Takes l as a length value, and returns a byte string that
- represents l per ASN.1 rules.'''
-
- if l < 128:
- return bytes([l])
-
- bs = _numtobytes(l)
- return bytes([len(bs) | 0x80]) + bs
-
- def _decodelen(d, pos=0):
- '''Returns the length, and number of bytes required.'''
-
- odp = d[pos]
- if odp < 128:
- return d[pos], 1
- else:
- l = odp & 0x7f
- return int(d[pos + 1:pos + 1 + l].hex(), 16), l + 1
-
- class Test_codelen(unittest.TestCase):
- _testdata = [
- (2, b'\x02'),
- (127, b'\x7f'),
- (128, b'\x81\x80'),
- (255, b'\x81\xff'),
- (256, b'\x82\x01\x00'),
- (65536-1, b'\x82\xff\xff'),
- (65536, b'\x83\x01\x00\x00'),
- ]
-
- def test_el(self):
- for i, j in self._testdata:
- self.assertEqual(_encodelen(i), j)
- self.assertEqual(_decodelen(j), (i, len(j)))
-
- def _splitfloat(f):
- m, e = math.frexp(f)
- # XXX - less than ideal
- while m != math.trunc(m):
- m *= 2
- e -= 1
-
- return math.trunc(m), e
-
- class TestSplitFloat(unittest.TestCase):
- def test_sf(self):
- for a, b in [ (0x2421, -32), (0x5382f, 238),
- (0x1fa8c3b094adf1, 971) ]:
- self.assertEqual(_splitfloat(a * 2**b), (a, b))
-
- class ASN1Coder(object):
- '''A class that contains an PASN.1 encoder/decoder.
-
- Exports two methods, loads and dumps.'''
-
- def __init__(self, coerce=None):
- '''If the arg coerce is provided, when dumping the object,
- if the type is not found, the coerce function will be called
- with the obj. It is expected to return a tuple of a string
- and an object that has the method w/ the string as defined:
- 'bool': __nonzero__
- 'float': compatible w/ float
- 'int': compatible w/ int
- 'list': __iter__
- 'set': __iter__
- 'bytes': __str__ # XXX what is correct here
- 'null': no method needed
- 'unicode': encode method returns UTF-8 encoded bytes
- 'datetime': strftime and microsecond
- '''
-
- self.coerce = coerce
-
- _typemap = {
- bool: 'bool',
- float: 'float',
- int: 'int',
- list: 'list',
- set: 'set',
- bytes: 'bytes',
- type(None): 'null',
- str: 'unicode',
- #decimal.Decimal: 'float',
- datetime.datetime: 'datetime',
- #datetime.timedelta: 'timedelta',
- }
- _tagmap = {
- b'\x01': 'bool',
- b'\x02': 'int',
- b'\x04': 'bytes',
- b'\x05': 'null',
- b'\x09': 'float',
- b'\x0c': 'unicode',
- b'\x18': 'datetime',
- b'\x30': 'list',
- b'\x31': 'set',
- }
-
- _typetag = dict((v, k) for k, v in _tagmap.items())
-
- @staticmethod
- def enc_int(obj, **kwargs):
- l = obj.bit_length()
- l += 1 # space for sign bit
-
- l = (l + 7) // 8
-
- if obj < 0:
- obj += 1 << (l * 8) # twos-complement conversion
-
- v = _numtobytes(obj)
- if len(v) != l:
- # XXX - is this a problem for signed values?
- v = b'\x00' + v # add sign octect
-
- return _encodelen(l) + v
-
- @staticmethod
- def dec_int(d, pos, end):
- if pos == end:
- return 0, end
-
- v = int(bytes.hex(d[pos:end]), 16)
- av = 1 << ((end - pos) * 8 - 1) # sign bit
- if v > av:
- v -= av * 2 # twos-complement conversion
-
- return v, end
-
- @staticmethod
- def enc_bool(obj, **kwargs):
- return b'\x01' + (b'\xff' if obj else b'\x00')
-
- def dec_bool(self, d, pos, end):
- v = self.dec_int(d, pos, end)[0]
- if v not in (-1, 0):
- raise ValueError('invalid bool value: %d' % v)
-
- return bool(v), end
-
- @staticmethod
- def enc_null(obj, **kwargs):
- return b'\x00'
-
- @staticmethod
- def dec_null(d, pos, end):
- return None, end
-
- def enc_list(self, obj, **kwargs):
- r = b''.join(self.dumps(x, **kwargs) for x in obj)
- return _encodelen(len(r)) + r
-
- def dec_list(self, d, pos, end):
- r = []
- vend = pos
- while pos < end:
- v, vend = self._loads(d, pos, end)
- if vend > end:
- raise ValueError('load past end')
- r.append(v)
- pos = vend
-
- return r, vend
-
- enc_set = enc_list
-
- def dec_set(self, d, pos, end):
- r, end = self.dec_list(d, pos, end)
- return set(r), end
-
- @staticmethod
- def enc_bytes(obj, **kwargs):
- return _encodelen(len(obj)) + bytes(obj)
-
- @staticmethod
- def dec_bytes(d, pos, end):
- return d[pos:end], end
-
- @staticmethod
- def enc_unicode(obj, **kwargs):
- encobj = obj.encode('utf-8')
- return _encodelen(len(encobj)) + encobj
-
- def dec_unicode(self, d, pos, end):
- return d[pos:end].decode('utf-8'), end
-
- @staticmethod
- def enc_float(obj, **kwargs):
- s = math.copysign(1, obj)
- if math.isnan(obj):
- return _encodelen(1) + bytes([0b01000010])
- elif math.isinf(obj):
- if s == 1:
- return _encodelen(1) + bytes([0b01000000])
- else:
- return _encodelen(1) + bytes([0b01000001])
- elif obj == 0:
- if s == 1:
- return _encodelen(0)
- else:
- return _encodelen(1) + bytes([0b01000011])
-
- m, e = _splitfloat(obj)
-
- # Binary encoding
- val = 0x80
- if m < 0:
- val |= 0x40
- m = -m
-
- # Base 2
- el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
- if el > 2:
- raise ValueError('exponent too large')
-
- if e < 0:
- e += 256**el # convert negative to twos-complement
-
- v = el - 1
- encexp = _numtobytes(e)
-
- val |= v
- r = bytes([val]) + encexp + _numtobytes(m)
- return _encodelen(len(r)) + r
-
- def dec_float(self, d, pos, end):
- if pos == end:
- return float(0), end
-
- v = d[pos]
- if v == 0b01000000:
- return float('inf'), end
- elif v == 0b01000001:
- return float('-inf'), end
- elif v == 0b01000010:
- return float('nan'), end
- elif v == 0b01000011:
- return float('-0'), end
- elif v & 0b110000:
- raise ValueError('base must be 2')
- elif v & 0b1100:
- raise ValueError('scaling factor must be 0')
- elif v & 0b11000000 == 0:
- raise ValueError('decimal encoding not supported')
- #elif v & 0b11000000 == 0b01000000:
- # raise ValueError('invalid encoding')
-
- if (v & 3) >= 2:
- raise ValueError('large exponents not supported')
- pexp = pos + 1
- eexp = pos + 1 + (v & 3) + 1
-
- exp = self.dec_int(d, pexp, eexp)[0]
-
- n = float(int(bytes.hex(d[eexp:end]), 16))
- r = n * 2 ** exp
- if v & 0b1000000:
- r = -r
-
- return r, end
-
- def dumps(self, obj, default=None):
- '''Convert obj into an array of bytes.
-
- ``default(obj)`` is a function that should return a
- serializable version of obj or raise TypeError. The
- default simply raises TypeError.
- '''
-
-
- try:
- tf = self._typemap[type(obj)]
- except KeyError:
- if default is not None:
- try:
- return self.dumps(default(obj), default=default)
- except TypeError:
- pass
-
- if self.coerce is None:
- raise TypeError('unhandled object: %s' % repr(obj))
-
- tf, obj = self.coerce(obj)
-
- fun = getattr(self, 'enc_%s' % tf)
- return self._typetag[tf] + fun(obj, default=default)
-
- def _loads(self, data, pos, end):
- tag = data[pos:pos + 1]
- l, b = _decodelen(data, pos + 1)
- if len(data) < pos + 1 + b + l:
- raise ValueError('string not long enough')
-
- # XXX - enforce that len(data) == end?
- end = pos + 1 + b + l
-
- t = self._tagmap[tag]
- fun = getattr(self, 'dec_%s' % t)
- return fun(data, pos + 1 + b, end)
-
- def enc_datetime(self, obj, **kwargs):
- obj = obj.astimezone(datetime.timezone.utc)
- ts = obj.strftime('%Y%m%d%H%M%S')
- if obj.microsecond:
- ts += ('.%06d' % obj.microsecond).rstrip('0')
- ts += 'Z'
- return _encodelen(len(ts)) + ts.encode('utf-8')
-
- def dec_datetime(self, data, pos, end):
- ts = data[pos:end].decode('ascii')
- if ts[-1:] != 'Z':
- raise ValueError('last character must be Z, was: %s' % repr(ts[-1]))
-
- # Real bug is in strptime, but work around it here.
- if ' ' in ts:
- raise ValueError('no spaces are allowed')
-
- if '.' in ts:
- fstr = '%Y%m%d%H%M%S.%fZ'
- if ts.endswith('0Z'):
- raise ValueError('invalid trailing zeros')
- else:
- fstr = '%Y%m%d%H%M%SZ'
- return datetime.datetime.strptime(ts, fstr).replace(tzinfo=datetime.timezone.utc), end
-
- def loads(self, data, pos=0, end=None, consume=False):
- '''Load from data, starting at pos (optional), and ending
- at end (optional). If it is required to consume the
- whole string (not the default), set consume to True, and
- a ValueError will be raised if the string is not
- completely consumed. The second item in ValueError will
- be the possition that was the detected end.'''
-
- if end is None:
- end = len(data)
- r, e = self._loads(data, pos, end)
-
- if consume and e != end:
- raise ValueError('entire string not consumed', e)
-
- return r
-
- class ASN1DictCoder(ASN1Coder):
- '''This adds support for the non-standard dict serialization.
-
- The coerce method also supports the following type:
- 'dict': iteritems
- '''
-
- _typemap = ASN1Coder._typemap.copy()
- _typemap[dict] = 'dict'
- _tagmap = ASN1Coder._tagmap.copy()
- _tagmap[b'\xe0'] = 'dict'
- _typetag = dict((v, k) for k, v in _tagmap.items())
-
- def enc_dict(self, obj, **kwargs):
- #it = list(obj.iteritems())
- #it.sort()
- r = b''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
- obj.items())
- return _encodelen(len(r)) + r
-
- def dec_dict(self, d, pos, end):
- r = {}
- vend = pos
- while pos < end:
- k, kend = self._loads(d, pos, end)
- #if kend > end:
- # raise ValueError('key past end')
- v, vend = self._loads(d, kend, end)
- if vend > end:
- raise ValueError('value past end')
-
- r[k] = v
- pos = vend
-
- return r, vend
-
- _coder = ASN1DictCoder()
- dumps = _coder.dumps
- loads = _coder.loads
-
- def cmp(a, b):
- return (a > b) - (a < b)
-
- def universal_cmp(a, b):
- # Because Python 3 sucks, make this function that
- # orders first based upon type, then on value.
- if type(a) == type(b):
- if isinstance(a, (tuple, list)):
- for a, b in zip(a, b):
- if a != b:
- return universal_cmp(a, b)
-
- #return cmp(len(a), len(b))
- else:
- return cmp(a, b)
- else:
- return id(type(a)) < id(type(b))
-
- def deeptypecmp(obj, o):
- #print('dtc:', repr(obj), repr(o))
- if type(obj) != type(o):
- return False
- if type(obj) is str:
- return True
-
- if type(obj) in (list, set):
- for i, j in zip(obj, o):
- if not deeptypecmp(i, j):
- return False
-
- if type(obj) in (dict,):
- itms = sorted(obj.items(), key=functools.cmp_to_key(universal_cmp))
- nitms = sorted(o.items(), key=functools.cmp_to_key(universal_cmp))
- for (k, v), (nk, nv) in zip(itms, nitms):
- if not deeptypecmp(k, nk):
- return False
- if not deeptypecmp(v, nv):
- return False
-
- return True
-
- class Test_deeptypecmp(unittest.TestCase):
- def test_true(self):
- for i in ((1,1), ('sldkfj', 'sldkfj')
- ):
- self.assertTrue(deeptypecmp(*i))
-
- def test_false(self):
- for i in (([[]], [{}]), ([1], ['str']), ([], set()),
- ({1: 2, 5: 'sdlkfj'}, {1: 2, 5: b'sdlkfj'}),
- ({1: 2, 'sdlkfj': 5}, {1: 2, b'sdlkfj': 5}),
- ):
- self.assertFalse(deeptypecmp(*i), '%s != %s' % (i[0], i[1]))
-
- def genfailures(obj):
- s = dumps(obj)
- for i in range(len(s)):
- for j in (bytes([x]) for x in range(256)):
- ts = s[:i] + j + s[i + 1:]
- if ts == s:
- continue
- try:
- o = loads(ts, consume=True)
- if o != obj or not deeptypecmp(o, obj):
- raise ValueError
- except (ValueError, KeyError, IndexError, TypeError):
- pass
- else: # pragma: no cover
- raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, s[i]))
-
- class TestCode(unittest.TestCase):
- def test_primv(self):
- self.assertEqual(dumps(-257), bytes.fromhex('0202feff'))
- self.assertEqual(dumps(-256), bytes.fromhex('0202ff00'))
- self.assertEqual(dumps(-255), bytes.fromhex('0202ff01'))
- self.assertEqual(dumps(-1), bytes.fromhex('0201ff'))
- self.assertEqual(dumps(5), bytes.fromhex('020105'))
- self.assertEqual(dumps(128), bytes.fromhex('02020080'))
- self.assertEqual(dumps(256), bytes.fromhex('02020100'))
-
- self.assertEqual(dumps(False), bytes.fromhex('010100'))
- self.assertEqual(dumps(True), bytes.fromhex('0101ff'))
-
- self.assertEqual(dumps(None), bytes.fromhex('0500'))
-
- self.assertEqual(dumps(.15625), bytes.fromhex('090380fb05'))
-
- def test_fuzzing(self):
- # Make sure that when a failure is detected here, that it
- # gets added to test_invalids, so that this function may be
- # disabled.
- genfailures(float(1))
- genfailures([ 1, 2, 'sdlkfj' ])
- genfailures({ 1: 2, 5: 'sdlkfj' })
- genfailures(set([ 1, 2, 'sdlkfj' ]))
- genfailures(True)
- genfailures(datetime.datetime.utcnow())
-
- def test_invalids(self):
- # Add tests for base 8, 16 floats among others
- for v in [ '010101',
- '0903040001', # float scaling factor
- '0903840001', # float scaling factor
- '0903100001', # float base
- '0903900001', # float base
- '0903000001', # float decimal encoding
- '0903830001', # float exponent encoding
- '090b827fffcc0df505d0fa58f7', # float large exponent
- '3007020101020102040673646c6b666a', # list short string still valid
- 'e007020101020102020105040673646c6b666a', # dict short value still valid
- '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
- '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
- '181632303136313220383031303933302e3931353133385a', #datetime w/ space
- ]:
- self.assertRaises(ValueError, loads, bytes.fromhex(v))
-
- def test_invalid_floats(self):
- from unittest import mock
- with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
- self.assertRaises(ValueError, dumps, 1.1)
-
- def test_consume(self):
- b = dumps(5)
- self.assertRaises(ValueError, loads, b + b'398473',
- consume=True)
-
- # XXX - still possible that an internal data member
- # doesn't consume all
-
- # XXX - test that sets are ordered properly
- # XXX - test that dicts are ordered properly..
-
- def test_nan(self):
- s = dumps(float('nan'))
- v = loads(s)
- self.assertTrue(math.isnan(v))
-
- def test_cryptoutilasn1(self):
- '''Test DER sequences generated by Crypto.Util.asn1.'''
-
- for s, v in [ (b'\x02\x03$\x8a\xf9', 2394873),
- (b'\x05\x00', None),
- (b'\x02\x03\x00\x96I', 38473),
- (b'\x04\x81\xc8' + b'\x00' * 200, b'\x00' * 200),
- ]:
- self.assertEqual(loads(s), v)
-
- def test_longstrings(self):
- for i in (203, 65484):
- s = os.urandom(i)
- v = dumps(s)
- self.assertEqual(loads(v), s)
-
- def test_invaliddate(self):
- pass
- # XXX - add test to reject datetime w/ tzinfo, or that it
- # handles it properly
-
- def test_tzdate(self):
- dlocal = datetime.datetime.now()
- dutc = dlocal.astimezone(datetime.timezone.utc)
- dts = dutc.timestamp()
-
- # sanity check
- self.assertEqual(dts, dlocal.timestamp())
-
- # verify that the same datetime, but with different
- # tzinfo, is serialized the same way
- self.assertEqual(dumps(dlocal), dumps(dutc))
-
- # that when dutc is read back
- dround = loads(dumps(dutc))
-
- # that it represents the same time
- self.assertEqual(dround.timestamp(), dts)
-
- # that when dlocal is read back
- dround = loads(dumps(dlocal))
-
- # that it represents the same time
- self.assertEqual(dround.timestamp(), dts)
-
- def test_dumps(self):
- for i in [ None,
- True, False,
- -1, 0, 1, 255, 256, -255, -256,
- 23498732498723, -2398729387234,
- (1<<2383) + 23984734, (-1<<1983) + 23984723984,
- float(0), float('-0'), float('inf'), float('-inf'),
- float(1.0), float(-1.0), float('353.3487'),
- float('2.38723873e+307'), float('2.387349e-317'),
- sys.float_info.max, sys.float_info.min,
- float('.15625'),
- 'weoifjwef',
- '\U0001f4a9',
- [], [ 1,2,3 ],
- {}, { 5: 10, 'adfkj': 34 },
- set(), set((1,2,3)),
- set((1,'sjlfdkj', None, float('inf'))),
- datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc),
- [ datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc), ' ' ],
- datetime.datetime.utcnow().replace(microsecond=0, tzinfo=datetime.timezone.utc),
- datetime.datetime.utcnow().replace(microsecond=1000, tzinfo=datetime.timezone.utc),
- ]:
- s = dumps(i)
- o = loads(s)
- self.assertEqual(i, o)
-
- tobj = { 1: 'dflkj', 5: 'sdlkfj', 'float': 1,
- 'largeint': 1<<342, 'list': [ 1, 2, 'str', 'str' ] }
-
- out = dumps(tobj)
- self.assertEqual(tobj, loads(out))
-
- def test_coerce(self):
- class Foo:
- pass
-
- class Bar:
- pass
-
- class Baz:
- pass
- def coerce(obj):
- if isinstance(obj, Foo):
- return 'list', obj.lst
- elif isinstance(obj, Baz):
- return 'bytes', obj.s
-
- raise TypeError('unknown type')
-
- ac = ASN1Coder(coerce)
-
- v = [1, 2, 3]
- o = Foo()
- o.lst = v
-
- self.assertEqual(ac.loads(ac.dumps(o)), v)
- self.assertRaises(TypeError, ac.dumps, Bar())
-
- v = b'oiejfd'
- o = Baz()
- o.s = v
-
- es = ac.dumps(o)
- self.assertEqual(ac.loads(es), v)
- self.assertIsInstance(es, bytes)
-
- self.assertRaises(TypeError, dumps, o)
-
- def test_loads(self):
- self.assertRaises(ValueError, loads, b'\x00\x02\x00')
-
- def test_nodict(self):
- '''Verify that ASN1Coder does not support dict.'''
-
- self.assertRaises(KeyError, ASN1Coder().loads, dumps({}))
-
- def test_dumps_default(self):
- '''Test that dumps supports the default method, and that
- it works.'''
-
- class Dummy(object):
- def somefun(self):
- return 5
-
- class Dummy2(object):
- def somefun(self):
- return [ Dummy() ]
-
- def deffun(obj):
- try:
- return obj.somefun()
- except Exception:
- raise TypeError
-
- self.assertEqual(dumps(5), dumps(Dummy(), default=deffun))
-
- # Make sure it works for the various containers
- self.assertEqual(dumps([5]), dumps([Dummy()], default=deffun))
- self.assertEqual(dumps({ 5: 5 }), dumps({ Dummy(): Dummy() },
- default=deffun))
- self.assertEqual(dumps([5]), dumps(Dummy2(), default=deffun))
-
- # Make sure that an error is raised when the function doesn't work
- self.assertRaises(TypeError, dumps, object(), default=deffun)
|