from __future__ import absolute_import import unittest from unittest import TestCase import copy import pickle import functools from lark.tree import Tree from lark.lexer import Token from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard, Transformer_InPlace, \ Transformer_InPlaceRecursive, Transformer_NonRecursive class TestTrees(TestCase): def setUp(self): self.tree1 = Tree('a', [Tree(x, y) for x, y in zip('bcd', 'xyz')]) def test_deepcopy(self): assert self.tree1 == copy.deepcopy(self.tree1) def test_pickle(self): s = copy.deepcopy(self.tree1) data = pickle.dumps(s, protocol=pickle.HIGHEST_PROTOCOL) assert pickle.loads(data) == s def test_repr_runnable(self): assert self.tree1 == eval(repr(self.tree1)) def test_iter_subtrees(self): expected = [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z'), Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')])] nodes = list(self.tree1.iter_subtrees()) self.assertEqual(nodes, expected) def test_iter_subtrees_topdown(self): expected = [Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]), Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')] nodes = list(self.tree1.iter_subtrees_topdown()) self.assertEqual(nodes, expected) def test_visitor(self): class Visitor1(Visitor): def __init__(self): self.nodes=[] def __default__(self,tree): self.nodes.append(tree) class Visitor1_Recursive(Visitor_Recursive): def __init__(self): self.nodes=[] def __default__(self,tree): self.nodes.append(tree) visitor1=Visitor1() visitor1_recursive=Visitor1_Recursive() expected_top_down = [Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]), Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')] expected_botton_up= [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z'), Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')])] visitor1.visit(self.tree1) self.assertEqual(visitor1.nodes,expected_botton_up) visitor1_recursive.visit(self.tree1) self.assertEqual(visitor1_recursive.nodes,expected_botton_up) visitor1.nodes=[] visitor1_recursive.nodes=[] visitor1.visit_topdown(self.tree1) self.assertEqual(visitor1.nodes,expected_top_down) visitor1_recursive.visit_topdown(self.tree1) self.assertEqual(visitor1_recursive.nodes,expected_top_down) def test_interp(self): t = Tree('a', [Tree('b', []), Tree('c', []), 'd']) class Interp1(Interpreter): def a(self, tree): return self.visit_children(tree) + ['e'] def b(self, tree): return 'B' def c(self, tree): return 'C' self.assertEqual(Interp1().visit(t), list('BCde')) class Interp2(Interpreter): @visit_children_decor def a(self, values): return values + ['e'] def b(self, tree): return 'B' def c(self, tree): return 'C' self.assertEqual(Interp2().visit(t), list('BCde')) class Interp3(Interpreter): def b(self, tree): return 'B' def c(self, tree): return 'C' self.assertEqual(Interp3().visit(t), list('BCd')) def test_transformer(self): t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) class T(Transformer): i = v_args(inline=True)(int) f = v_args(inline=True)(float) sub = lambda self, values: values[0] - values[1] def add(self, values): return sum(values) res = T().transform(t) self.assertEqual(res, 2.9) @v_args(inline=True) class T(Transformer): i = int f = float sub = lambda self, a, b: a-b def add(self, a, b): return a + b res = T().transform(t) self.assertEqual(res, 2.9) @v_args(inline=True) class T(Transformer): i = int f = float from operator import sub, add res = T().transform(t) self.assertEqual(res, 2.9) def test_vargs(self): @v_args() class MyTransformer(Transformer): @staticmethod def integer(args): return 1 # some code here @classmethod def integer2(cls, args): return 2 # some code here hello = staticmethod(lambda args: 'hello') x = MyTransformer().transform( Tree('integer', [2])) self.assertEqual(x, 1) x = MyTransformer().transform( Tree('integer2', [2])) self.assertEqual(x, 2) x = MyTransformer().transform( Tree('hello', [2])) self.assertEqual(x, 'hello') def test_inline_static(self): @v_args(inline=True) class T(Transformer): @staticmethod def test(a, b): return a + b x = T().transform(Tree('test', ['a', 'b'])) self.assertEqual(x, 'ab') def test_vargs_override(self): t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) @v_args(inline=True) class T(Transformer): i = int f = float sub = lambda self, a, b: a-b not_a_method = {'other': 'stuff'} @v_args(inline=False) def add(self, values): return sum(values) res = T().transform(t) self.assertEqual(res, 2.9) def test_partial(self): tree = Tree("start", [Tree("a", ["test1"]), Tree("b", ["test2"])]) def test(prefix, s, postfix): return prefix + s.upper() + postfix @v_args(inline=True) class T(Transformer): a = functools.partial(test, "@", postfix="!") b = functools.partial(lambda s: s + "!") res = T().transform(tree) assert res.children == ["@TEST1!", "test2!"] def test_discard(self): class MyTransformer(Transformer): def a(self, args): return 1 # some code here def b(cls, args): raise Discard() t = Tree('root', [ Tree('b', []), Tree('a', []), Tree('b', []), Tree('c', []), Tree('b', []), ]) t2 = Tree('root', [1, Tree('c', [])]) x = MyTransformer().transform( t ) self.assertEqual(x, t2) def test_transformer_variants(self): tree = Tree('start', [Tree('add', [Token('N', '1'), Token('N', '2')]), Tree('add', [Token('N', '3'), Token('N', '4')])]) for base in (Transformer, Transformer_InPlace, Transformer_NonRecursive, Transformer_InPlaceRecursive): class T(base): def add(self, children): return sum(children) def N(self, token): return int(token) copied = copy.deepcopy(tree) result = T().transform(copied) self.assertEqual(result, Tree('start', [3, 7])) if __name__ == '__main__': unittest.main()