diff --git a/docs/classes.md b/docs/classes.md index 3a69f73..68c133c 100644 --- a/docs/classes.md +++ b/docs/classes.md @@ -76,7 +76,11 @@ Returns all nodes of the tree whose data equals the given data. #### iter_subtrees(self) -Iterates over all the subtrees, never returning to the same node twice (Lark's parse-tree is actually a DAG) +Iterates over all the subtrees, never returning to the same node twice (Lark's parse-tree is actually a DAG). + +#### iter_subtrees_topdown(self) + +Iterates over all the subtrees, return nodes in order like pretty() does. #### \_\_eq\_\_, \_\_hash\_\_ diff --git a/lark/tree.py b/lark/tree.py index c406f45..e892a44 100644 --- a/lark/tree.py +++ b/lark/tree.py @@ -5,6 +5,7 @@ except ImportError: from copy import deepcopy + ###{standalone class Meta: pass @@ -42,6 +43,7 @@ class Tree(object): def pretty(self, indent_str=' '): return ''.join(self._pretty(0, indent_str)) + def __eq__(self, other): try: return self.data == other.data and self.children == other.children @@ -99,12 +101,22 @@ class Tree(object): yield x seen.add(id(x)) + def iter_subtrees_topdown(self): + stack = [self] + while stack: + node = stack.pop() + if not isinstance(node, Tree): + continue + yield node + for n in reversed(node.children): + stack.append(n) def __deepcopy__(self, memo): return type(self)(self.data, deepcopy(self.children, memo)) def copy(self): return type(self)(self.data, self.children) + def set(self, data, children): self.data = data self.children = children diff --git a/tests/test_trees.py b/tests/test_trees.py index 564d02b..7e1d841 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -21,6 +21,17 @@ class TestTrees(TestCase): data = pickle.dumps(s) assert pickle.loads(data) == s + def test_iter_subtrees(self): + expected = [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z'), + Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')])] + nodes = list(self.tree1.iter_subtrees()) + self.assertEqual(nodes, expected) + + def test_iter_subtrees_topdown(self): + expected = [Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]), + Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')] + nodes = list(self.tree1.iter_subtrees_topdown()) + self.assertEqual(nodes, expected) def test_interp(self): t = Tree('a', [Tree('b', []), Tree('c', []), 'd'])