A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

649 lines
17 KiB

  1. #!/usr/bin/env python
  2. '''A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. of pickle.
  4. The default dumps/loads uses a profile of ASN.1 that supports serialization
  5. of key/value pairs. This is non-standard. Instantiate the class ASN1Coder
  6. to get a pure ASN.1 serializer/deserializer.
  7. All lengths must be specified. That is that End-of-contents octets
  8. MUST NOT be used. The shorted form of length encoding MUST be used.
  9. A longer length encoding MUST be rejected.'''
  10. __author__ = 'John-Mark Gurney'
  11. __copyright__ = 'Copyright 2016 John-Mark Gurney. All rights reserved.'
  12. __license__ = '2-clause BSD license'
  13. # Copyright 2016, John-Mark Gurney
  14. # All rights reserved.
  15. #
  16. # Redistribution and use in source and binary forms, with or without
  17. # modification, are permitted provided that the following conditions are met:
  18. #
  19. # 1. Redistributions of source code must retain the above copyright notice, this
  20. # list of conditions and the following disclaimer.
  21. # 2. Redistributions in binary form must reproduce the above copyright notice,
  22. # this list of conditions and the following disclaimer in the documentation
  23. # and/or other materials provided with the distribution.
  24. #
  25. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  26. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  27. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  28. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
  29. # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  30. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  31. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  32. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  33. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  34. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  35. #
  36. # The views and conclusions contained in the software and documentation are those
  37. # of the authors and should not be interpreted as representing official policies,
  38. # either expressed or implied, of the Project.
  39. import datetime
  40. import math
  41. import mock
  42. import os
  43. import pdb
  44. import sys
  45. import unittest
  46. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  47. def _numtostr(n):
  48. hs = '%x' % n
  49. if len(hs) & 1 == 1:
  50. hs = '0' + hs
  51. bs = hs.decode('hex')
  52. return bs
  53. def _encodelen(l):
  54. '''Takes l as a length value, and returns a byte string that
  55. represents l per ASN.1 rules.'''
  56. if l < 128:
  57. return chr(l)
  58. bs = _numtostr(l)
  59. return chr(len(bs) | 0x80) + bs
  60. def _decodelen(d, pos=0):
  61. '''Returns the length, and number of bytes required.'''
  62. odp = ord(d[pos])
  63. if odp < 128:
  64. return ord(d[pos]), 1
  65. else:
  66. l = odp & 0x7f
  67. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  68. class Test_codelen(unittest.TestCase):
  69. _testdata = [
  70. (2, '\x02'),
  71. (127, '\x7f'),
  72. (128, '\x81\x80'),
  73. (255, '\x81\xff'),
  74. (256, '\x82\x01\x00'),
  75. (65536-1, '\x82\xff\xff'),
  76. (65536, '\x83\x01\x00\x00'),
  77. ]
  78. def test_el(self):
  79. for i, j in self._testdata:
  80. self.assertEqual(_encodelen(i), j)
  81. self.assertEqual(_decodelen(j), (i, len(j)))
  82. def _splitfloat(f):
  83. m, e = math.frexp(f)
  84. # XXX - less than ideal
  85. while m != math.trunc(m):
  86. m *= 2
  87. e -= 1
  88. return m, e
  89. class TestSplitFloat(unittest.TestCase):
  90. def test_sf(self):
  91. for a, b in [ (0x2421, -32), (0x5382f, 238),
  92. (0x1fa8c3b094adf1, 971) ]:
  93. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  94. class ASN1Coder(object):
  95. '''A class that contains an PASN.1 encoder/decoder.
  96. Exports two methods, loads and dumps.'''
  97. def __init__(self, coerce=None):
  98. '''If the arg coerce is provided, when dumping the object,
  99. if the type is not found, the coerce function will be called
  100. with the obj. It is expected to return a tuple of a string
  101. and an object that has the method w/ the string as defined:
  102. 'bool': __nonzero__
  103. 'float': compatible w/ float
  104. 'int': compatible w/ int
  105. 'list': __iter__
  106. 'set': __iter__
  107. 'bytes': __str__
  108. 'null': no method needed
  109. 'unicode': encode method returns UTF-8 encoded bytes
  110. 'datetime': strftime and microsecond
  111. '''
  112. self.coerce = coerce
  113. _typemap = {
  114. bool: 'bool',
  115. float: 'float',
  116. int: 'int',
  117. list: 'list',
  118. long: 'int',
  119. set: 'set',
  120. str: 'bytes',
  121. type(None): 'null',
  122. unicode: 'unicode',
  123. #decimal.Decimal: 'float',
  124. datetime.datetime: 'datetime',
  125. #datetime.timedelta: 'timedelta',
  126. }
  127. _tagmap = {
  128. '\x01': 'bool',
  129. '\x02': 'int',
  130. '\x04': 'bytes',
  131. '\x05': 'null',
  132. '\x09': 'float',
  133. '\x0c': 'unicode',
  134. '\x18': 'datetime',
  135. '\x30': 'list',
  136. '\x31': 'set',
  137. }
  138. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  139. @staticmethod
  140. def enc_int(obj):
  141. l = obj.bit_length()
  142. l += 1 # space for sign bit
  143. l = (l + 7) // 8
  144. if obj < 0:
  145. obj += 1 << (l * 8) # twos-complement conversion
  146. v = _numtostr(obj)
  147. if len(v) != l:
  148. # XXX - is this a problem for signed values?
  149. v = '\x00' + v # add sign octect
  150. return _encodelen(l) + v
  151. @staticmethod
  152. def dec_int(d, pos, end):
  153. if pos == end:
  154. return 0, end
  155. v = int(d[pos:end].encode('hex'), 16)
  156. av = 1 << ((end - pos) * 8 - 1) # sign bit
  157. if v > av:
  158. v -= av * 2 # twos-complement conversion
  159. return v, end
  160. @staticmethod
  161. def enc_bool(obj):
  162. return '\x01' + ('\xff' if obj else '\x00')
  163. def dec_bool(self, d, pos, end):
  164. v = self.dec_int(d, pos, end)[0]
  165. if v not in (-1, 0):
  166. raise ValueError('invalid bool value: %d' % v)
  167. return bool(v), end
  168. @staticmethod
  169. def enc_null(obj):
  170. return '\x00'
  171. @staticmethod
  172. def dec_null(d, pos, end):
  173. return None, end
  174. def enc_list(self, obj):
  175. r = ''.join(self.dumps(x) for x in obj)
  176. return _encodelen(len(r)) + r
  177. def dec_list(self, d, pos, end):
  178. r = []
  179. vend = pos
  180. while pos < end:
  181. v, vend = self._loads(d, pos, end)
  182. if vend > end:
  183. raise ValueError('load past end')
  184. r.append(v)
  185. pos = vend
  186. return r, vend
  187. enc_set = enc_list
  188. def dec_set(self, d, pos, end):
  189. r, end = self.dec_list(d, pos, end)
  190. return set(r), end
  191. @staticmethod
  192. def enc_bytes(obj):
  193. return _encodelen(len(obj)) + bytes(obj)
  194. @staticmethod
  195. def dec_bytes(d, pos, end):
  196. return d[pos:end], end
  197. @staticmethod
  198. def enc_unicode(obj):
  199. encobj = obj.encode('utf-8')
  200. return _encodelen(len(encobj)) + encobj
  201. def dec_unicode(self, d, pos, end):
  202. return d[pos:end].decode('utf-8'), end
  203. @staticmethod
  204. def enc_float(obj):
  205. s = math.copysign(1, obj)
  206. if math.isnan(obj):
  207. return _encodelen(1) + chr(0b01000010)
  208. elif math.isinf(obj):
  209. if s == 1:
  210. return _encodelen(1) + chr(0b01000000)
  211. else:
  212. return _encodelen(1) + chr(0b01000001)
  213. elif obj == 0:
  214. if s == 1:
  215. return _encodelen(0)
  216. else:
  217. return _encodelen(1) + chr(0b01000011)
  218. m, e = _splitfloat(obj)
  219. # Binary encoding
  220. val = 0x80
  221. if m < 0:
  222. val |= 0x40
  223. m = -m
  224. # Base 2
  225. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  226. if el > 2:
  227. raise ValueError('exponent too large')
  228. if e < 0:
  229. e += 256**el # convert negative to twos-complement
  230. v = el - 1
  231. encexp = _numtostr(e)
  232. val |= v
  233. r = chr(val) + encexp + _numtostr(m)
  234. return _encodelen(len(r)) + r
  235. def dec_float(self, d, pos, end):
  236. if pos == end:
  237. return float(0), end
  238. v = ord(d[pos])
  239. if v == 0b01000000:
  240. return float('inf'), end
  241. elif v == 0b01000001:
  242. return float('-inf'), end
  243. elif v == 0b01000010:
  244. return float('nan'), end
  245. elif v == 0b01000011:
  246. return float('-0'), end
  247. elif v & 0b110000:
  248. raise ValueError('base must be 2')
  249. elif v & 0b1100:
  250. raise ValueError('scaling factor must be 0')
  251. elif v & 0b11000000 == 0:
  252. raise ValueError('decimal encoding not supported')
  253. #elif v & 0b11000000 == 0b01000000:
  254. # raise ValueError('invalid encoding')
  255. if (v & 3) >= 2:
  256. raise ValueError('large exponents not supported')
  257. pexp = pos + 1
  258. eexp = pos + 1 + (v & 3) + 1
  259. exp = self.dec_int(d, pexp, eexp)[0]
  260. n = float(int(d[eexp:end].encode('hex'), 16))
  261. r = n * 2 ** exp
  262. if v & 0b1000000:
  263. r = -r
  264. return r, end
  265. def dumps(self, obj):
  266. '''Convert obj into an array of bytes.'''
  267. try:
  268. tf = self._typemap[type(obj)]
  269. except KeyError:
  270. if self.coerce is None:
  271. raise TypeError('unhandled object: %s' % `obj`)
  272. tf, obj = self.coerce(obj)
  273. fun = getattr(self, 'enc_%s' % tf)
  274. return self._typetag[tf] + fun(obj)
  275. def _loads(self, data, pos, end):
  276. tag = data[pos]
  277. l, b = _decodelen(data, pos + 1)
  278. if len(data) < pos + 1 + b + l:
  279. raise ValueError('string not long enough')
  280. # XXX - enforce that len(data) == end?
  281. end = pos + 1 + b + l
  282. t = self._tagmap[tag]
  283. fun = getattr(self, 'dec_%s' % t)
  284. return fun(data, pos + 1 + b, end)
  285. def enc_datetime(self, obj):
  286. ts = obj.strftime('%Y%m%d%H%M%S')
  287. if obj.microsecond:
  288. ts += ('.%06d' % obj.microsecond).rstrip('0')
  289. ts += 'Z'
  290. return _encodelen(len(ts)) + ts
  291. def dec_datetime(self, data, pos, end):
  292. ts = data[pos:end]
  293. if ts[-1] != 'Z':
  294. raise ValueError('last character must be Z')
  295. # Real bug is in strptime, but work around it here.
  296. if ' ' in data:
  297. raise ValueError('no spaces are allowed')
  298. if '.' in ts:
  299. fstr = '%Y%m%d%H%M%S.%fZ'
  300. if ts.endswith('0Z'):
  301. raise ValueError('invalid trailing zeros')
  302. else:
  303. fstr = '%Y%m%d%H%M%SZ'
  304. return datetime.datetime.strptime(ts, fstr), end
  305. def loads(self, data, pos=0, end=None, consume=False):
  306. '''Load from data, starting at pos (optional), and ending
  307. at end (optional). If it is required to consume the
  308. whole string (not the default), set consume to True, and
  309. a ValueError will be raised if the string is not
  310. completely consumed. The second item in ValueError will
  311. be the possition that was the detected end.'''
  312. if end is None:
  313. end = len(data)
  314. r, e = self._loads(data, pos, end)
  315. if consume and e != end:
  316. raise ValueError('entire string not consumed', e)
  317. return r
  318. class ASN1DictCoder(ASN1Coder):
  319. '''This adds support for the non-standard dict serialization.
  320. The coerce method also supports the following type:
  321. 'dict': iteritems
  322. '''
  323. _typemap = ASN1Coder._typemap.copy()
  324. _typemap[dict] = 'dict'
  325. _tagmap = ASN1Coder._tagmap.copy()
  326. _tagmap['\xe0'] = 'dict'
  327. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  328. def enc_dict(self, obj):
  329. #it = list(obj.iteritems())
  330. #it.sort()
  331. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in
  332. obj.iteritems())
  333. return _encodelen(len(r)) + r
  334. def dec_dict(self, d, pos, end):
  335. r = {}
  336. vend = pos
  337. while pos < end:
  338. k, kend = self._loads(d, pos, end)
  339. #if kend > end:
  340. # raise ValueError('key past end')
  341. v, vend = self._loads(d, kend, end)
  342. if vend > end:
  343. raise ValueError('value past end')
  344. r[k] = v
  345. pos = vend
  346. return r, vend
  347. _coder = ASN1DictCoder()
  348. dumps = _coder.dumps
  349. loads = _coder.loads
  350. def deeptypecmp(obj, o):
  351. #print 'dtc:', `obj`, `o`
  352. if type(obj) != type(o):
  353. return False
  354. if type(obj) in (str, unicode):
  355. return True
  356. if type(obj) in (list, set):
  357. for i, j in zip(obj, o):
  358. if not deeptypecmp(i, j):
  359. return False
  360. if type(obj) in (dict,):
  361. itms = obj.items()
  362. itms.sort()
  363. nitms = o.items()
  364. nitms.sort()
  365. for (k, v), (nk, nv) in zip(itms, nitms):
  366. if not deeptypecmp(k, nk):
  367. return False
  368. if not deeptypecmp(v, nv):
  369. return False
  370. return True
  371. class Test_deeptypecmp(unittest.TestCase):
  372. def test_true(self):
  373. for i in ((1,1), ('sldkfj', 'sldkfj')
  374. ):
  375. self.assertTrue(deeptypecmp(*i))
  376. def test_false(self):
  377. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  378. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  379. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  380. ):
  381. self.assertFalse(deeptypecmp(*i))
  382. def genfailures(obj):
  383. s = dumps(obj)
  384. for i in xrange(len(s)):
  385. for j in (chr(x) for x in xrange(256)):
  386. ts = s[:i] + j + s[i + 1:]
  387. if ts == s:
  388. continue
  389. try:
  390. o = loads(ts, consume=True)
  391. if o != obj or not deeptypecmp(o, obj):
  392. raise ValueError
  393. except (ValueError, KeyError, IndexError, TypeError):
  394. pass
  395. else: # pragma: no cover
  396. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  397. class TestCode(unittest.TestCase):
  398. def test_primv(self):
  399. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  400. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  401. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  402. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  403. self.assertEqual(dumps(5), '020105'.decode('hex'))
  404. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  405. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  406. self.assertEqual(dumps(False), '010100'.decode('hex'))
  407. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  408. self.assertEqual(dumps(None), '0500'.decode('hex'))
  409. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  410. def test_fuzzing(self):
  411. # Make sure that when a failure is detected here, that it
  412. # gets added to test_invalids, so that this function may be
  413. # disabled.
  414. genfailures(float(1))
  415. genfailures([ 1, 2, 'sdlkfj' ])
  416. genfailures({ 1: 2, 5: 'sdlkfj' })
  417. genfailures(set([ 1, 2, 'sdlkfj' ]))
  418. genfailures(True)
  419. genfailures(datetime.datetime.utcnow())
  420. def test_invalids(self):
  421. # Add tests for base 8, 16 floats among others
  422. for v in [ '010101',
  423. '0903040001', # float scaling factor
  424. '0903840001', # float scaling factor
  425. '0903100001', # float base
  426. '0903900001', # float base
  427. '0903000001', # float decimal encoding
  428. '0903830001', # float exponent encoding
  429. '090b827fffcc0df505d0fa58f7', # float large exponent
  430. '3007020101020102040673646c6b666a', # list short string still valid
  431. 'e007020101020102020105040673646c6b666a', # dict short value still valid
  432. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  433. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  434. '181632303136313220383031303933302e3931353133385a', #datetime w/ space
  435. ]:
  436. self.assertRaises(ValueError, loads, v.decode('hex'))
  437. def test_invalid_floats(self):
  438. with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
  439. self.assertRaises(ValueError, dumps, 1.1)
  440. def test_consume(self):
  441. b = dumps(5)
  442. self.assertRaises(ValueError, loads, b + '398473',
  443. consume=True)
  444. # XXX - still possible that an internal data member
  445. # doesn't consume all
  446. # XXX - test that sets are ordered properly
  447. # XXX - test that dicts are ordered properly..
  448. def test_nan(self):
  449. s = dumps(float('nan'))
  450. v = loads(s)
  451. self.assertTrue(math.isnan(v))
  452. def test_cryptoutilasn1(self):
  453. '''Test DER sequences generated by Crypto.Util.asn1.'''
  454. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  455. ('\x05\x00', None),
  456. ('\x02\x03\x00\x96I', 38473),
  457. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  458. ]:
  459. self.assertEqual(loads(s), v)
  460. def test_longstrings(self):
  461. for i in (203, 65484):
  462. s = os.urandom(i)
  463. v = dumps(s)
  464. self.assertEqual(loads(v), s)
  465. def test_invaliddate(self):
  466. pass
  467. # XXX - add test to reject datetime w/ tzinfo, or that it
  468. # handles it properly
  469. def test_dumps(self):
  470. for i in [ None,
  471. True, False,
  472. -1, 0, 1, 255, 256, -255, -256,
  473. 23498732498723, -2398729387234,
  474. (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  475. float(0), float('-0'), float('inf'), float('-inf'),
  476. float(1.0), float(-1.0), float('353.3487'),
  477. float('2.38723873e+307'), float('2.387349e-317'),
  478. sys.float_info.max, sys.float_info.min,
  479. float('.15625'),
  480. 'weoifjwef',
  481. u'\U0001f4a9',
  482. [], [ 1,2,3 ],
  483. {}, { 5: 10, 'adfkj': 34 },
  484. set(), set((1,2,3)),
  485. set((1,'sjlfdkj', None, float('inf'))),
  486. datetime.datetime.utcnow(),
  487. datetime.datetime.utcnow().replace(microsecond=0),
  488. datetime.datetime.utcnow().replace(microsecond=1000),
  489. ]:
  490. s = dumps(i)
  491. o = loads(s)
  492. self.assertEqual(i, o)
  493. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1,
  494. 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  495. out = dumps(tobj)
  496. self.assertEqual(tobj, loads(out))
  497. def test_coerce(self):
  498. class Foo:
  499. pass
  500. class Bar:
  501. pass
  502. class Baz:
  503. pass
  504. def coerce(obj):
  505. if isinstance(obj, Foo):
  506. return 'list', obj.lst
  507. elif isinstance(obj, Baz):
  508. return 'bytes', obj.s
  509. raise TypeError('unknown type')
  510. ac = ASN1Coder(coerce)
  511. v = [1, 2, 3]
  512. o = Foo()
  513. o.lst = v
  514. self.assertEqual(ac.loads(ac.dumps(o)), v)
  515. self.assertRaises(TypeError, ac.dumps, Bar())
  516. v = u'oiejfd'
  517. o = Baz()
  518. o.s = v
  519. es = ac.dumps(o)
  520. self.assertEqual(ac.loads(es), v)
  521. self.assertIsInstance(es, bytes)
  522. self.assertRaises(TypeError, dumps, o)
  523. def test_loads(self):
  524. self.assertRaises(ValueError, loads, '\x00\x02\x00')
  525. def test_nodict(self):
  526. '''Verify that ASN1Coder does not support dict.'''
  527. self.assertRaises(KeyError, ASN1Coder().loads, dumps({}))