|
- import base64
- import copy
- import datetime
- import itertools
- import json
- import unittest
- import uuid
-
- from .utils import _makeuuid, _makedatetime, _makebytes, _asn1coder
-
- class _JSONEncoder(json.JSONEncoder):
- def default(self, o):
- if isinstance(o, uuid.UUID):
- return str(o)
- elif isinstance(o, datetime.datetime):
- o = o.astimezone(datetime.timezone.utc)
- return o.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
- elif isinstance(o, bytes):
- return base64.urlsafe_b64encode(o).decode('US-ASCII')
-
- return json.JSONEncoder.default(self, o)
-
- _jsonencoder = _JSONEncoder()
-
- class _TestJSONEncoder(unittest.TestCase):
- def test_defaultfailure(self):
- class Foo:
- pass
-
- self.assertRaises(TypeError, _jsonencoder.encode, Foo())
-
- # XXX - add validation
- # XXX - how to add singletons
- class MDBase(object):
- '''This is a simple wrapper that turns a JSON object into a pythonesc
- object where attribute accesses work.'''
-
- _type = 'invalid'
-
- _generated_properties = {
- 'uuid': uuid.uuid4,
- 'modified': lambda: datetime.datetime.now(
- tz=datetime.timezone.utc),
- }
-
- # When decoding, the decoded value should be passed to this function
- # to get the correct type
- _instance_properties = {
- 'uuid': _makeuuid,
- 'modified': _makedatetime,
- 'created_by_ref': _makeuuid,
- 'parent_refs': lambda x: [ _makeuuid(y) for y in x ],
- 'sig': _makebytes,
- }
-
- # Override on a per subclass basis
- _class_instance_properties = {
- }
-
- _common_properties = [ 'type', 'created_by_ref' ] # XXX - add lang?
- _common_optional = set(('parent_refs', 'sig'))
- _common_names = set(_common_properties + list(
- _generated_properties.keys()))
- _common_names_list = _common_properties + list(
- _generated_properties.keys())
-
- def __init__(self, obj={}, **kwargs):
- obj = copy.deepcopy(obj)
- obj.update(kwargs)
-
- if self._type == MDBase._type:
- raise ValueError('call MDBase.create_obj instead so correct class is used.')
-
- if 'type' in obj and obj['type'] != self._type:
- raise ValueError(
- 'trying to create the wrong type of object, got: %s, expected: %s' %
- (repr(obj['type']), repr(self._type)))
-
- if 'type' not in obj:
- obj['type'] = self._type
-
- for x in self._common_properties:
- if x not in obj:
- raise ValueError('common property %s not present' % repr(x))
-
- for x, fun in itertools.chain(
- self._instance_properties.items(),
- self._class_instance_properties.items()):
- if x in obj:
- obj[x] = fun(obj[x])
-
- for x, fun in self._generated_properties.items():
- if x not in obj:
- obj[x] = fun()
-
- self._obj = obj
-
- @classmethod
- def create_obj(cls, obj):
- '''Using obj as a base, create an instance of MDBase of the
- correct type.
-
- If the correct type is not found, a ValueError is raised.'''
-
- if isinstance(obj, cls):
- obj = obj._obj
-
- ty = obj['type']
-
- for i in MDBase.__subclasses__():
- if i._type == ty:
- return i(obj)
- else:
- raise ValueError('Unable to find class for type %s' %
- repr(ty))
-
- def new_version(self, *args, dels=(), replaces=()):
- '''For each k, v pair, add the property k as an additional one
- (or new one if first), with the value v.
-
- Any key in dels is removed.
-
- Any k, v pair in replaces, replaces the entire key.'''
-
- obj = copy.deepcopy(self._obj)
-
- common = self._common_names | self._common_optional
- uniquify = set()
- for k, v in args:
- if k in common:
- obj[k] = v
- else:
- uniquify.add(k)
- obj.setdefault(k, []).append(v)
-
- for k in uniquify:
- obj[k] = list(set(obj[k]))
-
- for i in dels:
- del obj[i]
-
- for k, v in replaces:
- obj[k] = v
-
- del obj['modified']
-
- return self.create_obj(obj)
-
- def __repr__(self): # pragma: no cover
- return '%s(%s)' % (self.__class__.__name__, repr(self._obj))
-
- def __getattr__(self, k):
- try:
- return self._obj[k]
- except KeyError:
- raise AttributeError(k)
-
- def __setattr__(self, k, v):
- if k[0] == '_': # direct attribute
- self.__dict__[k] = v
- else:
- self._obj[k] = v
-
- def __getitem__(self, k):
- return self._obj[k]
-
- def __to_dict__(self):
- '''Returns an internal object. If modification is necessary,
- make sure to .copy() it first.'''
-
- return self._obj
-
- def __eq__(self, o):
- return self._obj == o
-
- def __contains__(self, k):
- return k in self._obj
-
- def items(self, skipcommon=True):
- return [ (k, v) for k, v in self._obj.items() if
- not skipcommon or k not in self._common_names ]
-
- def encode(self, meth='asn1'):
- if meth == 'asn1':
- return _asn1coder.dumps(self)
-
- return _jsonencoder.encode(self._obj)
-
- @classmethod
- def decode(cls, s, meth='asn1'):
- if meth == 'asn1':
- obj = _asn1coder.loads(s)
- else:
- obj = json.loads(s)
-
- return cls.create_obj(obj)
|