diff --git a/lark/load_grammar.py b/lark/load_grammar.py index 16f69ac..080cc37 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -525,8 +525,6 @@ class Grammar: # ================= # Compile Rules # ================= - - # TODO: add templates # 1. Pre-process terminals transformer = PrepareLiterals() * PrepareSymbols() * PrepareAnonTerminals(terminals) # Adds to terminals @@ -640,7 +638,7 @@ def import_from_grammar_into_namespace(grammar, namespace, aliases): imported_terms = dict(grammar.term_defs) imported_rules = {n:(n,deepcopy(t),o) for n,t,o in grammar.rule_defs} - imported_temps = {n:(n,deepcopy(t),o) for n,t,o in grammar.temp_defs} + imported_temps = {n:(n,p,deepcopy(t),o) for n,p,t,o in grammar.temp_defs} term_defs = [] rule_defs = [] @@ -649,12 +647,13 @@ def import_from_grammar_into_namespace(grammar, namespace, aliases): def rule_dependencies(symbol): if symbol.type != 'RULE': return [] - try: - _, tree, _ = imported_rules[symbol] - except KeyError: + if symbol in imported_rules: + return _find_used_symbols(imported_rules[symbol][1]) + elif symbol in imported_temps: + return _find_used_symbols(imported_temps[symbol][2]) - set(imported_temps[symbol][1]) + else: raise GrammarError("Missing symbol '%s' in grammar %s" % (symbol, namespace)) - return _find_used_symbols(tree) def get_namespace_name(name): @@ -671,14 +670,24 @@ def import_from_grammar_into_namespace(grammar, namespace, aliases): term_defs.append([get_namespace_name(symbol), imported_terms[symbol]]) else: assert symbol.type == 'RULE' - rule = imported_rules[symbol] - for t in rule[1].iter_subtrees(): - for i, c in enumerate(t.children): - if isinstance(c, Token) and c.type in ('RULE', 'TERMINAL'): - t.children[i] = Token(c.type, get_namespace_name(c)) - rule_defs.append((get_namespace_name(symbol), rule[1], rule[2])) - - return term_defs, rule_defs + if symbol in imported_rules: + rule = imported_rules[symbol] + for t in rule[1].iter_subtrees(): + for i, c in enumerate(t.children): + if isinstance(c, Token) and c.type in ('RULE', 'TERMINAL'): + t.children[i] = Token(c.type, get_namespace_name(c)) + rule_defs.append((get_namespace_name(symbol), rule[1], rule[2])) + else: + temp = imported_temps[symbol] + for t in temp[2].iter_subtrees(): + for i, c in enumerate(t.children): + if isinstance(c, Token) and c.type in ('RULE', 'TERMINAL'): + t.children[i] = Token(c.type, get_namespace_name(c)) + params = [('%s__%s' if p[0]!='_' else '_%s__%s' ) % (namespace, p) for p in temp[1]] + temp_defs.append((get_namespace_name(symbol), params, temp[2], temp[3])) + + + return term_defs, rule_defs, temp_defs diff --git a/tests/grammars/templates.lark b/tests/grammars/templates.lark new file mode 100644 index 0000000..1631188 --- /dev/null +++ b/tests/grammars/templates.lark @@ -0,0 +1 @@ +sep{item, delim}: item (delim item)* \ No newline at end of file diff --git a/tests/test_parser.py b/tests/test_parser.py index a664bcc..5a0313d 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -835,6 +835,13 @@ def _make_parser_test(LEXER, PARSER): x = g.parse("[1]") self.assertSequenceEqual(x.children,['1']) + def test_templates_import(self): + g = _Lark_open("test_templates_import.lark", rel_to=__file__) + x = g.parse("[1, 2, 3, 4]") + self.assertSequenceEqual(x.children,['1', '2', '3', '4']) + x = g.parse("[1]") + self.assertSequenceEqual(x.children,['1']) + def test_token_collision_WS(self): g = _Lark(r"""start: "Hello" NAME NAME: /\w/+ diff --git a/tests/test_templates_import.lark b/tests/test_templates_import.lark new file mode 100644 index 0000000..a1272b8 --- /dev/null +++ b/tests/test_templates_import.lark @@ -0,0 +1,4 @@ +start: "[" sep{NUMBER, ","} "]" +NUMBER: /\d+/ +%ignore " " +%import .grammars.templates.sep \ No newline at end of file