A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

596 lines
14 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle.
  4. #
  5. # It uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import datetime
  11. import math
  12. import mock
  13. import os
  14. import pdb
  15. import sys
  16. import unittest
  17. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  18. def _numtostr(n):
  19. hs = '%x' % n
  20. if len(hs) & 1 == 1:
  21. hs = '0' + hs
  22. bs = hs.decode('hex')
  23. return bs
  24. def _encodelen(l):
  25. '''Takes l as a length value, and returns a byte string that
  26. represents l per ASN.1 rules.'''
  27. if l < 128:
  28. return chr(l)
  29. bs = _numtostr(l)
  30. return chr(len(bs) | 0x80) + bs
  31. def _decodelen(d, pos=0):
  32. '''Returns the length, and number of bytes required.'''
  33. odp = ord(d[pos])
  34. if odp < 128:
  35. return ord(d[pos]), 1
  36. else:
  37. l = odp & 0x7f
  38. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  39. class Test_codelen(unittest.TestCase):
  40. _testdata = [
  41. (2, '\x02'),
  42. (127, '\x7f'),
  43. (128, '\x81\x80'),
  44. (255, '\x81\xff'),
  45. (256, '\x82\x01\x00'),
  46. (65536-1, '\x82\xff\xff'),
  47. (65536, '\x83\x01\x00\x00'),
  48. ]
  49. def test_el(self):
  50. for i, j in self._testdata:
  51. self.assertEqual(_encodelen(i), j)
  52. self.assertEqual(_decodelen(j), (i, len(j)))
  53. def _splitfloat(f):
  54. m, e = math.frexp(f)
  55. # XXX - less than ideal
  56. while m != math.trunc(m):
  57. m *= 2
  58. e -= 1
  59. return m, e
  60. class TestSplitFloat(unittest.TestCase):
  61. def test_sf(self):
  62. for a, b in [ (0x2421, -32), (0x5382f, 238),
  63. (0x1fa8c3b094adf1, 971) ]:
  64. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  65. class ASN1Coder(object):
  66. '''A class that contains an PASN.1 encoder/decoder.
  67. Exports two methods, loads and dumps.'''
  68. def __init__(self, coerce=None):
  69. '''If the arg coerce is provided, when dumping the object,
  70. if the type is not found, the coerce function will be called
  71. with the obj. It is expected to return a tuple of a string
  72. and an object that has the method w/ the string as defined:
  73. 'bool': __nonzero__
  74. 'dict': iteritems
  75. 'float': compatible w/ float
  76. 'int': compatible w/ int
  77. 'list': __iter__
  78. 'set': __iter__
  79. 'bytes': __str__
  80. 'null': no method needed
  81. 'unicode': encode method returns UTF-8 encoded bytes
  82. 'datetime': strftime and microsecond
  83. '''
  84. self.coerce = coerce
  85. _typemap = {
  86. bool: 'bool',
  87. dict: 'dict',
  88. float: 'float',
  89. int: 'int',
  90. list: 'list',
  91. long: 'int',
  92. set: 'set',
  93. str: 'bytes',
  94. type(None): 'null',
  95. unicode: 'unicode',
  96. #decimal.Decimal: 'float',
  97. datetime.datetime: 'datetime',
  98. #datetime.timedelta: 'timedelta',
  99. }
  100. _tagmap = {
  101. '\x01': 'bool',
  102. '\x02': 'int',
  103. '\x04': 'bytes',
  104. '\x05': 'null',
  105. '\x09': 'float',
  106. '\x0c': 'unicode',
  107. '\x18': 'datetime',
  108. '\x30': 'list',
  109. '\x31': 'set',
  110. '\xe0': 'dict',
  111. }
  112. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  113. @staticmethod
  114. def enc_int(obj):
  115. l = obj.bit_length()
  116. l += 1 # space for sign bit
  117. l = (l + 7) // 8
  118. if obj < 0:
  119. obj += 1 << (l * 8) # twos-complement conversion
  120. v = _numtostr(obj)
  121. if len(v) != l:
  122. # XXX - is this a problem for signed values?
  123. v = '\x00' + v # add sign octect
  124. return _encodelen(l) + v
  125. @staticmethod
  126. def dec_int(d, pos, end):
  127. if pos == end:
  128. return 0, end
  129. v = int(d[pos:end].encode('hex'), 16)
  130. av = 1 << ((end - pos) * 8 - 1) # sign bit
  131. if v > av:
  132. v -= av * 2 # twos-complement conversion
  133. return v, end
  134. @staticmethod
  135. def enc_bool(obj):
  136. return '\x01' + ('\xff' if obj else '\x00')
  137. def dec_bool(self, d, pos, end):
  138. v = self.dec_int(d, pos, end)[0]
  139. if v not in (-1, 0):
  140. raise ValueError('invalid bool value: %d' % v)
  141. return bool(v), end
  142. @staticmethod
  143. def enc_null(obj):
  144. return '\x00'
  145. @staticmethod
  146. def dec_null(d, pos, end):
  147. return None, end
  148. def enc_dict(self, obj):
  149. #it = list(obj.iteritems())
  150. #it.sort()
  151. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in
  152. obj.iteritems())
  153. return _encodelen(len(r)) + r
  154. def dec_dict(self, d, pos, end):
  155. r = {}
  156. vend = pos
  157. while pos < end:
  158. k, kend = self._loads(d, pos, end)
  159. #if kend > end:
  160. # raise ValueError('key past end')
  161. v, vend = self._loads(d, kend, end)
  162. if vend > end:
  163. raise ValueError('value past end')
  164. r[k] = v
  165. pos = vend
  166. return r, vend
  167. def enc_list(self, obj):
  168. r = ''.join(self.dumps(x) for x in obj)
  169. return _encodelen(len(r)) + r
  170. def dec_list(self, d, pos, end):
  171. r = []
  172. vend = pos
  173. while pos < end:
  174. v, vend = self._loads(d, pos, end)
  175. if vend > end:
  176. raise ValueError('load past end')
  177. r.append(v)
  178. pos = vend
  179. return r, vend
  180. enc_set = enc_list
  181. def dec_set(self, d, pos, end):
  182. r, end = self.dec_list(d, pos, end)
  183. return set(r), end
  184. @staticmethod
  185. def enc_bytes(obj):
  186. return _encodelen(len(obj)) + bytes(obj)
  187. @staticmethod
  188. def dec_bytes(d, pos, end):
  189. return d[pos:end], end
  190. @staticmethod
  191. def enc_unicode(obj):
  192. encobj = obj.encode('utf-8')
  193. return _encodelen(len(encobj)) + encobj
  194. def dec_unicode(self, d, pos, end):
  195. return d[pos:end].decode('utf-8'), end
  196. @staticmethod
  197. def enc_float(obj):
  198. s = math.copysign(1, obj)
  199. if math.isnan(obj):
  200. return _encodelen(1) + chr(0b01000010)
  201. elif math.isinf(obj):
  202. if s == 1:
  203. return _encodelen(1) + chr(0b01000000)
  204. else:
  205. return _encodelen(1) + chr(0b01000001)
  206. elif obj == 0:
  207. if s == 1:
  208. return _encodelen(0)
  209. else:
  210. return _encodelen(1) + chr(0b01000011)
  211. m, e = _splitfloat(obj)
  212. # Binary encoding
  213. val = 0x80
  214. if m < 0:
  215. val |= 0x40
  216. m = -m
  217. # Base 2
  218. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  219. if el > 2:
  220. raise ValueError('exponent too large')
  221. if e < 0:
  222. e += 256**el # convert negative to twos-complement
  223. v = el - 1
  224. encexp = _numtostr(e)
  225. val |= v
  226. r = chr(val) + encexp + _numtostr(m)
  227. return _encodelen(len(r)) + r
  228. def dec_float(self, d, pos, end):
  229. if pos == end:
  230. return float(0), end
  231. v = ord(d[pos])
  232. if v == 0b01000000:
  233. return float('inf'), end
  234. elif v == 0b01000001:
  235. return float('-inf'), end
  236. elif v == 0b01000010:
  237. return float('nan'), end
  238. elif v == 0b01000011:
  239. return float('-0'), end
  240. elif v & 0b110000:
  241. raise ValueError('base must be 2')
  242. elif v & 0b1100:
  243. raise ValueError('scaling factor must be 0')
  244. elif v & 0b11000000 == 0:
  245. raise ValueError('decimal encoding not supported')
  246. #elif v & 0b11000000 == 0b01000000:
  247. # raise ValueError('invalid encoding')
  248. if (v & 3) >= 2:
  249. raise ValueError('large exponents not supported')
  250. pexp = pos + 1
  251. eexp = pos + 1 + (v & 3) + 1
  252. exp = self.dec_int(d, pexp, eexp)[0]
  253. n = float(int(d[eexp:end].encode('hex'), 16))
  254. r = n * 2 ** exp
  255. if v & 0b1000000:
  256. r = -r
  257. return r, end
  258. def dumps(self, obj):
  259. '''Convert obj into an array of bytes.'''
  260. try:
  261. tf = self._typemap[type(obj)]
  262. except KeyError:
  263. if self.coerce is None:
  264. raise TypeError('unhandled object: %s' % `obj`)
  265. tf, obj = self.coerce(obj)
  266. fun = getattr(self, 'enc_%s' % tf)
  267. return self._typetag[tf] + fun(obj)
  268. def _loads(self, data, pos, end):
  269. tag = data[pos]
  270. l, b = _decodelen(data, pos + 1)
  271. if len(data) < pos + 1 + b + l:
  272. raise ValueError('string not long enough')
  273. # XXX - enforce that len(data) == end?
  274. end = pos + 1 + b + l
  275. t = self._tagmap[tag]
  276. fun = getattr(self, 'dec_%s' % t)
  277. return fun(data, pos + 1 + b, end)
  278. def enc_datetime(self, obj):
  279. ts = obj.strftime('%Y%m%d%H%M%S')
  280. if obj.microsecond:
  281. ts += ('.%06d' % obj.microsecond).rstrip('0')
  282. ts += 'Z'
  283. return _encodelen(len(ts)) + ts
  284. def dec_datetime(self, data, pos, end):
  285. ts = data[pos:end]
  286. if ts[-1] != 'Z':
  287. raise ValueError('last character must be Z')
  288. if '.' in ts:
  289. fstr = '%Y%m%d%H%M%S.%fZ'
  290. if ts.endswith('0Z'):
  291. raise ValueError('invalid trailing zeros')
  292. else:
  293. fstr = '%Y%m%d%H%M%SZ'
  294. return datetime.datetime.strptime(ts, fstr), end
  295. def loads(self, data, pos=0, end=None, consume=False):
  296. '''Load from data, starting at pos (optional), and ending
  297. at end (optional). If it is required to consume the
  298. whole string (not the default), set consume to True, and
  299. a ValueError will be raised if the string is not
  300. completely consumed. The second item in ValueError will
  301. be the possition that was the detected end.'''
  302. if end is None:
  303. end = len(data)
  304. r, e = self._loads(data, pos, end)
  305. if consume and e != end:
  306. raise ValueError('entire string not consumed', e)
  307. return r
  308. _coder = ASN1Coder()
  309. dumps = _coder.dumps
  310. loads = _coder.loads
  311. def deeptypecmp(obj, o):
  312. #print 'dtc:', `obj`, `o`
  313. if type(obj) != type(o):
  314. return False
  315. if type(obj) in (str, unicode):
  316. return True
  317. if type(obj) in (list, set):
  318. for i, j in zip(obj, o):
  319. if not deeptypecmp(i, j):
  320. return False
  321. if type(obj) in (dict,):
  322. itms = obj.items()
  323. itms.sort()
  324. nitms = o.items()
  325. nitms.sort()
  326. for (k, v), (nk, nv) in zip(itms, nitms):
  327. if not deeptypecmp(k, nk):
  328. return False
  329. if not deeptypecmp(v, nv):
  330. return False
  331. return True
  332. class Test_deeptypecmp(unittest.TestCase):
  333. def test_true(self):
  334. for i in ((1,1), ('sldkfj', 'sldkfj')
  335. ):
  336. self.assertTrue(deeptypecmp(*i))
  337. def test_false(self):
  338. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  339. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  340. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  341. ):
  342. self.assertFalse(deeptypecmp(*i))
  343. def genfailures(obj):
  344. s = dumps(obj)
  345. for i in xrange(len(s)):
  346. for j in (chr(x) for x in xrange(256)):
  347. ts = s[:i] + j + s[i + 1:]
  348. if ts == s:
  349. continue
  350. try:
  351. o = loads(ts, consume=True)
  352. if o != obj or not deeptypecmp(o, obj):
  353. raise ValueError
  354. except (ValueError, KeyError, IndexError, TypeError):
  355. pass
  356. else:
  357. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  358. class TestCode(unittest.TestCase):
  359. def test_primv(self):
  360. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  361. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  362. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  363. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  364. self.assertEqual(dumps(5), '020105'.decode('hex'))
  365. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  366. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  367. self.assertEqual(dumps(False), '010100'.decode('hex'))
  368. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  369. self.assertEqual(dumps(None), '0500'.decode('hex'))
  370. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  371. def test_fuzzing(self):
  372. # Make sure that when a failure is detected here, that it
  373. # gets added to test_invalids, so that this function may be
  374. # disabled.
  375. genfailures(float(1))
  376. genfailures([ 1, 2, 'sdlkfj' ])
  377. genfailures({ 1: 2, 5: 'sdlkfj' })
  378. genfailures(set([ 1, 2, 'sdlkfj' ]))
  379. genfailures(True)
  380. genfailures(datetime.datetime.utcnow())
  381. def test_invalids(self):
  382. # Add tests for base 8, 16 floats among others
  383. for v in [ '010101',
  384. '0903040001', # float scaling factor
  385. '0903840001', # float scaling factor
  386. '0903100001', # float base
  387. '0903900001', # float base
  388. '0903000001', # float decimal encoding
  389. '0903830001', # float exponent encoding
  390. '090b827fffcc0df505d0fa58f7', # float large exponent
  391. '3007020101020102040673646c6b666a', # list short string still valid
  392. 'e007020101020102020105040673646c6b666a', # dict short value still valid
  393. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  394. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  395. ]:
  396. self.assertRaises(ValueError, loads, v.decode('hex'))
  397. def test_invalid_floats(self):
  398. with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
  399. self.assertRaises(ValueError, dumps, 1.1)
  400. def test_consume(self):
  401. b = dumps(5)
  402. self.assertRaises(ValueError, loads, b + '398473',
  403. consume=True)
  404. # XXX - still possible that an internal data member
  405. # doesn't consume all
  406. # XXX - test that sets are ordered properly
  407. # XXX - test that dicts are ordered properly..
  408. def test_nan(self):
  409. s = dumps(float('nan'))
  410. v = loads(s)
  411. self.assertTrue(math.isnan(v))
  412. def test_cryptoutilasn1(self):
  413. '''Test DER sequences generated by Crypto.Util.asn1.'''
  414. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  415. ('\x05\x00', None),
  416. ('\x02\x03\x00\x96I', 38473),
  417. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  418. ]:
  419. self.assertEqual(loads(s), v)
  420. def test_longstrings(self):
  421. for i in (203, 65484):
  422. s = os.urandom(i)
  423. v = dumps(s)
  424. self.assertEqual(loads(v), s)
  425. def test_invaliddate(self):
  426. pass
  427. # XXX - add test to reject datetime w/ tzinfo, or that it
  428. # handles it properly
  429. def test_dumps(self):
  430. for i in [ None,
  431. True, False,
  432. -1, 0, 1, 255, 256, -255, -256,
  433. 23498732498723, -2398729387234,
  434. (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  435. float(0), float('-0'), float('inf'), float('-inf'),
  436. float(1.0), float(-1.0), float('353.3487'),
  437. float('2.38723873e+307'), float('2.387349e-317'),
  438. sys.float_info.max, sys.float_info.min,
  439. float('.15625'),
  440. 'weoifjwef',
  441. u'\U0001f4a9',
  442. [], [ 1,2,3 ],
  443. {}, { 5: 10, 'adfkj': 34 },
  444. set(), set((1,2,3)),
  445. set((1,'sjlfdkj', None, float('inf'))),
  446. datetime.datetime.utcnow(),
  447. datetime.datetime.utcnow().replace(microsecond=0),
  448. datetime.datetime.utcnow().replace(microsecond=1000),
  449. ]:
  450. s = dumps(i)
  451. o = loads(s)
  452. self.assertEqual(i, o)
  453. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1,
  454. 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  455. out = dumps(tobj)
  456. self.assertEqual(tobj, loads(out))
  457. def test_coerce(self):
  458. class Foo:
  459. pass
  460. class Bar:
  461. pass
  462. class Baz:
  463. pass
  464. def coerce(obj):
  465. if isinstance(obj, Foo):
  466. return 'list', obj.lst
  467. elif isinstance(obj, Baz):
  468. return 'bytes', obj.s
  469. raise TypeError('unknown type')
  470. ac = ASN1Coder(coerce)
  471. v = [1, 2, 3]
  472. o = Foo()
  473. o.lst = v
  474. self.assertEqual(ac.loads(ac.dumps(o)), v)
  475. self.assertRaises(TypeError, ac.dumps, Bar())
  476. v = u'oiejfd'
  477. o = Baz()
  478. o.s = v
  479. es = ac.dumps(o)
  480. self.assertEqual(ac.loads(es), v)
  481. self.assertIsInstance(es, bytes)
  482. self.assertRaises(TypeError, dumps, o)
  483. def test_loads(self):
  484. self.assertRaises(ValueError, loads, '\x00\x02\x00')