diff --git a/lark/load_grammar.py b/lark/load_grammar.py index dbf4a1f..569e67d 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -9,7 +9,7 @@ import pkgutil from ast import literal_eval from numbers import Integral -from .utils import bfs, Py36, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique +from .utils import bfs, Py36, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique, small_factors from .lexer import Token, TerminalDef, PatternStr, PatternRE from .parse_tree_builder import ParseTreeBuilder @@ -196,6 +196,26 @@ class EBNF_to_BNF(Transformer_InPlace): self.rules_by_expr[expr] = t return t + def _add_repeat_rule(self, a, b, target, atom): + if (a, b, target, atom) in self.rules_by_expr: + return self.rules_by_expr[(a, b, target, atom)] + new_name = '__%s_a%d_b%d_%d' % (self.prefix, a, b, self.i) + self.i += 1 + t = NonTerminal(new_name) + tree = ST('expansions', [ST('expansion', [target] * a + [atom] * b)]) + self.new_rules.append((new_name, tree, self.rule_options)) + self.rules_by_expr[(a, b, target, atom)] = t + return t + + def _generate_repeats(self, rule, mn, mx): + factors = small_factors(mn) + target = rule + for a, b in factors: + target = self._add_repeat_rule(a, b, target, rule) + + # return ST('expansions', [ST('expansion', [rule] * n) for n in range(mn, mx + 1)]) + return ST('expansions', [ST('expansion', [target] + [rule] * n) for n in range(0, mx - mn + 1)]) + def expr(self, rule, op, *args): if op.value == '?': empty = ST('expansion', []) @@ -220,7 +240,7 @@ class EBNF_to_BNF(Transformer_InPlace): mn, mx = map(int, args) if mx < mn or mn < 0: raise GrammarError("Bad Range for %s (%d..%d isn't allowed)" % (rule, mn, mx)) - return ST('expansions', [ST('expansion', [rule] * n) for n in range(mn, mx+1)]) + return self._generate_repeats(rule, mn, mx) assert False, op def maybe(self, rule): diff --git a/lark/utils.py b/lark/utils.py index ea78801..a3a077f 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -359,3 +359,33 @@ def _serialize(value, memo): return {key:_serialize(elem, memo) for key, elem in value.items()} # assert value is None or isinstance(value, (int, float, str, tuple)), value return value + + +def small_factors(n): + """ + Splits n up into smaller factors and summands <= 10. + Returns a list of [(a, b), ...] + so that the following code returns n: + + n = 1 + for a, b in values: + n = n * a + b + + Currently, we also keep a + b <= 10, but that might change + """ + assert n > 0 + if n < 10: + return [(n, 0)] + # TODO: Think of better algorithms (Prime factors should minimize the number of steps) + for a in range(10, 1, -1): + b = n % a + if a + b > 10: + continue + r = n // a + assert r * a + b == n # Sanity check + if r <= 10: + return [(r, 0), (a, b)] + else: + return [*small_factors(r), (a, b)] + # This should be unreachable, since 2 + 1 <= 10 + assert False, "Failed to factorize %s" % n