diff --git a/lark/load_grammar.py b/lark/load_grammar.py index dd3a27c..ae7ec32 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -307,6 +307,7 @@ class PrepareAnonTerminals(Transformer_InPlace): self.term_set = {td.name for td in self.terminals} self.term_reverse = {td.pattern: td for td in terminals} self.i = 0 + self.rule_options = None @inline_args @@ -351,7 +352,10 @@ class PrepareAnonTerminals(Transformer_InPlace): self.term_reverse[p] = termdef self.terminals.append(termdef) - return Terminal(term_name, filter_out=isinstance(p, PatternStr)) + filter_out = False if self.rule_options and self.rule_options.keep_all_tokens else isinstance(p, PatternStr) + + return Terminal(term_name, filter_out=filter_out) + class _ReplaceSymbols(Transformer_InPlace): " Helper for ApplyTemplates " @@ -541,7 +545,8 @@ class Grammar: # ================= # 1. Pre-process terminals - transformer = PrepareLiterals() * PrepareSymbols() * PrepareAnonTerminals(terminals) # Adds to terminals + anon_tokens_transf = PrepareAnonTerminals(terminals) + transformer = PrepareLiterals() * PrepareSymbols() * anon_tokens_transf # Adds to terminals # 2. Inline Templates @@ -556,8 +561,10 @@ class Grammar: i += 1 if len(params) != 0: # Dont transform templates continue - ebnf_to_bnf.rule_options = RuleOptions(keep_all_tokens=True) if options.keep_all_tokens else None + rule_options = RuleOptions(keep_all_tokens=True) if options and options.keep_all_tokens else None + ebnf_to_bnf.rule_options = rule_options ebnf_to_bnf.prefix = name + anon_tokens_transf.rule_options = rule_options tree = transformer.transform(rule_tree) res = ebnf_to_bnf.transform(tree) rules.append((name, res, options)) diff --git a/lark/reconstruct.py b/lark/reconstruct.py index baf1d2c..89967b2 100644 --- a/lark/reconstruct.py +++ b/lark/reconstruct.py @@ -86,6 +86,14 @@ def best_from_group(seq, group_key, cmp_key): d[key] = item return list(d.values()) + +def make_recons_rule(origin, expansion, old_expansion): + return Rule(origin, expansion, alias=MakeMatchTree(origin.name, old_expansion)) + +def make_recons_rule_to_term(origin, term): + return make_recons_rule(origin, [Terminal(term.name)], [term]) + + class Reconstructor: """ A Reconstructor that will, given a full parse Tree, generate source code. @@ -100,6 +108,8 @@ class Reconstructor: tokens, rules, _grammar_extra = parser.grammar.compile(parser.options.start) self.write_tokens = WriteTokensTransformer({t.name:t for t in tokens}, term_subs) + self.rules_for_root = defaultdict(list) + self.rules = list(self._build_recons_rules(rules)) self.rules.reverse() @@ -107,9 +117,8 @@ class Reconstructor: self.rules = best_from_group(self.rules, lambda r: r, lambda r: -len(r.expansion)) self.rules.sort(key=lambda r: len(r.expansion)) - callbacks = {rule: rule.alias for rule in self.rules} # TODO pass callbacks through dict, instead of alias? - self.parser = earley.Parser(ParserConf(self.rules, callbacks, parser.options.start), - self._match, resolve_ambiguity=True) + self.parser = parser + self._parser_cache = {} def _build_recons_rules(self, rules): expand1s = {r.origin for r in rules if r.options.expand1} @@ -121,24 +130,36 @@ class Reconstructor: rule_names = {r.origin for r in rules} nonterminals = {sym for sym in rule_names - if sym.name.startswith('_') or sym in expand1s or sym in aliases } + if sym.name.startswith('_') or sym in expand1s or sym in aliases } + seen = set() for r in rules: recons_exp = [sym if sym in nonterminals else Terminal(sym.name) for sym in r.expansion if not is_discarded_terminal(sym)] # Skip self-recursive constructs - if recons_exp == [r.origin]: + if recons_exp == [r.origin] and r.alias is None: continue sym = NonTerminal(r.alias) if r.alias else r.origin + rule = make_recons_rule(sym, recons_exp, r.expansion) - yield Rule(sym, recons_exp, alias=MakeMatchTree(sym.name, r.expansion)) + if sym in expand1s and len(recons_exp) != 1: + self.rules_for_root[sym.name].append(rule) + + if sym.name not in seen: + yield make_recons_rule_to_term(sym, sym) + seen.add(sym.name) + else: + if sym.name.startswith('_') or sym in expand1s: + yield rule + else: + self.rules_for_root[sym.name].append(rule) for origin, rule_aliases in aliases.items(): for alias in rule_aliases: - yield Rule(origin, [Terminal(alias)], alias=MakeMatchTree(origin.name, [NonTerminal(alias)])) - yield Rule(origin, [Terminal(origin.name)], alias=MakeMatchTree(origin.name, [origin])) + yield make_recons_rule_to_term(origin, NonTerminal(alias)) + yield make_recons_rule_to_term(origin, origin) def _match(self, term, token): if isinstance(token, Tree): @@ -149,7 +170,20 @@ class Reconstructor: def _reconstruct(self, tree): # TODO: ambiguity? - unreduced_tree = self.parser.parse(tree.children, tree.data) # find a full derivation + try: + parser = self._parser_cache[tree.data] + except KeyError: + rules = self.rules + best_from_group( + self.rules_for_root[tree.data], lambda r: r, lambda r: -len(r.expansion) + ) + + rules.sort(key=lambda r: len(r.expansion)) + + callbacks = {rule: rule.alias for rule in rules} # TODO pass callbacks through dict, instead of alias? + parser = earley.Parser(ParserConf(rules, callbacks, [tree.data]), self._match, resolve_ambiguity=True) + self._parser_cache[tree.data] = parser + + unreduced_tree = parser.parse(tree.children, tree.data) # find a full derivation assert unreduced_tree.data == tree.data res = self.write_tokens.transform(unreduced_tree) for item in res: diff --git a/tests/test_reconstructor.py b/tests/test_reconstructor.py index ecab499..93c64fe 100644 --- a/tests/test_reconstructor.py +++ b/tests/test_reconstructor.py @@ -69,6 +69,35 @@ class TestReconstructor(TestCase): self.assert_reconstruct(g, code) + def test_keep_tokens(self): + g = """ + start: (NL | stmt)* + stmt: var op var + !op: ("+" | "-" | "*" | "/") + var: WORD + NL: /(\\r?\\n)+\s*/ + """ + common + + code = """ + a+b + """ + + self.assert_reconstruct(g, code) + + def test_expand_rule(self): + g = """ + ?start: (NL | mult_stmt)* + ?mult_stmt: sum_stmt ["*" sum_stmt] + ?sum_stmt: var ["+" var] + var: WORD + NL: /(\\r?\\n)+\s*/ + """ + common + + code = ['a', 'a*b', 'a+b', 'a*b+c', 'a+b*c', 'a+b*c+d'] + + for c in code: + self.assert_reconstruct(g, c) + def test_json_example(self): test_json = ''' {