From b272dde726b1207ecf7c8a0688f223992058a847 Mon Sep 17 00:00:00 2001 From: Chanic Panic Date: Mon, 28 Dec 2020 20:42:20 -0800 Subject: [PATCH] Add cache to ForestToParseTree --- lark/parsers/earley_forest.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/lark/parsers/earley_forest.py b/lark/parsers/earley_forest.py index de92fb5..f39f4eb 100644 --- a/lark/parsers/earley_forest.py +++ b/lark/parsers/earley_forest.py @@ -499,14 +499,18 @@ class ForestToParseTree(ForestTransformer): resolve_ambiguity: If True, ambiguities will be resolved based on priorities. Otherwise, `_ambig` nodes will be in the resulting tree. + use_cache: If True, the results of packed node transformations will be + cached. """ - def __init__(self, tree_class=Tree, callbacks=dict(), prioritizer=ForestSumVisitor(), resolve_ambiguity=True): + def __init__(self, tree_class=Tree, callbacks=dict(), prioritizer=ForestSumVisitor(), resolve_ambiguity=True, use_cache=True): super(ForestToParseTree, self).__init__() self.tree_class = tree_class self.callbacks = callbacks self.prioritizer = prioritizer self.resolve_ambiguity = resolve_ambiguity + self._use_cache = use_cache + self._cache = {} self._on_cycle_retreat = False self._cycle_node = None self._successful_visits = set() @@ -515,6 +519,7 @@ class ForestToParseTree(ForestTransformer): if self.prioritizer: self.prioritizer.visit(root) super(ForestToParseTree, self).visit(root) + self._cache = {} def on_cycle(self, node, path): logger.debug("Cycle encountered in the SPPF at node: %s. " @@ -578,6 +583,8 @@ class ForestToParseTree(ForestTransformer): self._check_cycle(node) if self.resolve_ambiguity and id(node.parent) in self._successful_visits: raise Discard() + if self._use_cache and id(node) in self._cache: + return self._cache[id(node)] children = [] assert len(data) <= 2 data = PackedData(node, data) @@ -589,8 +596,8 @@ class ForestToParseTree(ForestTransformer): if data.right is not PackedData.NO_DATA: children.append(data.right) if node.parent.is_intermediate: - return children - return self._call_rule_func(node, children) + return self._cache.setdefault(id(node), children) + return self._cache.setdefault(id(node), self._call_rule_func(node, children)) def visit_symbol_node_in(self, node): super(ForestToParseTree, self).visit_symbol_node_in(node) @@ -602,7 +609,8 @@ class ForestToParseTree(ForestTransformer): self._on_cycle_retreat = False to_visit = super(ForestToParseTree, self).visit_packed_node_in(node) if not self.resolve_ambiguity or id(node.parent) not in self._successful_visits: - return to_visit + if not self._use_cache or id(node) not in self._cache: + return to_visit def visit_packed_node_out(self, node): super(ForestToParseTree, self).visit_packed_node_out(node) @@ -647,10 +655,14 @@ class TreeForestTransformer(ForestToParseTree): nodes in the SPPF. :param resolve_ambiguity: If True, ambiguities will be resolved based on priorities. + :param use_cache: If True, caches the results of some transformations, + potentially improving performance when ``resolve_ambiguity==False``. + Only use if you know what you are doing: i.e. All transformation + functions are pure and referentially transparent. """ - def __init__(self, tree_class=Tree, prioritizer=ForestSumVisitor(), resolve_ambiguity=True): - super(TreeForestTransformer, self).__init__(tree_class, dict(), prioritizer, resolve_ambiguity) + def __init__(self, tree_class=Tree, prioritizer=ForestSumVisitor(), resolve_ambiguity=True, use_cache=False): + super(TreeForestTransformer, self).__init__(tree_class, dict(), prioritizer, resolve_ambiguity, use_cache) def __default__(self, name, data): """Default operation on tree (for override).