A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

576 lines
14 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle. It will automaticly do the correct thing if possible.
  4. #
  5. # This uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import datetime
  11. import math
  12. import os
  13. import pdb
  14. import sys
  15. import unittest
  16. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  17. def _numtostr(n):
  18. hs = '%x' % n
  19. if len(hs) & 1 == 1:
  20. hs = '0' + hs
  21. bs = hs.decode('hex')
  22. return bs
  23. def _encodelen(l):
  24. '''Takes l as a length value, and returns a byte string that
  25. represents l per ASN.1 rules.'''
  26. if l < 128:
  27. return chr(l)
  28. bs = _numtostr(l)
  29. return chr(len(bs) | 0x80) + bs
  30. def _decodelen(d, pos=0):
  31. '''Returns the length, and number of bytes required.'''
  32. odp = ord(d[pos])
  33. if odp < 128:
  34. return ord(d[pos]), 1
  35. else:
  36. l = odp & 0x7f
  37. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  38. class Test_codelen(unittest.TestCase):
  39. _testdata = [
  40. (2, '\x02'),
  41. (127, '\x7f'),
  42. (128, '\x81\x80'),
  43. (255, '\x81\xff'),
  44. (256, '\x82\x01\x00'),
  45. (65536-1, '\x82\xff\xff'),
  46. (65536, '\x83\x01\x00\x00'),
  47. ]
  48. def test_el(self):
  49. for i, j in self._testdata:
  50. self.assertEqual(_encodelen(i), j)
  51. self.assertEqual(_decodelen(j), (i, len(j)))
  52. def _splitfloat(f):
  53. m, e = math.frexp(f)
  54. # XXX - less than ideal
  55. while m != math.trunc(m):
  56. m *= 2
  57. e -= 1
  58. return m, e
  59. class TestSplitFloat(unittest.TestCase):
  60. def test_sf(self):
  61. for a, b in [ (0x2421, -32), (0x5382f, 238),
  62. (0x1fa8c3b094adf1, 971) ]:
  63. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  64. class ASN1Coder(object):
  65. '''A class that contains an PASN.1 encoder/decoder.
  66. Exports two methods, loads and dumps.'''
  67. def __init__(self, coerce=None):
  68. '''If the arg coerce is provided, when dumping the object,
  69. if the type is not found, the coerce function will be called
  70. with the obj. It is expected to return a tuple of a string
  71. and an object that has the method w/ the string as defined:
  72. 'bool': __nonzero__
  73. 'dict': iteritems
  74. 'float': compatible w/ float
  75. 'int': compatible w/ int
  76. 'list': __iter__
  77. 'set': __iter__
  78. 'bytes': __str__
  79. 'null': no method needed
  80. 'unicode': encode method returns UTF-8 encoded bytes
  81. 'datetime': strftime and microsecond
  82. '''
  83. self.coerce = coerce
  84. _typemap = {
  85. bool: 'bool',
  86. dict: 'dict',
  87. float: 'float',
  88. int: 'int',
  89. list: 'list',
  90. long: 'int',
  91. set: 'set',
  92. str: 'bytes',
  93. type(None): 'null',
  94. unicode: 'unicode',
  95. #decimal.Decimal: 'float',
  96. datetime.datetime: 'datetime',
  97. #datetime.timedelta: 'timedelta',
  98. }
  99. _tagmap = {
  100. '\x01': 'bool',
  101. '\x02': 'int',
  102. '\x04': 'bytes',
  103. '\x05': 'null',
  104. '\x09': 'float',
  105. '\x0c': 'unicode',
  106. '\x18': 'datetime',
  107. '\x30': 'list',
  108. '\x31': 'set',
  109. '\xc0': 'dict',
  110. }
  111. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  112. @staticmethod
  113. def enc_int(obj):
  114. l = obj.bit_length()
  115. l += 1 # space for sign bit
  116. l = (l + 7) // 8
  117. if obj < 0:
  118. obj += 1 << (l * 8) # twos-complement conversion
  119. v = _numtostr(obj)
  120. if len(v) != l:
  121. # XXX - is this a problem for signed values?
  122. v = '\x00' + v # add sign octect
  123. return _encodelen(l) + v
  124. @staticmethod
  125. def dec_int(d, pos, end):
  126. if pos == end:
  127. return 0, end
  128. v = int(d[pos:end].encode('hex'), 16)
  129. av = 1 << ((end - pos) * 8 - 1) # sign bit
  130. if v > av:
  131. v -= av * 2 # twos-complement conversion
  132. return v, end
  133. @staticmethod
  134. def enc_bool(obj):
  135. return '\x01' + ('\xff' if obj else '\x00')
  136. def dec_bool(self, d, pos, end):
  137. v = self.dec_int(d, pos, end)[0]
  138. if v not in (-1, 0):
  139. raise ValueError('invalid bool value: %d' % v)
  140. return bool(v), end
  141. @staticmethod
  142. def enc_null(obj):
  143. return '\x00'
  144. @staticmethod
  145. def dec_null(d, pos, end):
  146. return None, end
  147. def enc_dict(self, obj):
  148. #it = list(obj.iteritems())
  149. #it.sort()
  150. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems())
  151. return _encodelen(len(r)) + r
  152. def dec_dict(self, d, pos, end):
  153. r = {}
  154. vend = pos
  155. while pos < end:
  156. k, kend = self._loads(d, pos, end)
  157. #if kend > end:
  158. # raise ValueError('key past end')
  159. v, vend = self._loads(d, kend, end)
  160. if vend > end:
  161. raise ValueError('value past end')
  162. r[k] = v
  163. pos = vend
  164. return r, vend
  165. def enc_list(self, obj):
  166. r = ''.join(self.dumps(x) for x in obj)
  167. return _encodelen(len(r)) + r
  168. def dec_list(self, d, pos, end):
  169. r = []
  170. vend = pos
  171. while pos < end:
  172. v, vend = self._loads(d, pos, end)
  173. if vend > end:
  174. raise ValueError('load past end')
  175. r.append(v)
  176. pos = vend
  177. return r, vend
  178. enc_set = enc_list
  179. def dec_set(self, d, pos, end):
  180. r, end = self.dec_list(d, pos, end)
  181. return set(r), end
  182. @staticmethod
  183. def enc_bytes(obj):
  184. return _encodelen(len(obj)) + bytes(obj)
  185. @staticmethod
  186. def dec_bytes(d, pos, end):
  187. return d[pos:end], end
  188. @staticmethod
  189. def enc_unicode(obj):
  190. encobj = obj.encode('utf-8')
  191. return _encodelen(len(encobj)) + encobj
  192. def dec_unicode(self, d, pos, end):
  193. return d[pos:end].decode('utf-8'), end
  194. @staticmethod
  195. def enc_float(obj):
  196. s = math.copysign(1, obj)
  197. if math.isnan(obj):
  198. return _encodelen(1) + chr(0b01000010)
  199. elif math.isinf(obj):
  200. if s == 1:
  201. return _encodelen(1) + chr(0b01000000)
  202. else:
  203. return _encodelen(1) + chr(0b01000001)
  204. elif obj == 0:
  205. if s == 1:
  206. return _encodelen(0)
  207. else:
  208. return _encodelen(1) + chr(0b01000011)
  209. m, e = _splitfloat(obj)
  210. # Binary encoding
  211. val = 0x80
  212. if m < 0:
  213. val |= 0x40
  214. m = -m
  215. # Base 2
  216. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  217. if e < 0:
  218. e += 256**el # convert negative to twos-complement
  219. if el > 3:
  220. v = 0x3
  221. encexp = _encodelen(el) + _numtostr(e)
  222. else:
  223. v = el - 1
  224. encexp = _numtostr(e)
  225. r = chr(val) + encexp + _numtostr(m)
  226. return _encodelen(len(r)) + r
  227. def dec_float(self, d, pos, end):
  228. if pos == end:
  229. return float(0), end
  230. v = ord(d[pos])
  231. if v == 0b01000000:
  232. return float('inf'), end
  233. elif v == 0b01000001:
  234. return float('-inf'), end
  235. elif v == 0b01000010:
  236. return float('nan'), end
  237. elif v == 0b01000011:
  238. return float('-0'), end
  239. elif v & 0b110000:
  240. raise ValueError('base must be 2')
  241. elif v & 0b1100:
  242. raise ValueError('scaling factor must be 0')
  243. elif v & 0b11000000 == 0:
  244. raise ValueError('decimal encoding not supported')
  245. #elif v & 0b11000000 == 0b01000000:
  246. # raise ValueError('invalid encoding')
  247. if v & 3 == 3:
  248. pexp = pos + 2
  249. explen = ord(d[pos + 1])
  250. if explen <= 3:
  251. raise ValueError('must use other length encoding')
  252. eexp = pos + 2 + explen
  253. else:
  254. pexp = pos + 1
  255. eexp = pos + 1 + (v & 3) + 1
  256. exp = self.dec_int(d, pexp, eexp)[0]
  257. n = float(int(d[eexp:end].encode('hex'), 16))
  258. r = n * 2 ** exp
  259. if v & 0b1000000:
  260. r = -r
  261. return r, end
  262. def dumps(self, obj):
  263. '''Convert obj into a string.'''
  264. try:
  265. tf = self._typemap[type(obj)]
  266. except KeyError:
  267. if self.coerce is None:
  268. raise
  269. tf, obj = self.coerce(obj)
  270. fun = getattr(self, 'enc_%s' % tf)
  271. return self._typetag[tf] + fun(obj)
  272. def _loads(self, data, pos, end):
  273. tag = data[pos]
  274. l, b = _decodelen(data, pos + 1)
  275. if len(data) < pos + 1 + b + l:
  276. raise ValueError('string not long enough')
  277. # XXX - enforce that len(data) == end?
  278. end = pos + 1 + b + l
  279. t = self._tagmap[tag]
  280. fun = getattr(self, 'dec_%s' % t)
  281. return fun(data, pos + 1 + b, end)
  282. def enc_datetime(self, obj):
  283. ts = obj.strftime('%Y%m%d%H%M%S')
  284. if obj.microsecond:
  285. ts += ('.%06d' % obj.microsecond).rstrip('0')
  286. ts += 'Z'
  287. return _encodelen(len(ts)) + ts
  288. def dec_datetime(self, data, pos, end):
  289. ts = data[pos:end]
  290. if '.' in ts:
  291. fstr = '%Y%m%d%H%M%S.%fZ'
  292. if ts.endswith('0Z'):
  293. raise ValueError('invalid trailing zeros')
  294. else:
  295. fstr = '%Y%m%d%H%M%SZ'
  296. return datetime.datetime.strptime(ts, fstr), end
  297. def loads(self, data, pos=0, end=None, consume=False):
  298. '''Load from data, starting at pos (option), and ending
  299. at end (optional). If it is required to consume the
  300. whole string (not the default), set consume to True, and
  301. a ValueError will be raised if the string is not
  302. completely consumed. The second item in ValueError will
  303. be the possition that was the detected end.'''
  304. if end is None:
  305. end = len(data)
  306. r, e = self._loads(data, pos, end)
  307. if consume and e != end:
  308. raise ValueError('entire string not consumed', e)
  309. return r
  310. def deeptypecmp(obj, o):
  311. #print 'dtc:', `obj`, `o`
  312. if type(obj) != type(o):
  313. return False
  314. if type(obj) in (str, unicode):
  315. return True
  316. if type(obj) in (list, set):
  317. for i, j in zip(obj, o):
  318. if not deeptypecmp(i, j):
  319. return False
  320. if type(obj) in (dict,):
  321. itms = obj.items()
  322. itms.sort()
  323. nitms = o.items()
  324. nitms.sort()
  325. for (k, v), (nk, nv) in zip(itms, nitms):
  326. if not deeptypecmp(k, nk):
  327. return False
  328. if not deeptypecmp(v, nv):
  329. return False
  330. return True
  331. class Test_deeptypecmp(unittest.TestCase):
  332. def test_true(self):
  333. for i in ((1,1), ('sldkfj', 'sldkfj')
  334. ):
  335. self.assertTrue(deeptypecmp(*i))
  336. def test_false(self):
  337. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  338. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  339. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  340. ):
  341. self.assertFalse(deeptypecmp(*i))
  342. def genfailures(obj):
  343. s = dumps(obj)
  344. for i in xrange(len(s)):
  345. for j in (chr(x) for x in xrange(256)):
  346. ts = s[:i] + j + s[i + 1:]
  347. if ts == s:
  348. continue
  349. try:
  350. o = loads(ts, consume=True)
  351. if o != obj or not deeptypecmp(o, obj):
  352. raise ValueError
  353. except (ValueError, KeyError, IndexError, TypeError):
  354. pass
  355. except Exception:
  356. raise
  357. else:
  358. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  359. _coder = ASN1Coder()
  360. dumps = _coder.dumps
  361. loads = _coder.loads
  362. class TestCode(unittest.TestCase):
  363. def test_primv(self):
  364. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  365. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  366. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  367. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  368. self.assertEqual(dumps(5), '020105'.decode('hex'))
  369. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  370. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  371. self.assertEqual(dumps(False), '010100'.decode('hex'))
  372. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  373. self.assertEqual(dumps(None), '0500'.decode('hex'))
  374. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  375. def test_fuzzing(self):
  376. genfailures(float(1))
  377. genfailures([ 1, 2, 'sdlkfj' ])
  378. genfailures({ 1: 2, 5: 'sdlkfj' })
  379. genfailures(set([ 1, 2, 'sdlkfj' ]))
  380. def test_consume(self):
  381. b = dumps(5)
  382. self.assertRaises(ValueError, loads, b + '398473', consume=True)
  383. # XXX - still possible that an internal data member
  384. # doesn't consume all
  385. # XXX - test that sets are ordered properly
  386. # XXX - test that dicts are ordered properly..
  387. def test_nan(self):
  388. s = dumps(float('nan'))
  389. v = loads(s)
  390. self.assertTrue(math.isnan(v))
  391. def test_invalids(self):
  392. # Add tests for base 8, 16 floats among others
  393. for v in [ '010101',
  394. '0903040001', # float scaling factor
  395. '0903840001', # float scaling factor
  396. '0903100001', # float base
  397. '0903900001', # float base
  398. '0903000001', # float decimal encoding
  399. '0903830001', # float exponent encoding
  400. '3007020101020102040673646c6b666a', # list short string still valid
  401. 'c007020101020102020105040673646c6b666a', # dict short value still valid
  402. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  403. ]:
  404. self.assertRaises(ValueError, loads, v.decode('hex'))
  405. def test_cryptoutilasn1(self):
  406. '''Test DER sequences generated by Crypto.Util.asn1.'''
  407. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  408. ('\x05\x00', None),
  409. ('\x02\x03\x00\x96I', 38473),
  410. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  411. ]:
  412. self.assertEqual(loads(s), v)
  413. def test_longstrings(self):
  414. for i in (203, 65484):
  415. s = os.urandom(i)
  416. v = dumps(s)
  417. self.assertEqual(loads(v), s)
  418. def test_invaliddate(self):
  419. pass
  420. # XXX - add test to reject datetime w/ tzinfo, or that it handles it
  421. # properly
  422. def test_dumps(self):
  423. for i in [ None,
  424. True, False,
  425. -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  426. float(0), float('-0'), float('inf'), float('-inf'), float(1.0), float(-1.0),
  427. float('353.3487'), float('2387.23873e492'), float('2387.348732e-392'),
  428. float('.15625'),
  429. 'weoifjwef',
  430. u'\U0001f4a9',
  431. [], [ 1,2,3 ],
  432. {}, { 5: 10, 'adfkj': 34 },
  433. set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
  434. datetime.datetime.utcnow(), datetime.datetime.utcnow().replace(microsecond=0),
  435. datetime.datetime.utcnow().replace(microsecond=1000),
  436. ]:
  437. s = dumps(i)
  438. o = loads(s)
  439. self.assertEqual(i, o)
  440. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  441. out = dumps(tobj)
  442. self.assertEqual(tobj, loads(out))
  443. def test_coerce(self):
  444. class Foo:
  445. pass
  446. class Bar:
  447. pass
  448. class Baz:
  449. pass
  450. def coerce(obj):
  451. if isinstance(obj, Foo):
  452. return 'list', obj.lst
  453. elif isinstance(obj, Baz):
  454. return 'bytes', obj.s
  455. raise TypeError('unknown type')
  456. ac = ASN1Coder(coerce)
  457. v = [1, 2, 3]
  458. o = Foo()
  459. o.lst = v
  460. self.assertEqual(ac.loads(ac.dumps(o)), v)
  461. self.assertRaises(TypeError, ac.dumps, Bar())
  462. v = u'oiejfd'
  463. o = Baz()
  464. o.s = v
  465. es = ac.dumps(o)
  466. self.assertEqual(ac.loads(es), v)
  467. self.assertIsInstance(es, bytes)
  468. def test_loads(self):
  469. self.assertRaises(ValueError, loads, '\x00\x02\x00')