try: from future_builtins import filter except ImportError: pass from copy import deepcopy from .utils import inline_args class Tree(object): def __init__(self, data, children): self.data = data self.children = list(children) def __repr__(self): return 'Tree(%s, %s)' % (self.data, self.children) def _pretty_label(self): return self.data def _pretty(self, level, indent_str): if len(self.children) == 1 and not isinstance(self.children[0], Tree): return [ indent_str*level, self._pretty_label(), '\t', '%s' % self.children[0], '\n'] l = [ indent_str*level, self._pretty_label(), '\n' ] for n in self.children: if isinstance(n, Tree): l += n._pretty(level+1, indent_str) else: l += [ indent_str*(level+1), '%s' % n, '\n' ] return l def pretty(self, indent_str=' '): return ''.join(self._pretty(0, indent_str)) def expand_kids_by_index(self, *indices): for i in sorted(indices, reverse=True): # reverse so that changing tail won't affect indices kid = self.children[i] self.children[i:i+1] = kid.children def __eq__(self, other): try: return self.data == other.data and self.children == other.children except AttributeError: return False def __ne__(self, other): return not (self == other) def __hash__(self): return hash((self.data, tuple(self.children))) def find_pred(self, pred): return filter(pred, self.iter_subtrees()) def find_data(self, data): return self.find_pred(lambda t: t.data == data) def scan_values(self, pred): for c in self.children: if isinstance(c, Tree): for t in c.scan_values(pred): yield t else: if pred(c): yield c def iter_subtrees(self): # TODO: Re-write as a more efficient version visited = set() q = [self] l = [] while q: subtree = q.pop() l.append( subtree ) if id(subtree) in visited: continue # already been here from another branch visited.add(id(subtree)) q += [c for c in subtree.children if isinstance(c, Tree)] seen = set() for x in reversed(l): if id(x) not in seen: yield x seen.add(id(x)) def __deepcopy__(self, memo): return type(self)(self.data, deepcopy(self.children, memo)) def copy(self): return type(self)(self.data, self.children) def set(self, data, children): self.data = data self.children = children class Transformer(object): def _get_func(self, name): return getattr(self, name) def transform(self, tree): items = [] for c in tree.children: try: items.append(self.transform(c) if isinstance(c, Tree) else c) except Discard: pass try: f = self._get_func(tree.data) except AttributeError: return self.__default__(tree.data, items) else: return f(items) def __default__(self, data, children): return Tree(data, children) def __mul__(self, other): return TransformerChain(self, other) class Discard(Exception): pass class TransformerChain(object): def __init__(self, *transformers): self.transformers = transformers def transform(self, tree): for t in self.transformers: tree = t.transform(tree) return tree def __mul__(self, other): return TransformerChain(*self.transformers + (other,)) class InlineTransformer(Transformer): def _get_func(self, name): # use super()._get_func return inline_args(getattr(self, name)).__get__(self) class Visitor(object): def visit(self, tree): for child in tree.children: if isinstance(child, Tree): self.visit(child) f = getattr(self, tree.data, self.__default__) f(tree) return tree def __default__(self, tree): pass class Visitor_NoRecurse(Visitor): def visit(self, tree): subtrees = list(tree.iter_subtrees()) for subtree in (subtrees): getattr(self, subtree.data, self.__default__)(subtree) return tree class Transformer_NoRecurse(Transformer): def transform(self, tree): subtrees = list(tree.iter_subtrees()) def _t(t): # Assumes t is already transformed try: f = self._get_func(t.data) except AttributeError: return self.__default__(t) else: return f(t) for subtree in subtrees: children = [] for c in subtree.children: try: children.append(_t(c) if isinstance(c, Tree) else c) except Discard: pass subtree.children = children return _t(tree) def __default__(self, t): return t def pydot__tree_to_png(tree, filename): import pydot graph = pydot.Dot(graph_type='digraph', rankdir="LR") i = [0] def new_leaf(leaf): node = pydot.Node(i[0], label=repr(leaf)) i[0] += 1 graph.add_node(node) return node def _to_pydot(subtree): color = hash(subtree.data) & 0xffffff color |= 0x808080 subnodes = [_to_pydot(child) if isinstance(child, Tree) else new_leaf(child) for child in subtree.children] node = pydot.Node(i[0], style="filled", fillcolor="#%x"%color, label=subtree.data) i[0] += 1 graph.add_node(node) for subnode in subnodes: graph.add_edge(pydot.Edge(node, subnode)) return node _to_pydot(tree) graph.write_png(filename)