diff --git a/lark/grammar.py b/lark/grammar.py index cf8cf64..bb84351 100644 --- a/lark/grammar.py +++ b/lark/grammar.py @@ -49,19 +49,21 @@ class NonTerminal(Symbol): class RuleOptions(Serialize): - __serialize_fields__ = 'keep_all_tokens', 'expand1', 'priority', 'empty_indices' + __serialize_fields__ = 'keep_all_tokens', 'expand1', 'priority', 'template_source', 'empty_indices' - def __init__(self, keep_all_tokens=False, expand1=False, priority=None, empty_indices=()): + def __init__(self, keep_all_tokens=False, expand1=False, priority=None, template_source=None, empty_indices=()): self.keep_all_tokens = keep_all_tokens self.expand1 = expand1 self.priority = priority + self.template_source = template_source self.empty_indices = empty_indices def __repr__(self): - return 'RuleOptions(%r, %r, %r)' % ( + return 'RuleOptions(%r, %r, %r, %r)' % ( self.keep_all_tokens, self.expand1, self.priority, + self.template_source ) diff --git a/lark/load_grammar.py b/lark/load_grammar.py index c1ba95f..3deb758 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -383,17 +383,6 @@ class ApplyTemplates(Transformer_InPlace): result_tree = deepcopy(tree) self.replacer.names = dict(zip(params, args)) self.replacer.transform(result_tree) - if name[0] != '_': - if result_tree.data == 'expansions': - t = result_tree - while len(t.children) == 2: - if t.children[-1].data != 'alias': - t.children[-1] = ST('alias', [t.children[-1], name]) - t = t.children[0] - if t.children[-1].data != 'alias': - t.children[-1] = ST('alias', [t.children[-1], name]) - elif result_tree.data != 'alias': - result_tree = ST('alias', [result_tree, name]) self.rule_defs.append((result_name, [], result_tree, deepcopy(options))) return NonTerminal(result_name) @@ -736,7 +725,8 @@ def options_from_rule(name, params, *x): expand1 = name.startswith('?') name = name.lstrip('?') - return name, params, expansions, RuleOptions(keep_all_tokens, expand1, priority=priority) + return name, params, expansions, RuleOptions(keep_all_tokens, expand1, priority=priority, + template_source=(name if params else None)) def symbols_from_strcase(expansion): diff --git a/lark/parse_tree_builder.py b/lark/parse_tree_builder.py index 11c7fac..4a9edd3 100644 --- a/lark/parse_tree_builder.py +++ b/lark/parse_tree_builder.py @@ -227,9 +227,10 @@ class ParseTreeBuilder: options = rule.options keep_all_tokens = self.always_keep_all_tokens or options.keep_all_tokens expand_single_child = options.expand1 + from_template = options.template_source is not None wrapper_chain = list(filter(None, [ - (expand_single_child and not rule.alias) and ExpandSingleChild, + (expand_single_child and not (rule.alias and not from_template)) and ExpandSingleChild, maybe_create_child_filter(rule.expansion, keep_all_tokens, self.ambiguous, options.empty_indices if self.maybe_placeholders else None), self.propagate_positions and PropagatePositions, self.ambiguous and maybe_create_ambiguous_expander(self.tree_class, rule.expansion, keep_all_tokens), @@ -243,7 +244,7 @@ class ParseTreeBuilder: for rule, wrapper_chain in self.rule_builders: - user_callback_name = rule.alias or rule.origin.name + user_callback_name = rule.alias or rule.options.template_source or rule.origin.name try: f = getattr(transformer, user_callback_name) # XXX InlineTransformer is deprecated! diff --git a/tests/test_parser.py b/tests/test_parser.py index 6ac4a98..6b9df3f 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -863,6 +863,38 @@ def _make_parser_test(LEXER, PARSER): x = g.parse("[1]") self.assertSequenceEqual(x.children, [Tree('sep', ['1'])]) + def test_templates_alias(self): + g = _Lark(r""" + start: expr{"C"} + expr{t}: "A" t + | "B" t -> b + """) + x = g.parse("AC") + self.assertSequenceEqual(x.children, [Tree('expr', [])]) + x = g.parse("BC") + self.assertSequenceEqual(x.children, [Tree('b', [])]) + + def test_templates_modifiers(self): + g = _Lark(r""" + start: expr{"B"} + !expr{t}: "A" t + """) + x = g.parse("AB") + self.assertSequenceEqual(x.children, [Tree('expr', ["A", "B"])]) + g = _Lark(r""" + start: _expr{"B"} + !_expr{t}: "A" t + """) + x = g.parse("AB") + self.assertSequenceEqual(x.children, ["A", "B"]) + g = _Lark(r""" + start: expr{b} + b: "B" + ?expr{t}: "A" t + """) + x = g.parse("AB") + self.assertSequenceEqual(x.children, [Tree('b',[])]) + def test_g_regex_flags(self): g = _Lark(""" start: "a" /b+/ C