A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

590 lines
14 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle. It will automaticly do the correct thing if possible.
  4. #
  5. # This uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import datetime
  11. import math
  12. import mock
  13. import os
  14. import pdb
  15. import sys
  16. import unittest
  17. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  18. def _numtostr(n):
  19. hs = '%x' % n
  20. if len(hs) & 1 == 1:
  21. hs = '0' + hs
  22. bs = hs.decode('hex')
  23. return bs
  24. def _encodelen(l):
  25. '''Takes l as a length value, and returns a byte string that
  26. represents l per ASN.1 rules.'''
  27. if l < 128:
  28. return chr(l)
  29. bs = _numtostr(l)
  30. return chr(len(bs) | 0x80) + bs
  31. def _decodelen(d, pos=0):
  32. '''Returns the length, and number of bytes required.'''
  33. odp = ord(d[pos])
  34. if odp < 128:
  35. return ord(d[pos]), 1
  36. else:
  37. l = odp & 0x7f
  38. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  39. class Test_codelen(unittest.TestCase):
  40. _testdata = [
  41. (2, '\x02'),
  42. (127, '\x7f'),
  43. (128, '\x81\x80'),
  44. (255, '\x81\xff'),
  45. (256, '\x82\x01\x00'),
  46. (65536-1, '\x82\xff\xff'),
  47. (65536, '\x83\x01\x00\x00'),
  48. ]
  49. def test_el(self):
  50. for i, j in self._testdata:
  51. self.assertEqual(_encodelen(i), j)
  52. self.assertEqual(_decodelen(j), (i, len(j)))
  53. def _splitfloat(f):
  54. m, e = math.frexp(f)
  55. # XXX - less than ideal
  56. while m != math.trunc(m):
  57. m *= 2
  58. e -= 1
  59. return m, e
  60. class TestSplitFloat(unittest.TestCase):
  61. def test_sf(self):
  62. for a, b in [ (0x2421, -32), (0x5382f, 238),
  63. (0x1fa8c3b094adf1, 971) ]:
  64. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  65. class ASN1Coder(object):
  66. '''A class that contains an PASN.1 encoder/decoder.
  67. Exports two methods, loads and dumps.'''
  68. def __init__(self, coerce=None):
  69. '''If the arg coerce is provided, when dumping the object,
  70. if the type is not found, the coerce function will be called
  71. with the obj. It is expected to return a tuple of a string
  72. and an object that has the method w/ the string as defined:
  73. 'bool': __nonzero__
  74. 'dict': iteritems
  75. 'float': compatible w/ float
  76. 'int': compatible w/ int
  77. 'list': __iter__
  78. 'set': __iter__
  79. 'bytes': __str__
  80. 'null': no method needed
  81. 'unicode': encode method returns UTF-8 encoded bytes
  82. 'datetime': strftime and microsecond
  83. '''
  84. self.coerce = coerce
  85. _typemap = {
  86. bool: 'bool',
  87. dict: 'dict',
  88. float: 'float',
  89. int: 'int',
  90. list: 'list',
  91. long: 'int',
  92. set: 'set',
  93. str: 'bytes',
  94. type(None): 'null',
  95. unicode: 'unicode',
  96. #decimal.Decimal: 'float',
  97. datetime.datetime: 'datetime',
  98. #datetime.timedelta: 'timedelta',
  99. }
  100. _tagmap = {
  101. '\x01': 'bool',
  102. '\x02': 'int',
  103. '\x04': 'bytes',
  104. '\x05': 'null',
  105. '\x09': 'float',
  106. '\x0c': 'unicode',
  107. '\x18': 'datetime',
  108. '\x30': 'list',
  109. '\x31': 'set',
  110. '\xc0': 'dict',
  111. }
  112. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  113. @staticmethod
  114. def enc_int(obj):
  115. l = obj.bit_length()
  116. l += 1 # space for sign bit
  117. l = (l + 7) // 8
  118. if obj < 0:
  119. obj += 1 << (l * 8) # twos-complement conversion
  120. v = _numtostr(obj)
  121. if len(v) != l:
  122. # XXX - is this a problem for signed values?
  123. v = '\x00' + v # add sign octect
  124. return _encodelen(l) + v
  125. @staticmethod
  126. def dec_int(d, pos, end):
  127. if pos == end:
  128. return 0, end
  129. v = int(d[pos:end].encode('hex'), 16)
  130. av = 1 << ((end - pos) * 8 - 1) # sign bit
  131. if v > av:
  132. v -= av * 2 # twos-complement conversion
  133. return v, end
  134. @staticmethod
  135. def enc_bool(obj):
  136. return '\x01' + ('\xff' if obj else '\x00')
  137. def dec_bool(self, d, pos, end):
  138. v = self.dec_int(d, pos, end)[0]
  139. if v not in (-1, 0):
  140. raise ValueError('invalid bool value: %d' % v)
  141. return bool(v), end
  142. @staticmethod
  143. def enc_null(obj):
  144. return '\x00'
  145. @staticmethod
  146. def dec_null(d, pos, end):
  147. return None, end
  148. def enc_dict(self, obj):
  149. #it = list(obj.iteritems())
  150. #it.sort()
  151. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems())
  152. return _encodelen(len(r)) + r
  153. def dec_dict(self, d, pos, end):
  154. r = {}
  155. vend = pos
  156. while pos < end:
  157. k, kend = self._loads(d, pos, end)
  158. #if kend > end:
  159. # raise ValueError('key past end')
  160. v, vend = self._loads(d, kend, end)
  161. if vend > end:
  162. raise ValueError('value past end')
  163. r[k] = v
  164. pos = vend
  165. return r, vend
  166. def enc_list(self, obj):
  167. r = ''.join(self.dumps(x) for x in obj)
  168. return _encodelen(len(r)) + r
  169. def dec_list(self, d, pos, end):
  170. r = []
  171. vend = pos
  172. while pos < end:
  173. v, vend = self._loads(d, pos, end)
  174. if vend > end:
  175. raise ValueError('load past end')
  176. r.append(v)
  177. pos = vend
  178. return r, vend
  179. enc_set = enc_list
  180. def dec_set(self, d, pos, end):
  181. r, end = self.dec_list(d, pos, end)
  182. return set(r), end
  183. @staticmethod
  184. def enc_bytes(obj):
  185. return _encodelen(len(obj)) + bytes(obj)
  186. @staticmethod
  187. def dec_bytes(d, pos, end):
  188. return d[pos:end], end
  189. @staticmethod
  190. def enc_unicode(obj):
  191. encobj = obj.encode('utf-8')
  192. return _encodelen(len(encobj)) + encobj
  193. def dec_unicode(self, d, pos, end):
  194. return d[pos:end].decode('utf-8'), end
  195. @staticmethod
  196. def enc_float(obj):
  197. s = math.copysign(1, obj)
  198. if math.isnan(obj):
  199. return _encodelen(1) + chr(0b01000010)
  200. elif math.isinf(obj):
  201. if s == 1:
  202. return _encodelen(1) + chr(0b01000000)
  203. else:
  204. return _encodelen(1) + chr(0b01000001)
  205. elif obj == 0:
  206. if s == 1:
  207. return _encodelen(0)
  208. else:
  209. return _encodelen(1) + chr(0b01000011)
  210. m, e = _splitfloat(obj)
  211. # Binary encoding
  212. val = 0x80
  213. if m < 0:
  214. val |= 0x40
  215. m = -m
  216. # Base 2
  217. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  218. if el > 2:
  219. raise ValueError('exponent too large')
  220. if e < 0:
  221. e += 256**el # convert negative to twos-complement
  222. v = el - 1
  223. encexp = _numtostr(e)
  224. val |= v
  225. r = chr(val) + encexp + _numtostr(m)
  226. return _encodelen(len(r)) + r
  227. def dec_float(self, d, pos, end):
  228. if pos == end:
  229. return float(0), end
  230. v = ord(d[pos])
  231. if v == 0b01000000:
  232. return float('inf'), end
  233. elif v == 0b01000001:
  234. return float('-inf'), end
  235. elif v == 0b01000010:
  236. return float('nan'), end
  237. elif v == 0b01000011:
  238. return float('-0'), end
  239. elif v & 0b110000:
  240. raise ValueError('base must be 2')
  241. elif v & 0b1100:
  242. raise ValueError('scaling factor must be 0')
  243. elif v & 0b11000000 == 0:
  244. raise ValueError('decimal encoding not supported')
  245. #elif v & 0b11000000 == 0b01000000:
  246. # raise ValueError('invalid encoding')
  247. if (v & 3) >= 2:
  248. raise ValueError('large exponents not supported')
  249. pexp = pos + 1
  250. eexp = pos + 1 + (v & 3) + 1
  251. exp = self.dec_int(d, pexp, eexp)[0]
  252. n = float(int(d[eexp:end].encode('hex'), 16))
  253. r = n * 2 ** exp
  254. if v & 0b1000000:
  255. r = -r
  256. return r, end
  257. def dumps(self, obj):
  258. '''Convert obj into a string.'''
  259. try:
  260. tf = self._typemap[type(obj)]
  261. except KeyError:
  262. if self.coerce is None:
  263. raise TypeError('unhandled object: %s' % `obj`)
  264. tf, obj = self.coerce(obj)
  265. fun = getattr(self, 'enc_%s' % tf)
  266. return self._typetag[tf] + fun(obj)
  267. def _loads(self, data, pos, end):
  268. tag = data[pos]
  269. l, b = _decodelen(data, pos + 1)
  270. if len(data) < pos + 1 + b + l:
  271. raise ValueError('string not long enough')
  272. # XXX - enforce that len(data) == end?
  273. end = pos + 1 + b + l
  274. t = self._tagmap[tag]
  275. fun = getattr(self, 'dec_%s' % t)
  276. return fun(data, pos + 1 + b, end)
  277. def enc_datetime(self, obj):
  278. ts = obj.strftime('%Y%m%d%H%M%S')
  279. if obj.microsecond:
  280. ts += ('.%06d' % obj.microsecond).rstrip('0')
  281. ts += 'Z'
  282. return _encodelen(len(ts)) + ts
  283. def dec_datetime(self, data, pos, end):
  284. ts = data[pos:end]
  285. if ts[-1] != 'Z':
  286. raise ValueError('last character must be Z')
  287. if '.' in ts:
  288. fstr = '%Y%m%d%H%M%S.%fZ'
  289. if ts.endswith('0Z'):
  290. raise ValueError('invalid trailing zeros')
  291. else:
  292. fstr = '%Y%m%d%H%M%SZ'
  293. return datetime.datetime.strptime(ts, fstr), end
  294. def loads(self, data, pos=0, end=None, consume=False):
  295. '''Load from data, starting at pos (option), and ending
  296. at end (optional). If it is required to consume the
  297. whole string (not the default), set consume to True, and
  298. a ValueError will be raised if the string is not
  299. completely consumed. The second item in ValueError will
  300. be the possition that was the detected end.'''
  301. if end is None:
  302. end = len(data)
  303. r, e = self._loads(data, pos, end)
  304. if consume and e != end:
  305. raise ValueError('entire string not consumed', e)
  306. return r
  307. _coder = ASN1Coder()
  308. dumps = _coder.dumps
  309. loads = _coder.loads
  310. def deeptypecmp(obj, o):
  311. #print 'dtc:', `obj`, `o`
  312. if type(obj) != type(o):
  313. return False
  314. if type(obj) in (str, unicode):
  315. return True
  316. if type(obj) in (list, set):
  317. for i, j in zip(obj, o):
  318. if not deeptypecmp(i, j):
  319. return False
  320. if type(obj) in (dict,):
  321. itms = obj.items()
  322. itms.sort()
  323. nitms = o.items()
  324. nitms.sort()
  325. for (k, v), (nk, nv) in zip(itms, nitms):
  326. if not deeptypecmp(k, nk):
  327. return False
  328. if not deeptypecmp(v, nv):
  329. return False
  330. return True
  331. class Test_deeptypecmp(unittest.TestCase):
  332. def test_true(self):
  333. for i in ((1,1), ('sldkfj', 'sldkfj')
  334. ):
  335. self.assertTrue(deeptypecmp(*i))
  336. def test_false(self):
  337. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  338. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  339. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  340. ):
  341. self.assertFalse(deeptypecmp(*i))
  342. def genfailures(obj):
  343. s = dumps(obj)
  344. for i in xrange(len(s)):
  345. for j in (chr(x) for x in xrange(256)):
  346. ts = s[:i] + j + s[i + 1:]
  347. if ts == s:
  348. continue
  349. try:
  350. o = loads(ts, consume=True)
  351. if o != obj or not deeptypecmp(o, obj):
  352. raise ValueError
  353. except (ValueError, KeyError, IndexError, TypeError):
  354. pass
  355. except Exception:
  356. raise
  357. else:
  358. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  359. class TestCode(unittest.TestCase):
  360. def test_primv(self):
  361. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  362. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  363. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  364. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  365. self.assertEqual(dumps(5), '020105'.decode('hex'))
  366. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  367. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  368. self.assertEqual(dumps(False), '010100'.decode('hex'))
  369. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  370. self.assertEqual(dumps(None), '0500'.decode('hex'))
  371. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  372. def test_fuzzing(self):
  373. # Make sure that when a failure is detected here, that it
  374. # gets added to test_invalids, so that this function may be
  375. # disabled.
  376. genfailures(float(1))
  377. genfailures([ 1, 2, 'sdlkfj' ])
  378. genfailures({ 1: 2, 5: 'sdlkfj' })
  379. genfailures(set([ 1, 2, 'sdlkfj' ]))
  380. genfailures(True)
  381. genfailures(datetime.datetime.utcnow())
  382. def test_invalids(self):
  383. # Add tests for base 8, 16 floats among others
  384. for v in [ '010101',
  385. '0903040001', # float scaling factor
  386. '0903840001', # float scaling factor
  387. '0903100001', # float base
  388. '0903900001', # float base
  389. '0903000001', # float decimal encoding
  390. '0903830001', # float exponent encoding
  391. '090b827fffcc0df505d0fa58f7', # float large exponent
  392. '3007020101020102040673646c6b666a', # list short string still valid
  393. 'c007020101020102020105040673646c6b666a', # dict short value still valid
  394. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  395. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  396. ]:
  397. self.assertRaises(ValueError, loads, v.decode('hex'))
  398. def test_invalid_floats(self):
  399. with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
  400. self.assertRaises(ValueError, dumps, 1.1)
  401. def test_consume(self):
  402. b = dumps(5)
  403. self.assertRaises(ValueError, loads, b + '398473', consume=True)
  404. # XXX - still possible that an internal data member
  405. # doesn't consume all
  406. # XXX - test that sets are ordered properly
  407. # XXX - test that dicts are ordered properly..
  408. def test_nan(self):
  409. s = dumps(float('nan'))
  410. v = loads(s)
  411. self.assertTrue(math.isnan(v))
  412. def test_cryptoutilasn1(self):
  413. '''Test DER sequences generated by Crypto.Util.asn1.'''
  414. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  415. ('\x05\x00', None),
  416. ('\x02\x03\x00\x96I', 38473),
  417. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  418. ]:
  419. self.assertEqual(loads(s), v)
  420. def test_longstrings(self):
  421. for i in (203, 65484):
  422. s = os.urandom(i)
  423. v = dumps(s)
  424. self.assertEqual(loads(v), s)
  425. def test_invaliddate(self):
  426. pass
  427. # XXX - add test to reject datetime w/ tzinfo, or that it handles it
  428. # properly
  429. def test_dumps(self):
  430. for i in [ None,
  431. True, False,
  432. -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  433. float(0), float('-0'), float('inf'), float('-inf'), float(1.0), float(-1.0),
  434. float('353.3487'), float('2.38723873e+307'), float('2.387349e-317'),
  435. sys.float_info.max, sys.float_info.min,
  436. float('.15625'),
  437. 'weoifjwef',
  438. u'\U0001f4a9',
  439. [], [ 1,2,3 ],
  440. {}, { 5: 10, 'adfkj': 34 },
  441. set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
  442. datetime.datetime.utcnow(), datetime.datetime.utcnow().replace(microsecond=0),
  443. datetime.datetime.utcnow().replace(microsecond=1000),
  444. ]:
  445. s = dumps(i)
  446. o = loads(s)
  447. self.assertEqual(i, o)
  448. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  449. out = dumps(tobj)
  450. self.assertEqual(tobj, loads(out))
  451. def test_coerce(self):
  452. class Foo:
  453. pass
  454. class Bar:
  455. pass
  456. class Baz:
  457. pass
  458. def coerce(obj):
  459. if isinstance(obj, Foo):
  460. return 'list', obj.lst
  461. elif isinstance(obj, Baz):
  462. return 'bytes', obj.s
  463. raise TypeError('unknown type')
  464. ac = ASN1Coder(coerce)
  465. v = [1, 2, 3]
  466. o = Foo()
  467. o.lst = v
  468. self.assertEqual(ac.loads(ac.dumps(o)), v)
  469. self.assertRaises(TypeError, ac.dumps, Bar())
  470. v = u'oiejfd'
  471. o = Baz()
  472. o.s = v
  473. es = ac.dumps(o)
  474. self.assertEqual(ac.loads(es), v)
  475. self.assertIsInstance(es, bytes)
  476. self.assertRaises(TypeError, dumps, o)
  477. def test_loads(self):
  478. self.assertRaises(ValueError, loads, '\x00\x02\x00')