|
- from __future__ import absolute_import
-
- import unittest
-
- from lark import Lark
- from lark.lexer import Token
- from lark.tree import Tree
- from lark.visitors import Visitor, Transformer, Discard
- from lark.parsers.earley_forest import TreeForestTransformer, handles_ambiguity
-
- class TestTreeForestTransformer(unittest.TestCase):
-
- grammar = """
- start: ab bc cd
- !ab: "A" "B"?
- !bc: "B"? "C"?
- !cd: "C"? "D"
- """
-
- parser = Lark(grammar, parser='earley', ambiguity='forest')
- forest = parser.parse("ABCD")
-
- def test_identity_resolve_ambiguity(self):
- l = Lark(self.grammar, parser='earley', ambiguity='resolve')
- tree1 = l.parse("ABCD")
- tree2 = TreeForestTransformer(resolve_ambiguity=True).transform(self.forest)
- self.assertEqual(tree1, tree2)
-
- def test_identity_explicit_ambiguity(self):
- l = Lark(self.grammar, parser='earley', ambiguity='explicit')
- tree1 = l.parse("ABCD")
- tree2 = TreeForestTransformer(resolve_ambiguity=False).transform(self.forest)
- self.assertEqual(tree1, tree2)
-
- def test_tree_class(self):
-
- class CustomTree(Tree):
- pass
-
- class TreeChecker(Visitor):
- def __default__(self, tree):
- assert isinstance(tree, CustomTree)
-
- tree = TreeForestTransformer(resolve_ambiguity=False, tree_class=CustomTree).transform(self.forest)
- TreeChecker().visit(tree)
-
- def test_token_calls(self):
-
- visited = [False] * 4
-
- class CustomTransformer(TreeForestTransformer):
- def A(self, node):
- assert node.type == 'A'
- visited[0] = True
- def B(self, node):
- assert node.type == 'B'
- visited[1] = True
- def C(self, node):
- assert node.type == 'C'
- visited[2] = True
- def D(self, node):
- assert node.type == 'D'
- visited[3] = True
-
- tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
- assert visited == [True] * 4
-
- def test_default_token(self):
-
- token_count = [0]
-
- class CustomTransformer(TreeForestTransformer):
- def __default_token__(self, node):
- token_count[0] += 1
- assert isinstance(node, Token)
-
- tree = CustomTransformer(resolve_ambiguity=True).transform(self.forest)
- self.assertEqual(token_count[0], 4)
-
- def test_rule_calls(self):
-
- visited_start = [False]
- visited_ab = [False]
- visited_bc = [False]
- visited_cd = [False]
-
- class CustomTransformer(TreeForestTransformer):
- def start(self, data):
- visited_start[0] = True
- def ab(self, data):
- visited_ab[0] = True
- def bc(self, data):
- visited_bc[0] = True
- def cd(self, data):
- visited_cd[0] = True
-
- tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
- self.assertTrue(visited_start[0])
- self.assertTrue(visited_ab[0])
- self.assertTrue(visited_bc[0])
- self.assertTrue(visited_cd[0])
-
- def test_default_rule(self):
-
- rule_count = [0]
-
- class CustomTransformer(TreeForestTransformer):
- def __default__(self, name, data):
- rule_count[0] += 1
-
- tree = CustomTransformer(resolve_ambiguity=True).transform(self.forest)
- self.assertEqual(rule_count[0], 4)
-
- def test_default_ambig(self):
-
- ambig_count = [0]
-
- class CustomTransformer(TreeForestTransformer):
- def __default_ambig__(self, name, data):
- if len(data) > 1:
- ambig_count[0] += 1
-
- tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
- self.assertEqual(ambig_count[0], 1)
-
- def test_handles_ambiguity(self):
-
- class CustomTransformer(TreeForestTransformer):
- @handles_ambiguity
- def start(self, data):
- assert isinstance(data, list)
- assert len(data) == 4
- for tree in data:
- assert tree.data == 'start'
- return 'handled'
-
- @handles_ambiguity
- def ab(self, data):
- assert isinstance(data, list)
- assert len(data) == 1
- assert data[0].data == 'ab'
-
- tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
- self.assertEqual(tree, 'handled')
-
- def test_discard(self):
-
- class CustomTransformer(TreeForestTransformer):
- def bc(self, data):
- raise Discard()
-
- def D(self, node):
- raise Discard()
-
- class TreeChecker(Transformer):
- def bc(self, children):
- assert False
-
- def D(self, token):
- assert False
-
- tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
- TreeChecker(visit_tokens=True).transform(tree)
-
- def test_aliases(self):
-
- visited_ambiguous = [False]
- visited_full = [False]
-
- class CustomTransformer(TreeForestTransformer):
- @handles_ambiguity
- def start(self, data):
- for tree in data:
- assert tree.data == 'ambiguous' or tree.data == 'full'
-
- def ambiguous(self, data):
- visited_ambiguous[0] = True
- assert len(data) == 3
- assert data[0].data == 'ab'
- assert data[1].data == 'bc'
- assert data[2].data == 'cd'
- return self.tree_class('ambiguous', data)
-
- def full(self, data):
- visited_full[0] = True
- assert len(data) == 1
- assert data[0].data == 'abcd'
- return self.tree_class('full', data)
-
- grammar = """
- start: ab bc cd -> ambiguous
- | abcd -> full
- !ab: "A" "B"?
- !bc: "B"? "C"?
- !cd: "C"? "D"
- !abcd: "ABCD"
- """
-
- l = Lark(grammar, parser='earley', ambiguity='forest')
- forest = l.parse('ABCD')
- tree = CustomTransformer(resolve_ambiguity=False).transform(forest)
- self.assertTrue(visited_ambiguous[0])
- self.assertTrue(visited_full[0])
-
- def test_transformation(self):
-
- class CustomTransformer(TreeForestTransformer):
- def __default__(self, name, data):
- result = []
- for item in data:
- if isinstance(item, list):
- result += item
- else:
- result.append(item)
- return result
-
- def __default_token__(self, node):
- return node.lower()
-
- def __default_ambig__(self, name, data):
- return data[0]
-
- result = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
- expected = ['a', 'b', 'c', 'd']
- self.assertEqual(result, expected)
-
- if __name__ == '__main__':
- unittest.main()
|