A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

585 lines
14 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle. It will automaticly do the correct thing if possible.
  4. #
  5. # This uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import datetime
  11. import math
  12. import os
  13. import pdb
  14. import sys
  15. import unittest
  16. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  17. def _numtostr(n):
  18. hs = '%x' % n
  19. if len(hs) & 1 == 1:
  20. hs = '0' + hs
  21. bs = hs.decode('hex')
  22. return bs
  23. def _encodelen(l):
  24. '''Takes l as a length value, and returns a byte string that
  25. represents l per ASN.1 rules.'''
  26. if l < 128:
  27. return chr(l)
  28. bs = _numtostr(l)
  29. return chr(len(bs) | 0x80) + bs
  30. def _decodelen(d, pos=0):
  31. '''Returns the length, and number of bytes required.'''
  32. odp = ord(d[pos])
  33. if odp < 128:
  34. return ord(d[pos]), 1
  35. else:
  36. l = odp & 0x7f
  37. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  38. class Test_codelen(unittest.TestCase):
  39. _testdata = [
  40. (2, '\x02'),
  41. (127, '\x7f'),
  42. (128, '\x81\x80'),
  43. (255, '\x81\xff'),
  44. (256, '\x82\x01\x00'),
  45. (65536-1, '\x82\xff\xff'),
  46. (65536, '\x83\x01\x00\x00'),
  47. ]
  48. def test_el(self):
  49. for i, j in self._testdata:
  50. self.assertEqual(_encodelen(i), j)
  51. self.assertEqual(_decodelen(j), (i, len(j)))
  52. def _splitfloat(f):
  53. m, e = math.frexp(f)
  54. # XXX - less than ideal
  55. while m != math.trunc(m):
  56. m *= 2
  57. e -= 1
  58. return m, e
  59. class TestSplitFloat(unittest.TestCase):
  60. def test_sf(self):
  61. for a, b in [ (0x2421, -32), (0x5382f, 238),
  62. (0x1fa8c3b094adf1, 971) ]:
  63. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  64. class ASN1Coder(object):
  65. '''A class that contains an PASN.1 encoder/decoder.
  66. Exports two methods, loads and dumps.'''
  67. def __init__(self, coerce=None):
  68. '''If the arg coerce is provided, when dumping the object,
  69. if the type is not found, the coerce function will be called
  70. with the obj. It is expected to return a tuple of a string
  71. and an object that has the method w/ the string as defined:
  72. 'bool': __nonzero__
  73. 'dict': iteritems
  74. 'float': compatible w/ float
  75. 'int': compatible w/ int
  76. 'list': __iter__
  77. 'set': __iter__
  78. 'bytes': __str__
  79. 'null': no method needed
  80. 'unicode': encode method returns UTF-8 encoded bytes
  81. 'datetime': strftime and microsecond
  82. '''
  83. self.coerce = coerce
  84. _typemap = {
  85. bool: 'bool',
  86. dict: 'dict',
  87. float: 'float',
  88. int: 'int',
  89. list: 'list',
  90. long: 'int',
  91. set: 'set',
  92. str: 'bytes',
  93. type(None): 'null',
  94. unicode: 'unicode',
  95. #decimal.Decimal: 'float',
  96. datetime.datetime: 'datetime',
  97. #datetime.timedelta: 'timedelta',
  98. }
  99. _tagmap = {
  100. '\x01': 'bool',
  101. '\x02': 'int',
  102. '\x04': 'bytes',
  103. '\x05': 'null',
  104. '\x09': 'float',
  105. '\x0c': 'unicode',
  106. '\x18': 'datetime',
  107. '\x30': 'list',
  108. '\x31': 'set',
  109. '\xc0': 'dict',
  110. }
  111. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  112. @staticmethod
  113. def enc_int(obj):
  114. l = obj.bit_length()
  115. l += 1 # space for sign bit
  116. l = (l + 7) // 8
  117. if obj < 0:
  118. obj += 1 << (l * 8) # twos-complement conversion
  119. v = _numtostr(obj)
  120. if len(v) != l:
  121. # XXX - is this a problem for signed values?
  122. v = '\x00' + v # add sign octect
  123. return _encodelen(l) + v
  124. @staticmethod
  125. def dec_int(d, pos, end):
  126. if pos == end:
  127. return 0, end
  128. v = int(d[pos:end].encode('hex'), 16)
  129. av = 1 << ((end - pos) * 8 - 1) # sign bit
  130. if v > av:
  131. v -= av * 2 # twos-complement conversion
  132. return v, end
  133. @staticmethod
  134. def enc_bool(obj):
  135. return '\x01' + ('\xff' if obj else '\x00')
  136. def dec_bool(self, d, pos, end):
  137. v = self.dec_int(d, pos, end)[0]
  138. if v not in (-1, 0):
  139. raise ValueError('invalid bool value: %d' % v)
  140. return bool(v), end
  141. @staticmethod
  142. def enc_null(obj):
  143. return '\x00'
  144. @staticmethod
  145. def dec_null(d, pos, end):
  146. return None, end
  147. def enc_dict(self, obj):
  148. #it = list(obj.iteritems())
  149. #it.sort()
  150. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems())
  151. return _encodelen(len(r)) + r
  152. def dec_dict(self, d, pos, end):
  153. r = {}
  154. vend = pos
  155. while pos < end:
  156. k, kend = self._loads(d, pos, end)
  157. #if kend > end:
  158. # raise ValueError('key past end')
  159. v, vend = self._loads(d, kend, end)
  160. if vend > end:
  161. raise ValueError('value past end')
  162. r[k] = v
  163. pos = vend
  164. return r, vend
  165. def enc_list(self, obj):
  166. r = ''.join(self.dumps(x) for x in obj)
  167. return _encodelen(len(r)) + r
  168. def dec_list(self, d, pos, end):
  169. r = []
  170. vend = pos
  171. while pos < end:
  172. v, vend = self._loads(d, pos, end)
  173. if vend > end:
  174. raise ValueError('load past end')
  175. r.append(v)
  176. pos = vend
  177. return r, vend
  178. enc_set = enc_list
  179. def dec_set(self, d, pos, end):
  180. r, end = self.dec_list(d, pos, end)
  181. return set(r), end
  182. @staticmethod
  183. def enc_bytes(obj):
  184. return _encodelen(len(obj)) + bytes(obj)
  185. @staticmethod
  186. def dec_bytes(d, pos, end):
  187. return d[pos:end], end
  188. @staticmethod
  189. def enc_unicode(obj):
  190. encobj = obj.encode('utf-8')
  191. return _encodelen(len(encobj)) + encobj
  192. def dec_unicode(self, d, pos, end):
  193. return d[pos:end].decode('utf-8'), end
  194. @staticmethod
  195. def enc_float(obj):
  196. s = math.copysign(1, obj)
  197. if math.isnan(obj):
  198. return _encodelen(1) + chr(0b01000010)
  199. elif math.isinf(obj):
  200. if s == 1:
  201. return _encodelen(1) + chr(0b01000000)
  202. else:
  203. return _encodelen(1) + chr(0b01000001)
  204. elif obj == 0:
  205. if s == 1:
  206. return _encodelen(0)
  207. else:
  208. return _encodelen(1) + chr(0b01000011)
  209. m, e = _splitfloat(obj)
  210. # Binary encoding
  211. val = 0x80
  212. if m < 0:
  213. val |= 0x40
  214. m = -m
  215. # Base 2
  216. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  217. if e < 0:
  218. e += 256**el # convert negative to twos-complement
  219. if el > 3:
  220. v = 0x3
  221. encexp = _encodelen(el) + _numtostr(e)
  222. else:
  223. v = el - 1
  224. encexp = _numtostr(e)
  225. r = chr(val) + encexp + _numtostr(m)
  226. return _encodelen(len(r)) + r
  227. def dec_float(self, d, pos, end):
  228. if pos == end:
  229. return float(0), end
  230. v = ord(d[pos])
  231. if v == 0b01000000:
  232. return float('inf'), end
  233. elif v == 0b01000001:
  234. return float('-inf'), end
  235. elif v == 0b01000010:
  236. return float('nan'), end
  237. elif v == 0b01000011:
  238. return float('-0'), end
  239. elif v & 0b110000:
  240. raise ValueError('base must be 2')
  241. elif v & 0b1100:
  242. raise ValueError('scaling factor must be 0')
  243. elif v & 0b11000000 == 0:
  244. raise ValueError('decimal encoding not supported')
  245. #elif v & 0b11000000 == 0b01000000:
  246. # raise ValueError('invalid encoding')
  247. if v & 3 == 3:
  248. pexp = pos + 2
  249. explen = ord(d[pos + 1])
  250. if explen <= 3:
  251. raise ValueError('must use other length encoding')
  252. eexp = pos + 2 + explen
  253. else:
  254. pexp = pos + 1
  255. eexp = pos + 1 + (v & 3) + 1
  256. exp = self.dec_int(d, pexp, eexp)[0]
  257. n = float(int(d[eexp:end].encode('hex'), 16))
  258. r = n * 2 ** exp
  259. if v & 0b1000000:
  260. r = -r
  261. return r, end
  262. def dumps(self, obj):
  263. '''Convert obj into a string.'''
  264. try:
  265. tf = self._typemap[type(obj)]
  266. except KeyError:
  267. if self.coerce is None:
  268. raise
  269. tf, obj = self.coerce(obj)
  270. fun = getattr(self, 'enc_%s' % tf)
  271. return self._typetag[tf] + fun(obj)
  272. def _loads(self, data, pos, end):
  273. tag = data[pos]
  274. l, b = _decodelen(data, pos + 1)
  275. if len(data) < pos + 1 + b + l:
  276. raise ValueError('string not long enough')
  277. # XXX - enforce that len(data) == end?
  278. end = pos + 1 + b + l
  279. t = self._tagmap[tag]
  280. fun = getattr(self, 'dec_%s' % t)
  281. return fun(data, pos + 1 + b, end)
  282. def enc_datetime(self, obj):
  283. ts = obj.strftime('%Y%m%d%H%M%S')
  284. if obj.microsecond:
  285. ts += ('.%06d' % obj.microsecond).rstrip('0')
  286. ts += 'Z'
  287. return _encodelen(len(ts)) + ts
  288. def dec_datetime(self, data, pos, end):
  289. ts = data[pos:end]
  290. if ts[-1] != 'Z':
  291. raise ValueError('last character must be Z')
  292. if '.' in ts:
  293. fstr = '%Y%m%d%H%M%S.%fZ'
  294. if ts.endswith('0Z'):
  295. raise ValueError('invalid trailing zeros')
  296. else:
  297. fstr = '%Y%m%d%H%M%SZ'
  298. return datetime.datetime.strptime(ts, fstr), end
  299. def loads(self, data, pos=0, end=None, consume=False):
  300. '''Load from data, starting at pos (option), and ending
  301. at end (optional). If it is required to consume the
  302. whole string (not the default), set consume to True, and
  303. a ValueError will be raised if the string is not
  304. completely consumed. The second item in ValueError will
  305. be the possition that was the detected end.'''
  306. if end is None:
  307. end = len(data)
  308. r, e = self._loads(data, pos, end)
  309. if consume and e != end:
  310. raise ValueError('entire string not consumed', e)
  311. return r
  312. _coder = ASN1Coder()
  313. dumps = _coder.dumps
  314. loads = _coder.loads
  315. def deeptypecmp(obj, o):
  316. #print 'dtc:', `obj`, `o`
  317. if type(obj) != type(o):
  318. return False
  319. if type(obj) in (str, unicode):
  320. return True
  321. if type(obj) in (list, set):
  322. for i, j in zip(obj, o):
  323. if not deeptypecmp(i, j):
  324. return False
  325. if type(obj) in (dict,):
  326. itms = obj.items()
  327. itms.sort()
  328. nitms = o.items()
  329. nitms.sort()
  330. for (k, v), (nk, nv) in zip(itms, nitms):
  331. if not deeptypecmp(k, nk):
  332. return False
  333. if not deeptypecmp(v, nv):
  334. return False
  335. return True
  336. class Test_deeptypecmp(unittest.TestCase):
  337. def test_true(self):
  338. for i in ((1,1), ('sldkfj', 'sldkfj')
  339. ):
  340. self.assertTrue(deeptypecmp(*i))
  341. def test_false(self):
  342. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  343. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  344. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  345. ):
  346. self.assertFalse(deeptypecmp(*i))
  347. def genfailures(obj):
  348. s = dumps(obj)
  349. for i in xrange(len(s)):
  350. for j in (chr(x) for x in xrange(256)):
  351. ts = s[:i] + j + s[i + 1:]
  352. if ts == s:
  353. continue
  354. try:
  355. o = loads(ts, consume=True)
  356. if o != obj or not deeptypecmp(o, obj):
  357. raise ValueError
  358. except (ValueError, KeyError, IndexError, TypeError):
  359. pass
  360. except Exception:
  361. raise
  362. else:
  363. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  364. class TestCode(unittest.TestCase):
  365. def test_primv(self):
  366. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  367. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  368. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  369. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  370. self.assertEqual(dumps(5), '020105'.decode('hex'))
  371. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  372. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  373. self.assertEqual(dumps(False), '010100'.decode('hex'))
  374. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  375. self.assertEqual(dumps(None), '0500'.decode('hex'))
  376. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  377. def test_fuzzing(self):
  378. # Make sure that when a failure is detected here, that it
  379. # gets added to test_invalids, so that this function may be
  380. # disabled.
  381. genfailures(float(1))
  382. genfailures([ 1, 2, 'sdlkfj' ])
  383. genfailures({ 1: 2, 5: 'sdlkfj' })
  384. genfailures(set([ 1, 2, 'sdlkfj' ]))
  385. genfailures(True)
  386. genfailures(datetime.datetime.utcnow())
  387. def test_invalids(self):
  388. # Add tests for base 8, 16 floats among others
  389. for v in [ '010101',
  390. '0903040001', # float scaling factor
  391. '0903840001', # float scaling factor
  392. '0903100001', # float base
  393. '0903900001', # float base
  394. '0903000001', # float decimal encoding
  395. '0903830001', # float exponent encoding
  396. '3007020101020102040673646c6b666a', # list short string still valid
  397. 'c007020101020102020105040673646c6b666a', # dict short value still valid
  398. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  399. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  400. ]:
  401. self.assertRaises(ValueError, loads, v.decode('hex'))
  402. def test_consume(self):
  403. b = dumps(5)
  404. self.assertRaises(ValueError, loads, b + '398473', consume=True)
  405. # XXX - still possible that an internal data member
  406. # doesn't consume all
  407. # XXX - test that sets are ordered properly
  408. # XXX - test that dicts are ordered properly..
  409. def test_nan(self):
  410. s = dumps(float('nan'))
  411. v = loads(s)
  412. self.assertTrue(math.isnan(v))
  413. def test_cryptoutilasn1(self):
  414. '''Test DER sequences generated by Crypto.Util.asn1.'''
  415. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  416. ('\x05\x00', None),
  417. ('\x02\x03\x00\x96I', 38473),
  418. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  419. ]:
  420. self.assertEqual(loads(s), v)
  421. def test_longstrings(self):
  422. for i in (203, 65484):
  423. s = os.urandom(i)
  424. v = dumps(s)
  425. self.assertEqual(loads(v), s)
  426. def test_invaliddate(self):
  427. pass
  428. # XXX - add test to reject datetime w/ tzinfo, or that it handles it
  429. # properly
  430. def test_dumps(self):
  431. for i in [ None,
  432. True, False,
  433. -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  434. float(0), float('-0'), float('inf'), float('-inf'), float(1.0), float(-1.0),
  435. float('353.3487'), float('2387.23873e492'), float('2387.348732e-392'),
  436. float('.15625'),
  437. 'weoifjwef',
  438. u'\U0001f4a9',
  439. [], [ 1,2,3 ],
  440. {}, { 5: 10, 'adfkj': 34 },
  441. set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
  442. datetime.datetime.utcnow(), datetime.datetime.utcnow().replace(microsecond=0),
  443. datetime.datetime.utcnow().replace(microsecond=1000),
  444. ]:
  445. s = dumps(i)
  446. o = loads(s)
  447. self.assertEqual(i, o)
  448. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  449. out = dumps(tobj)
  450. self.assertEqual(tobj, loads(out))
  451. def test_coerce(self):
  452. class Foo:
  453. pass
  454. class Bar:
  455. pass
  456. class Baz:
  457. pass
  458. def coerce(obj):
  459. if isinstance(obj, Foo):
  460. return 'list', obj.lst
  461. elif isinstance(obj, Baz):
  462. return 'bytes', obj.s
  463. raise TypeError('unknown type')
  464. ac = ASN1Coder(coerce)
  465. v = [1, 2, 3]
  466. o = Foo()
  467. o.lst = v
  468. self.assertEqual(ac.loads(ac.dumps(o)), v)
  469. self.assertRaises(TypeError, ac.dumps, Bar())
  470. v = u'oiejfd'
  471. o = Baz()
  472. o.s = v
  473. es = ac.dumps(o)
  474. self.assertEqual(ac.loads(es), v)
  475. self.assertIsInstance(es, bytes)
  476. def test_loads(self):
  477. self.assertRaises(ValueError, loads, '\x00\x02\x00')