diff --git a/lark/parsers/earley_forest.py b/lark/parsers/earley_forest.py index c1b2b82..d8f1395 100644 --- a/lark/parsers/earley_forest.py +++ b/lark/parsers/earley_forest.py @@ -133,7 +133,8 @@ class PackedNode(ForestNode): return [x for x in [self.left, self.right] if x is not None] def __iter__(self): - return iter([self.left, self.right]) + yield self.left + yield self.right def __eq__(self, other): if not isinstance(other, PackedNode): @@ -194,7 +195,7 @@ class ForestVisitor(object): # a list of nodes that are currently being visited # used for the `on_cycle` callback - path = list() + path = [] # We do not use recursion here to walk the Forest due to the limited # stack size in python. Therefore input_stack is essentially our stack. @@ -385,7 +386,8 @@ class ForestSumVisitor(ForestVisitor): final tree. """ def visit_packed_node_in(self, node): - return iter([node.left, node.right]) + yield node.left + yield node.right def visit_symbol_node_in(self, node): return iter(node.children) @@ -448,7 +450,7 @@ class ForestToParseTree(ForestTransformer): def _check_cycle(self, node): if self._on_cycle_retreat: - raise Discard + raise Discard() def _collapse_ambig(self, children): new_children = [] @@ -473,7 +475,7 @@ class ForestToParseTree(ForestTransformer): return self.tree_class('_ambig', data) elif data: return data[0] - raise Discard + raise Discard() def transform_symbol_node(self, node, data): self._check_cycle(node) @@ -489,7 +491,7 @@ class ForestToParseTree(ForestTransformer): def transform_packed_node(self, node, data): self._check_cycle(node) - children = list() + children = [] assert len(data) <= 2 data = PackedData(node, data) if data.left is not None: @@ -568,7 +570,7 @@ class TreeForestTransformer(ForestToParseTree): return self.tree_class('_ambig', data) elif data: return data[0] - raise Discard + raise Discard() def __default_token__(self, node): """Default operation on Token (for override). @@ -631,7 +633,8 @@ class ForestToPyDotVisitor(ForestVisitor): graph_node_shape = "diamond" graph_node = self.pydot.Node(graph_node_id, style=graph_node_style, fillcolor="#{:06x}".format(graph_node_color), shape=graph_node_shape, label=graph_node_label) self.graph.add_node(graph_node) - return iter([node.left, node.right]) + yield node.left + yield node.right def visit_packed_node_out(self, node): graph_node_id = str(id(node)) diff --git a/tests/test_tree_forest_transformer.py b/tests/test_tree_forest_transformer.py index e7ca56b..35c5614 100644 --- a/tests/test_tree_forest_transformer.py +++ b/tests/test_tree_forest_transformer.py @@ -46,99 +46,82 @@ class TestTreeForestTransformer(unittest.TestCase): def test_token_calls(self): - visited_A = False - visited_B = False - visited_C = False - visited_D = False + visited = [False] * 4 class CustomTransformer(TreeForestTransformer): def A(self, node): assert node.type == 'A' - nonlocal visited_A - visited_A = True + visited[0] = True def B(self, node): assert node.type == 'B' - nonlocal visited_B - visited_B = True + visited[1] = True def C(self, node): assert node.type == 'C' - nonlocal visited_C - visited_C = True + visited[2] = True def D(self, node): assert node.type == 'D' - nonlocal visited_D - visited_D = True + visited[3] = True tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest) - self.assertTrue(visited_A) - self.assertTrue(visited_B) - self.assertTrue(visited_C) - self.assertTrue(visited_D) + assert visited == [True] * 4 def test_default_token(self): - token_count = 0 + token_count = [0] class CustomTransformer(TreeForestTransformer): def __default_token__(self, node): - nonlocal token_count - token_count += 1 + token_count[0] += 1 assert isinstance(node, Token) tree = CustomTransformer(resolve_ambiguity=True).transform(self.forest) - self.assertEqual(token_count, 4) + self.assertEqual(token_count[0], 4) def test_rule_calls(self): - visited_start = False - visited_ab = False - visited_bc = False - visited_cd = False + visited_start = [False] + visited_ab = [False] + visited_bc = [False] + visited_cd = [False] class CustomTransformer(TreeForestTransformer): def start(self, data): - nonlocal visited_start - visited_start = True + visited_start[0] = True def ab(self, data): - nonlocal visited_ab - visited_ab = True + visited_ab[0] = True def bc(self, data): - nonlocal visited_bc - visited_bc = True + visited_bc[0] = True def cd(self, data): - nonlocal visited_cd - visited_cd = True + visited_cd[0] = True tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest) - self.assertTrue(visited_start) - self.assertTrue(visited_ab) - self.assertTrue(visited_bc) - self.assertTrue(visited_cd) + self.assertTrue(visited_start[0]) + self.assertTrue(visited_ab[0]) + self.assertTrue(visited_bc[0]) + self.assertTrue(visited_cd[0]) def test_default_rule(self): - rule_count = 0 + rule_count = [0] class CustomTransformer(TreeForestTransformer): def __default__(self, name, data): - nonlocal rule_count - rule_count += 1 + rule_count[0] += 1 tree = CustomTransformer(resolve_ambiguity=True).transform(self.forest) - self.assertEqual(rule_count, 4) + self.assertEqual(rule_count[0], 4) def test_default_ambig(self): - ambig_count = 0 + ambig_count = [0] class CustomTransformer(TreeForestTransformer): def __default_ambig__(self, name, data): - nonlocal ambig_count if len(data) > 1: - ambig_count += 1 + ambig_count[0] += 1 tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest) - self.assertEqual(ambig_count, 1) + self.assertEqual(ambig_count[0], 1) def test_handles_ambiguity(self): @@ -164,10 +147,10 @@ class TestTreeForestTransformer(unittest.TestCase): class CustomTransformer(TreeForestTransformer): def bc(self, data): - raise Discard + raise Discard() def D(self, node): - raise Discard + raise Discard() class TreeChecker(Transformer): def bc(self, children): @@ -181,8 +164,8 @@ class TestTreeForestTransformer(unittest.TestCase): def test_aliases(self): - visited_ambiguous = False - visited_full = False + visited_ambiguous = [False] + visited_full = [False] class CustomTransformer(TreeForestTransformer): @handles_ambiguity @@ -191,8 +174,7 @@ class TestTreeForestTransformer(unittest.TestCase): assert tree.data == 'ambiguous' or tree.data == 'full' def ambiguous(self, data): - nonlocal visited_ambiguous - visited_ambiguous = True + visited_ambiguous[0] = True assert len(data) == 3 assert data[0].data == 'ab' assert data[1].data == 'bc' @@ -200,8 +182,7 @@ class TestTreeForestTransformer(unittest.TestCase): return self.tree_class('ambiguous', data) def full(self, data): - nonlocal visited_full - visited_full = True + visited_full[0] = True assert len(data) == 1 assert data[0].data == 'abcd' return self.tree_class('full', data) @@ -218,8 +199,8 @@ class TestTreeForestTransformer(unittest.TestCase): l = Lark(grammar, parser='earley', ambiguity='forest') forest = l.parse('ABCD') tree = CustomTransformer(resolve_ambiguity=False).transform(forest) - self.assertTrue(visited_ambiguous) - self.assertTrue(visited_full) + self.assertTrue(visited_ambiguous[0]) + self.assertTrue(visited_full[0]) def test_transformation(self):