A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

690 lines
18 KiB

  1. #!/usr/bin/env python
  2. '''A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. of pickle.
  4. The default dumps/loads uses a profile of ASN.1 that supports serialization
  5. of key/value pairs. This is non-standard. Instantiate the class ASN1Coder
  6. to get a pure ASN.1 serializer/deserializer.
  7. All lengths must be specified. That is that End-of-contents octets
  8. MUST NOT be used. The shorted form of length encoding MUST be used.
  9. A longer length encoding MUST be rejected.'''
  10. __author__ = 'John-Mark Gurney'
  11. __copyright__ = 'Copyright 2016-2019 John-Mark Gurney. All rights reserved.'
  12. __license__ = '2-clause BSD license'
  13. # Copyright 2016-2019, John-Mark Gurney
  14. # All rights reserved.
  15. #
  16. # Redistribution and use in source and binary forms, with or without
  17. # modification, are permitted provided that the following conditions are met:
  18. #
  19. # 1. Redistributions of source code must retain the above copyright notice, this
  20. # list of conditions and the following disclaimer.
  21. # 2. Redistributions in binary form must reproduce the above copyright notice,
  22. # this list of conditions and the following disclaimer in the documentation
  23. # and/or other materials provided with the distribution.
  24. #
  25. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  26. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  27. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  28. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
  29. # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  30. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  31. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  32. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  33. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  34. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  35. #
  36. # The views and conclusions contained in the software and documentation are those
  37. # of the authors and should not be interpreted as representing official policies,
  38. # either expressed or implied, of the Project.
  39. import datetime
  40. import math
  41. import os
  42. import sys
  43. import unittest
  44. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  45. def _numtostr(n):
  46. hs = '%x' % n
  47. if len(hs) & 1 == 1:
  48. hs = '0' + hs
  49. bs = hs.decode('hex')
  50. return bs
  51. def _encodelen(l):
  52. '''Takes l as a length value, and returns a byte string that
  53. represents l per ASN.1 rules.'''
  54. if l < 128:
  55. return chr(l)
  56. bs = _numtostr(l)
  57. return chr(len(bs) | 0x80) + bs
  58. def _decodelen(d, pos=0):
  59. '''Returns the length, and number of bytes required.'''
  60. odp = ord(d[pos])
  61. if odp < 128:
  62. return ord(d[pos]), 1
  63. else:
  64. l = odp & 0x7f
  65. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  66. class Test_codelen(unittest.TestCase):
  67. _testdata = [
  68. (2, '\x02'),
  69. (127, '\x7f'),
  70. (128, '\x81\x80'),
  71. (255, '\x81\xff'),
  72. (256, '\x82\x01\x00'),
  73. (65536-1, '\x82\xff\xff'),
  74. (65536, '\x83\x01\x00\x00'),
  75. ]
  76. def test_el(self):
  77. for i, j in self._testdata:
  78. self.assertEqual(_encodelen(i), j)
  79. self.assertEqual(_decodelen(j), (i, len(j)))
  80. def _splitfloat(f):
  81. m, e = math.frexp(f)
  82. # XXX - less than ideal
  83. while m != math.trunc(m):
  84. m *= 2
  85. e -= 1
  86. return m, e
  87. class TestSplitFloat(unittest.TestCase):
  88. def test_sf(self):
  89. for a, b in [ (0x2421, -32), (0x5382f, 238),
  90. (0x1fa8c3b094adf1, 971) ]:
  91. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  92. class ASN1Coder(object):
  93. '''A class that contains an PASN.1 encoder/decoder.
  94. Exports two methods, loads and dumps.'''
  95. def __init__(self, coerce=None):
  96. '''If the arg coerce is provided, when dumping the object,
  97. if the type is not found, the coerce function will be called
  98. with the obj. It is expected to return a tuple of a string
  99. and an object that has the method w/ the string as defined:
  100. 'bool': __nonzero__
  101. 'float': compatible w/ float
  102. 'int': compatible w/ int
  103. 'list': __iter__
  104. 'set': __iter__
  105. 'bytes': __str__
  106. 'null': no method needed
  107. 'unicode': encode method returns UTF-8 encoded bytes
  108. 'datetime': strftime and microsecond
  109. '''
  110. self.coerce = coerce
  111. _typemap = {
  112. bool: 'bool',
  113. float: 'float',
  114. int: 'int',
  115. list: 'list',
  116. long: 'int',
  117. set: 'set',
  118. str: 'bytes',
  119. type(None): 'null',
  120. unicode: 'unicode',
  121. #decimal.Decimal: 'float',
  122. datetime.datetime: 'datetime',
  123. #datetime.timedelta: 'timedelta',
  124. }
  125. _tagmap = {
  126. '\x01': 'bool',
  127. '\x02': 'int',
  128. '\x04': 'bytes',
  129. '\x05': 'null',
  130. '\x09': 'float',
  131. '\x0c': 'unicode',
  132. '\x18': 'datetime',
  133. '\x30': 'list',
  134. '\x31': 'set',
  135. }
  136. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  137. @staticmethod
  138. def enc_int(obj, **kwargs):
  139. l = obj.bit_length()
  140. l += 1 # space for sign bit
  141. l = (l + 7) // 8
  142. if obj < 0:
  143. obj += 1 << (l * 8) # twos-complement conversion
  144. v = _numtostr(obj)
  145. if len(v) != l:
  146. # XXX - is this a problem for signed values?
  147. v = '\x00' + v # add sign octect
  148. return _encodelen(l) + v
  149. @staticmethod
  150. def dec_int(d, pos, end):
  151. if pos == end:
  152. return 0, end
  153. v = int(d[pos:end].encode('hex'), 16)
  154. av = 1 << ((end - pos) * 8 - 1) # sign bit
  155. if v > av:
  156. v -= av * 2 # twos-complement conversion
  157. return v, end
  158. @staticmethod
  159. def enc_bool(obj, **kwargs):
  160. return '\x01' + ('\xff' if obj else '\x00')
  161. def dec_bool(self, d, pos, end):
  162. v = self.dec_int(d, pos, end)[0]
  163. if v not in (-1, 0):
  164. raise ValueError('invalid bool value: %d' % v)
  165. return bool(v), end
  166. @staticmethod
  167. def enc_null(obj, **kwargs):
  168. return '\x00'
  169. @staticmethod
  170. def dec_null(d, pos, end):
  171. return None, end
  172. def enc_list(self, obj, **kwargs):
  173. r = ''.join(self.dumps(x, **kwargs) for x in obj)
  174. return _encodelen(len(r)) + r
  175. def dec_list(self, d, pos, end):
  176. r = []
  177. vend = pos
  178. while pos < end:
  179. v, vend = self._loads(d, pos, end)
  180. if vend > end:
  181. raise ValueError('load past end')
  182. r.append(v)
  183. pos = vend
  184. return r, vend
  185. enc_set = enc_list
  186. def dec_set(self, d, pos, end):
  187. r, end = self.dec_list(d, pos, end)
  188. return set(r), end
  189. @staticmethod
  190. def enc_bytes(obj, **kwargs):
  191. return _encodelen(len(obj)) + bytes(obj)
  192. @staticmethod
  193. def dec_bytes(d, pos, end):
  194. return d[pos:end], end
  195. @staticmethod
  196. def enc_unicode(obj, **kwargs):
  197. encobj = obj.encode('utf-8')
  198. return _encodelen(len(encobj)) + encobj
  199. def dec_unicode(self, d, pos, end):
  200. return d[pos:end].decode('utf-8'), end
  201. @staticmethod
  202. def enc_float(obj, **kwargs):
  203. s = math.copysign(1, obj)
  204. if math.isnan(obj):
  205. return _encodelen(1) + chr(0b01000010)
  206. elif math.isinf(obj):
  207. if s == 1:
  208. return _encodelen(1) + chr(0b01000000)
  209. else:
  210. return _encodelen(1) + chr(0b01000001)
  211. elif obj == 0:
  212. if s == 1:
  213. return _encodelen(0)
  214. else:
  215. return _encodelen(1) + chr(0b01000011)
  216. m, e = _splitfloat(obj)
  217. # Binary encoding
  218. val = 0x80
  219. if m < 0:
  220. val |= 0x40
  221. m = -m
  222. # Base 2
  223. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  224. if el > 2:
  225. raise ValueError('exponent too large')
  226. if e < 0:
  227. e += 256**el # convert negative to twos-complement
  228. v = el - 1
  229. encexp = _numtostr(e)
  230. val |= v
  231. r = chr(val) + encexp + _numtostr(m)
  232. return _encodelen(len(r)) + r
  233. def dec_float(self, d, pos, end):
  234. if pos == end:
  235. return float(0), end
  236. v = ord(d[pos])
  237. if v == 0b01000000:
  238. return float('inf'), end
  239. elif v == 0b01000001:
  240. return float('-inf'), end
  241. elif v == 0b01000010:
  242. return float('nan'), end
  243. elif v == 0b01000011:
  244. return float('-0'), end
  245. elif v & 0b110000:
  246. raise ValueError('base must be 2')
  247. elif v & 0b1100:
  248. raise ValueError('scaling factor must be 0')
  249. elif v & 0b11000000 == 0:
  250. raise ValueError('decimal encoding not supported')
  251. #elif v & 0b11000000 == 0b01000000:
  252. # raise ValueError('invalid encoding')
  253. if (v & 3) >= 2:
  254. raise ValueError('large exponents not supported')
  255. pexp = pos + 1
  256. eexp = pos + 1 + (v & 3) + 1
  257. exp = self.dec_int(d, pexp, eexp)[0]
  258. n = float(int(d[eexp:end].encode('hex'), 16))
  259. r = n * 2 ** exp
  260. if v & 0b1000000:
  261. r = -r
  262. return r, end
  263. def dumps(self, obj, default=None):
  264. '''Convert obj into an array of bytes.
  265. ``default(obj)`` is a function that should return a
  266. serializable version of obj or raise TypeError. The
  267. default simply raises TypeError.
  268. '''
  269. try:
  270. tf = self._typemap[type(obj)]
  271. except KeyError:
  272. if default is not None:
  273. try:
  274. return self.dumps(default(obj), default=default)
  275. except TypeError:
  276. pass
  277. if self.coerce is None:
  278. raise TypeError('unhandled object: %s' % `obj`)
  279. tf, obj = self.coerce(obj)
  280. fun = getattr(self, 'enc_%s' % tf)
  281. return self._typetag[tf] + fun(obj, default=default)
  282. def _loads(self, data, pos, end):
  283. tag = data[pos]
  284. l, b = _decodelen(data, pos + 1)
  285. if len(data) < pos + 1 + b + l:
  286. raise ValueError('string not long enough')
  287. # XXX - enforce that len(data) == end?
  288. end = pos + 1 + b + l
  289. t = self._tagmap[tag]
  290. fun = getattr(self, 'dec_%s' % t)
  291. return fun(data, pos + 1 + b, end)
  292. def enc_datetime(self, obj, **kwargs):
  293. ts = obj.strftime('%Y%m%d%H%M%S')
  294. if obj.microsecond:
  295. ts += ('.%06d' % obj.microsecond).rstrip('0')
  296. ts += 'Z'
  297. return _encodelen(len(ts)) + ts
  298. def dec_datetime(self, data, pos, end):
  299. ts = data[pos:end]
  300. if ts[-1] != 'Z':
  301. raise ValueError('last character must be Z')
  302. # Real bug is in strptime, but work around it here.
  303. if ' ' in ts:
  304. raise ValueError('no spaces are allowed')
  305. if '.' in ts:
  306. fstr = '%Y%m%d%H%M%S.%fZ'
  307. if ts.endswith('0Z'):
  308. raise ValueError('invalid trailing zeros')
  309. else:
  310. fstr = '%Y%m%d%H%M%SZ'
  311. return datetime.datetime.strptime(ts, fstr), end
  312. def loads(self, data, pos=0, end=None, consume=False):
  313. '''Load from data, starting at pos (optional), and ending
  314. at end (optional). If it is required to consume the
  315. whole string (not the default), set consume to True, and
  316. a ValueError will be raised if the string is not
  317. completely consumed. The second item in ValueError will
  318. be the possition that was the detected end.'''
  319. if end is None:
  320. end = len(data)
  321. r, e = self._loads(data, pos, end)
  322. if consume and e != end:
  323. raise ValueError('entire string not consumed', e)
  324. return r
  325. class ASN1DictCoder(ASN1Coder):
  326. '''This adds support for the non-standard dict serialization.
  327. The coerce method also supports the following type:
  328. 'dict': iteritems
  329. '''
  330. _typemap = ASN1Coder._typemap.copy()
  331. _typemap[dict] = 'dict'
  332. _tagmap = ASN1Coder._tagmap.copy()
  333. _tagmap['\xe0'] = 'dict'
  334. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  335. def enc_dict(self, obj, **kwargs):
  336. #it = list(obj.iteritems())
  337. #it.sort()
  338. r = ''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
  339. obj.iteritems())
  340. return _encodelen(len(r)) + r
  341. def dec_dict(self, d, pos, end):
  342. r = {}
  343. vend = pos
  344. while pos < end:
  345. k, kend = self._loads(d, pos, end)
  346. #if kend > end:
  347. # raise ValueError('key past end')
  348. v, vend = self._loads(d, kend, end)
  349. if vend > end:
  350. raise ValueError('value past end')
  351. r[k] = v
  352. pos = vend
  353. return r, vend
  354. _coder = ASN1DictCoder()
  355. dumps = _coder.dumps
  356. loads = _coder.loads
  357. def deeptypecmp(obj, o):
  358. #print 'dtc:', `obj`, `o`
  359. if type(obj) != type(o):
  360. return False
  361. if type(obj) in (str, unicode):
  362. return True
  363. if type(obj) in (list, set):
  364. for i, j in zip(obj, o):
  365. if not deeptypecmp(i, j):
  366. return False
  367. if type(obj) in (dict,):
  368. itms = obj.items()
  369. itms.sort()
  370. nitms = o.items()
  371. nitms.sort()
  372. for (k, v), (nk, nv) in zip(itms, nitms):
  373. if not deeptypecmp(k, nk):
  374. return False
  375. if not deeptypecmp(v, nv):
  376. return False
  377. return True
  378. class Test_deeptypecmp(unittest.TestCase):
  379. def test_true(self):
  380. for i in ((1,1), ('sldkfj', 'sldkfj')
  381. ):
  382. self.assertTrue(deeptypecmp(*i))
  383. def test_false(self):
  384. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  385. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  386. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  387. ):
  388. self.assertFalse(deeptypecmp(*i))
  389. def genfailures(obj):
  390. s = dumps(obj)
  391. for i in xrange(len(s)):
  392. for j in (chr(x) for x in xrange(256)):
  393. ts = s[:i] + j + s[i + 1:]
  394. if ts == s:
  395. continue
  396. try:
  397. o = loads(ts, consume=True)
  398. if o != obj or not deeptypecmp(o, obj):
  399. raise ValueError
  400. except (ValueError, KeyError, IndexError, TypeError):
  401. pass
  402. else: # pragma: no cover
  403. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  404. class TestCode(unittest.TestCase):
  405. def test_primv(self):
  406. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  407. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  408. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  409. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  410. self.assertEqual(dumps(5), '020105'.decode('hex'))
  411. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  412. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  413. self.assertEqual(dumps(False), '010100'.decode('hex'))
  414. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  415. self.assertEqual(dumps(None), '0500'.decode('hex'))
  416. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  417. def test_fuzzing(self):
  418. # Make sure that when a failure is detected here, that it
  419. # gets added to test_invalids, so that this function may be
  420. # disabled.
  421. genfailures(float(1))
  422. genfailures([ 1, 2, 'sdlkfj' ])
  423. genfailures({ 1: 2, 5: 'sdlkfj' })
  424. genfailures(set([ 1, 2, 'sdlkfj' ]))
  425. genfailures(True)
  426. genfailures(datetime.datetime.utcnow())
  427. def test_invalids(self):
  428. # Add tests for base 8, 16 floats among others
  429. for v in [ '010101',
  430. '0903040001', # float scaling factor
  431. '0903840001', # float scaling factor
  432. '0903100001', # float base
  433. '0903900001', # float base
  434. '0903000001', # float decimal encoding
  435. '0903830001', # float exponent encoding
  436. '090b827fffcc0df505d0fa58f7', # float large exponent
  437. '3007020101020102040673646c6b666a', # list short string still valid
  438. 'e007020101020102020105040673646c6b666a', # dict short value still valid
  439. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  440. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  441. '181632303136313220383031303933302e3931353133385a', #datetime w/ space
  442. ]:
  443. self.assertRaises(ValueError, loads, v.decode('hex'))
  444. def test_invalid_floats(self):
  445. import mock
  446. with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
  447. self.assertRaises(ValueError, dumps, 1.1)
  448. def test_consume(self):
  449. b = dumps(5)
  450. self.assertRaises(ValueError, loads, b + '398473',
  451. consume=True)
  452. # XXX - still possible that an internal data member
  453. # doesn't consume all
  454. # XXX - test that sets are ordered properly
  455. # XXX - test that dicts are ordered properly..
  456. def test_nan(self):
  457. s = dumps(float('nan'))
  458. v = loads(s)
  459. self.assertTrue(math.isnan(v))
  460. def test_cryptoutilasn1(self):
  461. '''Test DER sequences generated by Crypto.Util.asn1.'''
  462. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  463. ('\x05\x00', None),
  464. ('\x02\x03\x00\x96I', 38473),
  465. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  466. ]:
  467. self.assertEqual(loads(s), v)
  468. def test_longstrings(self):
  469. for i in (203, 65484):
  470. s = os.urandom(i)
  471. v = dumps(s)
  472. self.assertEqual(loads(v), s)
  473. def test_invaliddate(self):
  474. pass
  475. # XXX - add test to reject datetime w/ tzinfo, or that it
  476. # handles it properly
  477. def test_dumps(self):
  478. for i in [ None,
  479. True, False,
  480. -1, 0, 1, 255, 256, -255, -256,
  481. 23498732498723, -2398729387234,
  482. (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  483. float(0), float('-0'), float('inf'), float('-inf'),
  484. float(1.0), float(-1.0), float('353.3487'),
  485. float('2.38723873e+307'), float('2.387349e-317'),
  486. sys.float_info.max, sys.float_info.min,
  487. float('.15625'),
  488. 'weoifjwef',
  489. u'\U0001f4a9',
  490. [], [ 1,2,3 ],
  491. {}, { 5: 10, 'adfkj': 34 },
  492. set(), set((1,2,3)),
  493. set((1,'sjlfdkj', None, float('inf'))),
  494. datetime.datetime.utcnow(),
  495. [ datetime.datetime.utcnow(), ' ' ],
  496. datetime.datetime.utcnow().replace(microsecond=0),
  497. datetime.datetime.utcnow().replace(microsecond=1000),
  498. ]:
  499. s = dumps(i)
  500. o = loads(s)
  501. self.assertEqual(i, o)
  502. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1,
  503. 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  504. out = dumps(tobj)
  505. self.assertEqual(tobj, loads(out))
  506. def test_coerce(self):
  507. class Foo:
  508. pass
  509. class Bar:
  510. pass
  511. class Baz:
  512. pass
  513. def coerce(obj):
  514. if isinstance(obj, Foo):
  515. return 'list', obj.lst
  516. elif isinstance(obj, Baz):
  517. return 'bytes', obj.s
  518. raise TypeError('unknown type')
  519. ac = ASN1Coder(coerce)
  520. v = [1, 2, 3]
  521. o = Foo()
  522. o.lst = v
  523. self.assertEqual(ac.loads(ac.dumps(o)), v)
  524. self.assertRaises(TypeError, ac.dumps, Bar())
  525. v = u'oiejfd'
  526. o = Baz()
  527. o.s = v
  528. es = ac.dumps(o)
  529. self.assertEqual(ac.loads(es), v)
  530. self.assertIsInstance(es, bytes)
  531. self.assertRaises(TypeError, dumps, o)
  532. def test_loads(self):
  533. self.assertRaises(ValueError, loads, '\x00\x02\x00')
  534. def test_nodict(self):
  535. '''Verify that ASN1Coder does not support dict.'''
  536. self.assertRaises(KeyError, ASN1Coder().loads, dumps({}))
  537. def test_dumps_default(self):
  538. '''Test that dumps supports the default method, and that
  539. it works.'''
  540. class Dummy(object):
  541. def somefun(self):
  542. return 5
  543. class Dummy2(object):
  544. def somefun(self):
  545. return [ Dummy() ]
  546. def deffun(obj):
  547. try:
  548. return obj.somefun()
  549. except Exception:
  550. raise TypeError
  551. self.assertEqual(dumps(5), dumps(Dummy(), default=deffun))
  552. # Make sure it works for the various containers
  553. self.assertEqual(dumps([5]), dumps([Dummy()], default=deffun))
  554. self.assertEqual(dumps({ 5: 5 }), dumps({ Dummy(): Dummy() },
  555. default=deffun))
  556. self.assertEqual(dumps([5]), dumps(Dummy2(), default=deffun))
  557. # Make sure that an error is raised when the function doesn't work
  558. self.assertRaises(TypeError, dumps, object(), default=deffun)