diff --git a/lark/parse_tree_builder.py b/lark/parse_tree_builder.py index e0041b9..de88cd1 100644 --- a/lark/parse_tree_builder.py +++ b/lark/parse_tree_builder.py @@ -1,3 +1,5 @@ +from functools import partial + from .common import is_terminal, GrammarError from .utils import suppress from .lexer import Token @@ -5,15 +7,7 @@ from .grammar import Rule ###{standalone -class NodeBuilder: - def __init__(self, tree_class, name): - self.tree_class = tree_class - self.name = name - - def __call__(self, children): - return self.tree_class(self.name, children) - -class Expand1: +class ExpandSingleChild: def __init__(self, node_builder): self.node_builder = node_builder @@ -23,57 +17,17 @@ class Expand1: else: return self.node_builder(children) -class Factory: - def __init__(self, cls, *args): - self.cls = cls - self.args = args - - def __call__(self, node_builder): - return self.cls(node_builder, *self.args) - -class TokenWrapper: +class CreateToken: "Used for fixing the results of scanless parsing" - def __init__(self, node_builder, token_name): + def __init__(self, token_name, node_builder): self.node_builder = node_builder self.token_name = token_name def __call__(self, children): return self.node_builder( [Token(self.token_name, ''.join(children))] ) -def identity(node_builder): - return node_builder - - -class ChildFilter: - def __init__(self, node_builder, to_include): - self.node_builder = node_builder - self.to_include = to_include - - def __call__(self, children): - filtered = [] - for i, to_expand in self.to_include: - if to_expand: - filtered += children[i].children - else: - filtered.append(children[i]) - - return self.node_builder(filtered) - -def create_rule_handler(expansion, keep_all_tokens, filter_out): - # if not keep_all_tokens: - to_include = [(i, not is_terminal(sym) and sym.startswith('_')) - for i, sym in enumerate(expansion) - if keep_all_tokens - or not ((is_terminal(sym) and sym.startswith('_')) or sym in filter_out) - ] - - if len(to_include) < len(expansion) or any(to_expand for i, to_expand in to_include): - return Factory(ChildFilter, to_include) - - # else, if no filtering required.. - return identity class PropagatePositions: def __init__(self, node_builder): @@ -98,6 +52,31 @@ class PropagatePositions: return res +class ChildFilter: + def __init__(self, to_include, node_builder): + self.node_builder = node_builder + self.to_include = to_include + + def __call__(self, children): + filtered = [] + for i, to_expand in self.to_include: + if to_expand: + filtered += children[i].children + else: + filtered.append(children[i]) + + return self.node_builder(filtered) + +def _should_expand(sym): + return not is_terminal(sym) and sym.startswith('_') + +def maybe_create_child_filter(expansion, filter_out): + to_include = [(i, _should_expand(sym)) for i, sym in enumerate(expansion) if sym not in filter_out] + + if len(to_include) < len(expansion) or any(to_expand for i, to_expand in to_include): + return partial(ChildFilter, to_include) + + class Callback(object): pass @@ -112,22 +91,20 @@ class ParseTreeBuilder: self.user_aliases = {} def _init_builders(self, rules): - filter_out = set() - for rule in rules: - if rule.options and rule.options.filter_out: - assert rule.origin.startswith('_') # Just to make sure - filter_out.add(rule.origin) + filter_out = {rule.origin for rule in rules if rule.options and rule.options.filter_out} + filter_out |= {sym for rule in rules for sym in rule.expansion if is_terminal(sym) and sym.startswith('_')} + assert all(x.startswith('_') for x in filter_out) for rule in rules: options = rule.options keep_all_tokens = self.always_keep_all_tokens or (options.keep_all_tokens if options else False) - expand1 = options.expand1 if options else False + expand_single_child = options.expand1 if options else False create_token = options.create_token if options else False wrapper_chain = filter(None, [ - (expand1 and not rule.alias) and Expand1, - create_token and Factory(TokenWrapper, create_token), - create_rule_handler(rule.expansion, keep_all_tokens, filter_out), + create_token and partial(CreateToken, create_token), + (expand_single_child and not rule.alias) and ExpandSingleChild, + maybe_create_child_filter(rule.expansion, () if keep_all_tokens else filter_out), self.propagate_positions and PropagatePositions, ]) @@ -144,7 +121,7 @@ class ParseTreeBuilder: try: f = transformer._get_func(user_callback_name) except AttributeError: - f = NodeBuilder(self.tree_class, user_callback_name) + f = partial(self.tree_class, user_callback_name) self.user_aliases[rule] = rule.alias rule.alias = internal_callback_name