A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

709 lines
18 KiB

  1. #!/usr/bin/env python
  2. '''A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. of pickle.
  4. The default dumps/loads uses a profile of ASN.1 that supports serialization
  5. of key/value pairs. This is non-standard. Instantiate the class ASN1Coder
  6. to get a pure ASN.1 serializer/deserializer.
  7. All lengths must be specified. That is that End-of-contents octets
  8. MUST NOT be used. The shorted form of length encoding MUST be used.
  9. A longer length encoding MUST be rejected.'''
  10. __author__ = 'John-Mark Gurney'
  11. __copyright__ = 'Copyright 2016-2020 John-Mark Gurney. All rights reserved.'
  12. __license__ = '2-clause BSD license'
  13. # Copyright 2016-2020, John-Mark Gurney
  14. # All rights reserved.
  15. #
  16. # Redistribution and use in source and binary forms, with or without
  17. # modification, are permitted provided that the following conditions are met:
  18. #
  19. # 1. Redistributions of source code must retain the above copyright notice, this
  20. # list of conditions and the following disclaimer.
  21. # 2. Redistributions in binary form must reproduce the above copyright notice,
  22. # this list of conditions and the following disclaimer in the documentation
  23. # and/or other materials provided with the distribution.
  24. #
  25. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  26. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  27. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  28. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
  29. # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  30. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  31. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  32. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  33. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  34. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  35. #
  36. # The views and conclusions contained in the software and documentation are those
  37. # of the authors and should not be interpreted as representing official policies,
  38. # either expressed or implied, of the Project.
  39. import datetime
  40. import functools
  41. import math
  42. import os
  43. import sys
  44. import unittest
  45. if sys.version_info.major != 3: # pragma: no cover
  46. raise RuntimeError('this module only supports python 3')
  47. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  48. def _numtobytes(n):
  49. hs = '%x' % n
  50. if len(hs) & 1 == 1:
  51. hs = '0' + hs
  52. bs = bytes.fromhex(hs)
  53. return bs
  54. def _encodelen(l):
  55. '''Takes l as a length value, and returns a byte string that
  56. represents l per ASN.1 rules.'''
  57. if l < 128:
  58. return bytes([l])
  59. bs = _numtobytes(l)
  60. return bytes([len(bs) | 0x80]) + bs
  61. def _decodelen(d, pos=0):
  62. '''Returns the length, and number of bytes required.'''
  63. odp = d[pos]
  64. if odp < 128:
  65. return d[pos], 1
  66. else:
  67. l = odp & 0x7f
  68. return int(d[pos + 1:pos + 1 + l].hex(), 16), l + 1
  69. class Test_codelen(unittest.TestCase):
  70. _testdata = [
  71. (2, b'\x02'),
  72. (127, b'\x7f'),
  73. (128, b'\x81\x80'),
  74. (255, b'\x81\xff'),
  75. (256, b'\x82\x01\x00'),
  76. (65536-1, b'\x82\xff\xff'),
  77. (65536, b'\x83\x01\x00\x00'),
  78. ]
  79. def test_el(self):
  80. for i, j in self._testdata:
  81. self.assertEqual(_encodelen(i), j)
  82. self.assertEqual(_decodelen(j), (i, len(j)))
  83. def _splitfloat(f):
  84. m, e = math.frexp(f)
  85. # XXX - less than ideal
  86. while m != math.trunc(m):
  87. m *= 2
  88. e -= 1
  89. return math.trunc(m), e
  90. class TestSplitFloat(unittest.TestCase):
  91. def test_sf(self):
  92. for a, b in [ (0x2421, -32), (0x5382f, 238),
  93. (0x1fa8c3b094adf1, 971) ]:
  94. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  95. class ASN1Coder(object):
  96. '''A class that contains an PASN.1 encoder/decoder.
  97. Exports two methods, loads and dumps.'''
  98. def __init__(self, coerce=None):
  99. '''If the arg coerce is provided, when dumping the object,
  100. if the type is not found, the coerce function will be called
  101. with the obj. It is expected to return a tuple of a string
  102. and an object that has the method w/ the string as defined:
  103. 'bool': __nonzero__
  104. 'float': compatible w/ float
  105. 'int': compatible w/ int
  106. 'list': __iter__
  107. 'set': __iter__
  108. 'bytes': __str__ # XXX what is correct here
  109. 'null': no method needed
  110. 'unicode': encode method returns UTF-8 encoded bytes
  111. 'datetime': strftime and microsecond
  112. '''
  113. self.coerce = coerce
  114. _typemap = {
  115. bool: 'bool',
  116. float: 'float',
  117. int: 'int',
  118. list: 'list',
  119. set: 'set',
  120. bytes: 'bytes',
  121. type(None): 'null',
  122. str: 'unicode',
  123. #decimal.Decimal: 'float',
  124. datetime.datetime: 'datetime',
  125. #datetime.timedelta: 'timedelta',
  126. }
  127. _tagmap = {
  128. b'\x01': 'bool',
  129. b'\x02': 'int',
  130. b'\x04': 'bytes',
  131. b'\x05': 'null',
  132. b'\x09': 'float',
  133. b'\x0c': 'unicode',
  134. b'\x18': 'datetime',
  135. b'\x30': 'list',
  136. b'\x31': 'set',
  137. }
  138. _typetag = dict((v, k) for k, v in _tagmap.items())
  139. @staticmethod
  140. def enc_int(obj, **kwargs):
  141. l = obj.bit_length()
  142. l += 1 # space for sign bit
  143. l = (l + 7) // 8
  144. if obj < 0:
  145. obj += 1 << (l * 8) # twos-complement conversion
  146. v = _numtobytes(obj)
  147. if len(v) != l:
  148. # XXX - is this a problem for signed values?
  149. v = b'\x00' + v # add sign octect
  150. return _encodelen(l) + v
  151. @staticmethod
  152. def dec_int(d, pos, end):
  153. if pos == end:
  154. return 0, end
  155. v = int(bytes.hex(d[pos:end]), 16)
  156. av = 1 << ((end - pos) * 8 - 1) # sign bit
  157. if v > av:
  158. v -= av * 2 # twos-complement conversion
  159. return v, end
  160. @staticmethod
  161. def enc_bool(obj, **kwargs):
  162. return b'\x01' + (b'\xff' if obj else b'\x00')
  163. def dec_bool(self, d, pos, end):
  164. v = self.dec_int(d, pos, end)[0]
  165. if v not in (-1, 0):
  166. raise ValueError('invalid bool value: %d' % v)
  167. return bool(v), end
  168. @staticmethod
  169. def enc_null(obj, **kwargs):
  170. return b'\x00'
  171. @staticmethod
  172. def dec_null(d, pos, end):
  173. return None, end
  174. def enc_list(self, obj, **kwargs):
  175. r = b''.join(self.dumps(x, **kwargs) for x in obj)
  176. return _encodelen(len(r)) + r
  177. def dec_list(self, d, pos, end):
  178. r = []
  179. vend = pos
  180. while pos < end:
  181. v, vend = self._loads(d, pos, end)
  182. if vend > end:
  183. raise ValueError('load past end')
  184. r.append(v)
  185. pos = vend
  186. return r, vend
  187. enc_set = enc_list
  188. def dec_set(self, d, pos, end):
  189. r, end = self.dec_list(d, pos, end)
  190. return set(r), end
  191. @staticmethod
  192. def enc_bytes(obj, **kwargs):
  193. return _encodelen(len(obj)) + bytes(obj)
  194. @staticmethod
  195. def dec_bytes(d, pos, end):
  196. return d[pos:end], end
  197. @staticmethod
  198. def enc_unicode(obj, **kwargs):
  199. encobj = obj.encode('utf-8')
  200. return _encodelen(len(encobj)) + encobj
  201. def dec_unicode(self, d, pos, end):
  202. return d[pos:end].decode('utf-8'), end
  203. @staticmethod
  204. def enc_float(obj, **kwargs):
  205. s = math.copysign(1, obj)
  206. if math.isnan(obj):
  207. return _encodelen(1) + bytes([0b01000010])
  208. elif math.isinf(obj):
  209. if s == 1:
  210. return _encodelen(1) + bytes([0b01000000])
  211. else:
  212. return _encodelen(1) + bytes([0b01000001])
  213. elif obj == 0:
  214. if s == 1:
  215. return _encodelen(0)
  216. else:
  217. return _encodelen(1) + bytes([0b01000011])
  218. m, e = _splitfloat(obj)
  219. # Binary encoding
  220. val = 0x80
  221. if m < 0:
  222. val |= 0x40
  223. m = -m
  224. # Base 2
  225. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  226. if el > 2:
  227. raise ValueError('exponent too large')
  228. if e < 0:
  229. e += 256**el # convert negative to twos-complement
  230. v = el - 1
  231. encexp = _numtobytes(e)
  232. val |= v
  233. r = bytes([val]) + encexp + _numtobytes(m)
  234. return _encodelen(len(r)) + r
  235. def dec_float(self, d, pos, end):
  236. if pos == end:
  237. return float(0), end
  238. v = d[pos]
  239. if v == 0b01000000:
  240. return float('inf'), end
  241. elif v == 0b01000001:
  242. return float('-inf'), end
  243. elif v == 0b01000010:
  244. return float('nan'), end
  245. elif v == 0b01000011:
  246. return float('-0'), end
  247. elif v & 0b110000:
  248. raise ValueError('base must be 2')
  249. elif v & 0b1100:
  250. raise ValueError('scaling factor must be 0')
  251. elif v & 0b11000000 == 0:
  252. raise ValueError('decimal encoding not supported')
  253. #elif v & 0b11000000 == 0b01000000:
  254. # raise ValueError('invalid encoding')
  255. if (v & 3) >= 2:
  256. raise ValueError('large exponents not supported')
  257. pexp = pos + 1
  258. eexp = pos + 1 + (v & 3) + 1
  259. exp = self.dec_int(d, pexp, eexp)[0]
  260. n = float(int(bytes.hex(d[eexp:end]), 16))
  261. r = n * 2 ** exp
  262. if v & 0b1000000:
  263. r = -r
  264. return r, end
  265. def dumps(self, obj, default=None):
  266. '''Convert obj into an array of bytes.
  267. ``default(obj)`` is a function that should return a
  268. serializable version of obj or raise TypeError. The
  269. default simply raises TypeError.
  270. '''
  271. try:
  272. tf = self._typemap[type(obj)]
  273. except KeyError:
  274. if default is not None:
  275. try:
  276. return self.dumps(default(obj), default=default)
  277. except TypeError:
  278. pass
  279. if self.coerce is None:
  280. raise TypeError('unhandled object: %s' % repr(obj))
  281. tf, obj = self.coerce(obj)
  282. fun = getattr(self, 'enc_%s' % tf)
  283. return self._typetag[tf] + fun(obj, default=default)
  284. def _loads(self, data, pos, end):
  285. tag = data[pos:pos + 1]
  286. l, b = _decodelen(data, pos + 1)
  287. if len(data) < pos + 1 + b + l:
  288. raise ValueError('string not long enough')
  289. # XXX - enforce that len(data) == end?
  290. end = pos + 1 + b + l
  291. t = self._tagmap[tag]
  292. fun = getattr(self, 'dec_%s' % t)
  293. return fun(data, pos + 1 + b, end)
  294. def enc_datetime(self, obj, **kwargs):
  295. ts = obj.strftime('%Y%m%d%H%M%S')
  296. if obj.microsecond:
  297. ts += ('.%06d' % obj.microsecond).rstrip('0')
  298. ts += 'Z'
  299. return _encodelen(len(ts)) + ts.encode('utf-8')
  300. def dec_datetime(self, data, pos, end):
  301. ts = data[pos:end].decode('ascii')
  302. if ts[-1:] != 'Z':
  303. raise ValueError('last character must be Z, was: %s' % repr(ts[-1]))
  304. # Real bug is in strptime, but work around it here.
  305. if ' ' in ts:
  306. raise ValueError('no spaces are allowed')
  307. if '.' in ts:
  308. fstr = '%Y%m%d%H%M%S.%fZ'
  309. if ts.endswith('0Z'):
  310. raise ValueError('invalid trailing zeros')
  311. else:
  312. fstr = '%Y%m%d%H%M%SZ'
  313. return datetime.datetime.strptime(ts, fstr), end
  314. def loads(self, data, pos=0, end=None, consume=False):
  315. '''Load from data, starting at pos (optional), and ending
  316. at end (optional). If it is required to consume the
  317. whole string (not the default), set consume to True, and
  318. a ValueError will be raised if the string is not
  319. completely consumed. The second item in ValueError will
  320. be the possition that was the detected end.'''
  321. if end is None:
  322. end = len(data)
  323. r, e = self._loads(data, pos, end)
  324. if consume and e != end:
  325. raise ValueError('entire string not consumed', e)
  326. return r
  327. class ASN1DictCoder(ASN1Coder):
  328. '''This adds support for the non-standard dict serialization.
  329. The coerce method also supports the following type:
  330. 'dict': iteritems
  331. '''
  332. _typemap = ASN1Coder._typemap.copy()
  333. _typemap[dict] = 'dict'
  334. _tagmap = ASN1Coder._tagmap.copy()
  335. _tagmap[b'\xe0'] = 'dict'
  336. _typetag = dict((v, k) for k, v in _tagmap.items())
  337. def enc_dict(self, obj, **kwargs):
  338. #it = list(obj.iteritems())
  339. #it.sort()
  340. r = b''.join(self.dumps(k, **kwargs) + self.dumps(v, **kwargs) for k, v in
  341. obj.items())
  342. return _encodelen(len(r)) + r
  343. def dec_dict(self, d, pos, end):
  344. r = {}
  345. vend = pos
  346. while pos < end:
  347. k, kend = self._loads(d, pos, end)
  348. #if kend > end:
  349. # raise ValueError('key past end')
  350. v, vend = self._loads(d, kend, end)
  351. if vend > end:
  352. raise ValueError('value past end')
  353. r[k] = v
  354. pos = vend
  355. return r, vend
  356. _coder = ASN1DictCoder()
  357. dumps = _coder.dumps
  358. loads = _coder.loads
  359. def cmp(a, b):
  360. return (a > b) - (a < b)
  361. def universal_cmp(a, b):
  362. # Because Python 3 sucks, make this function that
  363. # orders first based upon type, then on value.
  364. if type(a) == type(b):
  365. if isinstance(a, (tuple, list)):
  366. for a, b in zip(a, b):
  367. if a != b:
  368. return universal_cmp(a, b)
  369. #return cmp(len(a), len(b))
  370. else:
  371. return cmp(a, b)
  372. else:
  373. return id(type(a)) < id(type(b))
  374. def deeptypecmp(obj, o):
  375. #print('dtc:', repr(obj), repr(o))
  376. if type(obj) != type(o):
  377. return False
  378. if type(obj) is str:
  379. return True
  380. if type(obj) in (list, set):
  381. for i, j in zip(obj, o):
  382. if not deeptypecmp(i, j):
  383. return False
  384. if type(obj) in (dict,):
  385. itms = sorted(obj.items(), key=functools.cmp_to_key(universal_cmp))
  386. nitms = sorted(o.items(), key=functools.cmp_to_key(universal_cmp))
  387. for (k, v), (nk, nv) in zip(itms, nitms):
  388. if not deeptypecmp(k, nk):
  389. return False
  390. if not deeptypecmp(v, nv):
  391. return False
  392. return True
  393. class Test_deeptypecmp(unittest.TestCase):
  394. def test_true(self):
  395. for i in ((1,1), ('sldkfj', 'sldkfj')
  396. ):
  397. self.assertTrue(deeptypecmp(*i))
  398. def test_false(self):
  399. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  400. ({1: 2, 5: 'sdlkfj'}, {1: 2, 5: b'sdlkfj'}),
  401. ({1: 2, 'sdlkfj': 5}, {1: 2, b'sdlkfj': 5}),
  402. ):
  403. self.assertFalse(deeptypecmp(*i), '%s != %s' % (i[0], i[1]))
  404. def genfailures(obj):
  405. s = dumps(obj)
  406. for i in range(len(s)):
  407. for j in (bytes([x]) for x in range(256)):
  408. ts = s[:i] + j + s[i + 1:]
  409. if ts == s:
  410. continue
  411. try:
  412. o = loads(ts, consume=True)
  413. if o != obj or not deeptypecmp(o, obj):
  414. raise ValueError
  415. except (ValueError, KeyError, IndexError, TypeError):
  416. pass
  417. else: # pragma: no cover
  418. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, s[i]))
  419. class TestCode(unittest.TestCase):
  420. def test_primv(self):
  421. self.assertEqual(dumps(-257), bytes.fromhex('0202feff'))
  422. self.assertEqual(dumps(-256), bytes.fromhex('0202ff00'))
  423. self.assertEqual(dumps(-255), bytes.fromhex('0202ff01'))
  424. self.assertEqual(dumps(-1), bytes.fromhex('0201ff'))
  425. self.assertEqual(dumps(5), bytes.fromhex('020105'))
  426. self.assertEqual(dumps(128), bytes.fromhex('02020080'))
  427. self.assertEqual(dumps(256), bytes.fromhex('02020100'))
  428. self.assertEqual(dumps(False), bytes.fromhex('010100'))
  429. self.assertEqual(dumps(True), bytes.fromhex('0101ff'))
  430. self.assertEqual(dumps(None), bytes.fromhex('0500'))
  431. self.assertEqual(dumps(.15625), bytes.fromhex('090380fb05'))
  432. def test_fuzzing(self):
  433. # Make sure that when a failure is detected here, that it
  434. # gets added to test_invalids, so that this function may be
  435. # disabled.
  436. genfailures(float(1))
  437. genfailures([ 1, 2, 'sdlkfj' ])
  438. genfailures({ 1: 2, 5: 'sdlkfj' })
  439. genfailures(set([ 1, 2, 'sdlkfj' ]))
  440. genfailures(True)
  441. genfailures(datetime.datetime.utcnow())
  442. def test_invalids(self):
  443. # Add tests for base 8, 16 floats among others
  444. for v in [ '010101',
  445. '0903040001', # float scaling factor
  446. '0903840001', # float scaling factor
  447. '0903100001', # float base
  448. '0903900001', # float base
  449. '0903000001', # float decimal encoding
  450. '0903830001', # float exponent encoding
  451. '090b827fffcc0df505d0fa58f7', # float large exponent
  452. '3007020101020102040673646c6b666a', # list short string still valid
  453. 'e007020101020102020105040673646c6b666a', # dict short value still valid
  454. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  455. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  456. '181632303136313220383031303933302e3931353133385a', #datetime w/ space
  457. ]:
  458. self.assertRaises(ValueError, loads, bytes.fromhex(v))
  459. def test_invalid_floats(self):
  460. import mock
  461. with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
  462. self.assertRaises(ValueError, dumps, 1.1)
  463. def test_consume(self):
  464. b = dumps(5)
  465. self.assertRaises(ValueError, loads, b + b'398473',
  466. consume=True)
  467. # XXX - still possible that an internal data member
  468. # doesn't consume all
  469. # XXX - test that sets are ordered properly
  470. # XXX - test that dicts are ordered properly..
  471. def test_nan(self):
  472. s = dumps(float('nan'))
  473. v = loads(s)
  474. self.assertTrue(math.isnan(v))
  475. def test_cryptoutilasn1(self):
  476. '''Test DER sequences generated by Crypto.Util.asn1.'''
  477. for s, v in [ (b'\x02\x03$\x8a\xf9', 2394873),
  478. (b'\x05\x00', None),
  479. (b'\x02\x03\x00\x96I', 38473),
  480. (b'\x04\x81\xc8' + b'\x00' * 200, b'\x00' * 200),
  481. ]:
  482. self.assertEqual(loads(s), v)
  483. def test_longstrings(self):
  484. for i in (203, 65484):
  485. s = os.urandom(i)
  486. v = dumps(s)
  487. self.assertEqual(loads(v), s)
  488. def test_invaliddate(self):
  489. pass
  490. # XXX - add test to reject datetime w/ tzinfo, or that it
  491. # handles it properly
  492. def test_dumps(self):
  493. for i in [ None,
  494. True, False,
  495. -1, 0, 1, 255, 256, -255, -256,
  496. 23498732498723, -2398729387234,
  497. (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  498. float(0), float('-0'), float('inf'), float('-inf'),
  499. float(1.0), float(-1.0), float('353.3487'),
  500. float('2.38723873e+307'), float('2.387349e-317'),
  501. sys.float_info.max, sys.float_info.min,
  502. float('.15625'),
  503. 'weoifjwef',
  504. '\U0001f4a9',
  505. [], [ 1,2,3 ],
  506. {}, { 5: 10, 'adfkj': 34 },
  507. set(), set((1,2,3)),
  508. set((1,'sjlfdkj', None, float('inf'))),
  509. datetime.datetime.utcnow(),
  510. [ datetime.datetime.utcnow(), ' ' ],
  511. datetime.datetime.utcnow().replace(microsecond=0),
  512. datetime.datetime.utcnow().replace(microsecond=1000),
  513. ]:
  514. s = dumps(i)
  515. o = loads(s)
  516. self.assertEqual(i, o)
  517. tobj = { 1: 'dflkj', 5: 'sdlkfj', 'float': 1,
  518. 'largeint': 1<<342, 'list': [ 1, 2, 'str', 'str' ] }
  519. out = dumps(tobj)
  520. self.assertEqual(tobj, loads(out))
  521. def test_coerce(self):
  522. class Foo:
  523. pass
  524. class Bar:
  525. pass
  526. class Baz:
  527. pass
  528. def coerce(obj):
  529. if isinstance(obj, Foo):
  530. return 'list', obj.lst
  531. elif isinstance(obj, Baz):
  532. return 'bytes', obj.s
  533. raise TypeError('unknown type')
  534. ac = ASN1Coder(coerce)
  535. v = [1, 2, 3]
  536. o = Foo()
  537. o.lst = v
  538. self.assertEqual(ac.loads(ac.dumps(o)), v)
  539. self.assertRaises(TypeError, ac.dumps, Bar())
  540. v = b'oiejfd'
  541. o = Baz()
  542. o.s = v
  543. es = ac.dumps(o)
  544. self.assertEqual(ac.loads(es), v)
  545. self.assertIsInstance(es, bytes)
  546. self.assertRaises(TypeError, dumps, o)
  547. def test_loads(self):
  548. self.assertRaises(ValueError, loads, b'\x00\x02\x00')
  549. def test_nodict(self):
  550. '''Verify that ASN1Coder does not support dict.'''
  551. self.assertRaises(KeyError, ASN1Coder().loads, dumps({}))
  552. def test_dumps_default(self):
  553. '''Test that dumps supports the default method, and that
  554. it works.'''
  555. class Dummy(object):
  556. def somefun(self):
  557. return 5
  558. class Dummy2(object):
  559. def somefun(self):
  560. return [ Dummy() ]
  561. def deffun(obj):
  562. try:
  563. return obj.somefun()
  564. except Exception:
  565. raise TypeError
  566. self.assertEqual(dumps(5), dumps(Dummy(), default=deffun))
  567. # Make sure it works for the various containers
  568. self.assertEqual(dumps([5]), dumps([Dummy()], default=deffun))
  569. self.assertEqual(dumps({ 5: 5 }), dumps({ Dummy(): Dummy() },
  570. default=deffun))
  571. self.assertEqual(dumps([5]), dumps(Dummy2(), default=deffun))
  572. # Make sure that an error is raised when the function doesn't work
  573. self.assertRaises(TypeError, dumps, object(), default=deffun)