diff --git a/lark/load_grammar.py b/lark/load_grammar.py index dbf4a1f..d1d06cc 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -9,7 +9,7 @@ import pkgutil from ast import literal_eval from numbers import Integral -from .utils import bfs, Py36, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique +from .utils import bfs, Py36, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique, small_factors from .lexer import Token, TerminalDef, PatternStr, PatternRE from .parse_tree_builder import ParseTreeBuilder @@ -175,27 +175,136 @@ RULES = { } +# Value 5 keeps the number of states in the lalr parser somewhat minimal +# It isn't optimal, but close to it. See PR #949 +SMALL_FACTOR_THRESHOLD = 5 +# The Threshold whether repeat via ~ are split up into different rules +# 50 is chosen since it keeps the number of states low and therefore lalr analysis time low, +# while not being to overaggressive and unnecessarily creating rules that might create shift/reduce conflicts. +# (See PR #949) +REPEAT_BREAK_THRESHOLD = 50 + + @inline_args class EBNF_to_BNF(Transformer_InPlace): def __init__(self): self.new_rules = [] - self.rules_by_expr = {} + self.rules_cache = {} self.prefix = 'anon' self.i = 0 self.rule_options = None - def _add_recurse_rule(self, type_, expr): - if expr in self.rules_by_expr: - return self.rules_by_expr[expr] - - new_name = '__%s_%s_%d' % (self.prefix, type_, self.i) + def _name_rule(self, inner): + new_name = '__%s_%s_%d' % (self.prefix, inner, self.i) self.i += 1 - t = NonTerminal(new_name) - tree = ST('expansions', [ST('expansion', [expr]), ST('expansion', [t, expr])]) - self.new_rules.append((new_name, tree, self.rule_options)) - self.rules_by_expr[expr] = t + return new_name + + def _add_rule(self, key, name, expansions): + t = NonTerminal(name) + self.new_rules.append((name, expansions, self.rule_options)) + self.rules_cache[key] = t return t + def _add_recurse_rule(self, type_, expr): + try: + return self.rules_cache[expr] + except KeyError: + new_name = self._name_rule(type_) + t = NonTerminal(new_name) + tree = ST('expansions', [ + ST('expansion', [expr]), + ST('expansion', [t, expr]) + ]) + return self._add_rule(expr, new_name, tree) + + def _add_repeat_rule(self, a, b, target, atom): + """Generate a rule that repeats target ``a`` times, and repeats atom ``b`` times. + + When called recursively (into target), it repeats atom for x(n) times, where: + x(0) = 1 + x(n) = a(n) * x(n-1) + b + + Example rule when a=3, b=4: + + new_rule: target target target atom atom atom atom + + """ + key = (a, b, target, atom) + try: + return self.rules_cache[key] + except KeyError: + new_name = self._name_rule('repeat_a%d_b%d' % (a, b)) + tree = ST('expansions', [ST('expansion', [target] * a + [atom] * b)]) + return self._add_rule(key, new_name, tree) + + def _add_repeat_opt_rule(self, a, b, target, target_opt, atom): + """Creates a rule that matches atom 0 to (a*n+b)-1 times. + + When target matches n times atom, and target_opt 0 to n-1 times target_opt, + + First we generate target * i followed by target_opt, for i from 0 to a-1 + These match 0 to n*a - 1 times atom + + Then we generate target * a followed by atom * i, for i from 0 to b-1 + These match n*a to n*a + b-1 times atom + + The created rule will not have any shift/reduce conflicts so that it can be used with lalr + + Example rule when a=3, b=4: + + new_rule: target_opt + | target target_opt + | target target target_opt + + | target target target + | target target target atom + | target target target atom atom + | target target target atom atom atom + + """ + key = (a, b, target, atom, "opt") + try: + return self.rules_cache[key] + except KeyError: + new_name = self._name_rule('repeat_a%d_b%d_opt' % (a, b)) + tree = ST('expansions', [ + ST('expansion', [target]*i + [target_opt]) for i in range(a) + ] + [ + ST('expansion', [target]*a + [atom]*i) for i in range(b) + ]) + return self._add_rule(key, new_name, tree) + + def _generate_repeats(self, rule, mn, mx): + """Generates a rule tree that repeats ``rule`` exactly between ``mn`` to ``mx`` times. + """ + # For a small number of repeats, we can take the naive approach + if mx < REPEAT_BREAK_THRESHOLD: + return ST('expansions', [ST('expansion', [rule] * n) for n in range(mn, mx + 1)]) + + # For large repeat values, we break the repetition into sub-rules. + # We treat ``rule~mn..mx`` as ``rule~mn rule~0..(diff=mx-mn)``. + # We then use small_factors to split up mn and diff up into values [(a, b), ...] + # This values are used with the help of _add_repeat_rule and _add_repeat_rule_opt + # to generate a complete rule/expression that matches the corresponding number of repeats + mn_target = rule + for a, b in small_factors(mn, SMALL_FACTOR_THRESHOLD): + mn_target = self._add_repeat_rule(a, b, mn_target, rule) + if mx == mn: + return mn_target + + diff = mx - mn + 1 # We add one because _add_repeat_opt_rule generates rules that match one less + diff_factors = small_factors(diff, SMALL_FACTOR_THRESHOLD) + diff_target = rule # Match rule 1 times + diff_opt_target = ST('expansion', []) # match rule 0 times (e.g. up to 1 -1 times) + for a, b in diff_factors[:-1]: + diff_opt_target = self._add_repeat_opt_rule(a, b, diff_target, diff_opt_target, rule) + diff_target = self._add_repeat_rule(a, b, diff_target, rule) + + a, b = diff_factors[-1] + diff_opt_target = self._add_repeat_opt_rule(a, b, diff_target, diff_opt_target, rule) + + return ST('expansions', [ST('expansion', [mn_target] + [diff_opt_target])]) + def expr(self, rule, op, *args): if op.value == '?': empty = ST('expansion', []) @@ -220,7 +329,9 @@ class EBNF_to_BNF(Transformer_InPlace): mn, mx = map(int, args) if mx < mn or mn < 0: raise GrammarError("Bad Range for %s (%d..%d isn't allowed)" % (rule, mn, mx)) - return ST('expansions', [ST('expansion', [rule] * n) for n in range(mn, mx+1)]) + + return self._generate_repeats(rule, mn, mx) + assert False, op def maybe(self, rule): diff --git a/lark/utils.py b/lark/utils.py index ea78801..2938591 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -187,7 +187,7 @@ def get_regexp_width(expr): return 1, sre_constants.MAXREPEAT else: return 0, sre_constants.MAXREPEAT - + ###} @@ -288,7 +288,7 @@ except ImportError: class FS: exists = os.path.exists - + @staticmethod def open(name, mode="r", **kwargs): if atomicwrites and "w" in mode: @@ -359,3 +359,29 @@ def _serialize(value, memo): return {key:_serialize(elem, memo) for key, elem in value.items()} # assert value is None or isinstance(value, (int, float, str, tuple)), value return value + + + + +def small_factors(n, max_factor): + """ + Splits n up into smaller factors and summands <= max_factor. + Returns a list of [(a, b), ...] + so that the following code returns n: + + n = 1 + for a, b in values: + n = n * a + b + + Currently, we also keep a + b <= max_factor, but that might change + """ + assert n >= 0 + assert max_factor > 2 + if n <= max_factor: + return [(n, 0)] + + for a in range(max_factor, 1, -1): + r, b = divmod(n, a) + if a + b <= max_factor: + return small_factors(r, max_factor) + [(a, b)] + assert False, "Failed to factorize %s" % n diff --git a/tests/test_grammar.py b/tests/test_grammar.py index a643117..3ae65f2 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -3,7 +3,7 @@ from __future__ import absolute_import import sys from unittest import TestCase, main -from lark import Lark, Token, Tree +from lark import Lark, Token, Tree, ParseError, UnexpectedInput from lark.load_grammar import GrammarError, GRAMMAR_ERRORS, find_grammar_errors from lark.load_grammar import FromPackageLoader @@ -198,6 +198,53 @@ class TestGrammar(TestCase): x = find_grammar_errors(text) assert [e.line for e, _s in find_grammar_errors(text)] == [2, 6] + def test_ranged_repeat_terms(self): + g = u"""!start: AAA + AAA: "A"~3 + """ + l = Lark(g, parser='lalr') + self.assertEqual(l.parse(u'AAA'), Tree('start', ["AAA"])) + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AA') + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAA') + + g = u"""!start: AABB CC + AABB: "A"~0..2 "B"~2 + CC: "C"~1..2 + """ + l = Lark(g, parser='lalr') + self.assertEqual(l.parse(u'AABBCC'), Tree('start', ['AABB', 'CC'])) + self.assertEqual(l.parse(u'BBC'), Tree('start', ['BB', 'C'])) + self.assertEqual(l.parse(u'ABBCC'), Tree('start', ['ABB', 'CC'])) + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAB') + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAABBB') + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'ABB') + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAABB') + + def test_ranged_repeat_large(self): + g = u"""!start: "A"~60 + """ + l = Lark(g, parser='lalr') + self.assertGreater(len(l.rules), 1, "Expected that more than one rule will be generated") + self.assertEqual(l.parse(u'A' * 60), Tree('start', ["A"] * 60)) + self.assertRaises(ParseError, l.parse, u'A' * 59) + self.assertRaises((ParseError, UnexpectedInput), l.parse, u'A' * 61) + + g = u"""!start: "A"~15..100 + """ + l = Lark(g, parser='lalr') + for i in range(0, 110): + if 15 <= i <= 100: + self.assertEqual(l.parse(u'A' * i), Tree('start', ['A']*i)) + else: + self.assertRaises(UnexpectedInput, l.parse, u'A' * i) + + # 8191 is a Mersenne prime + g = u"""start: "A"~8191 + """ + l = Lark(g, parser='lalr') + self.assertEqual(l.parse(u'A' * 8191), Tree('start', [])) + self.assertRaises(UnexpectedInput, l.parse, u'A' * 8190) + self.assertRaises(UnexpectedInput, l.parse, u'A' * 8192) if __name__ == '__main__': diff --git a/tests/test_parser.py b/tests/test_parser.py index 8fec82d..9eb7b26 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -2204,27 +2204,7 @@ def _make_parser_test(LEXER, PARSER): self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAABB') - def test_ranged_repeat_terms(self): - g = u"""!start: AAA - AAA: "A"~3 - """ - l = _Lark(g) - self.assertEqual(l.parse(u'AAA'), Tree('start', ["AAA"])) - self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AA') - self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAA') - g = u"""!start: AABB CC - AABB: "A"~0..2 "B"~2 - CC: "C"~1..2 - """ - l = _Lark(g) - self.assertEqual(l.parse(u'AABBCC'), Tree('start', ['AABB', 'CC'])) - self.assertEqual(l.parse(u'BBC'), Tree('start', ['BB', 'C'])) - self.assertEqual(l.parse(u'ABBCC'), Tree('start', ['ABB', 'CC'])) - self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAB') - self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAABBB') - self.assertRaises((ParseError, UnexpectedInput), l.parse, u'ABB') - self.assertRaises((ParseError, UnexpectedInput), l.parse, u'AAAABB') @unittest.skipIf(PARSER=='earley', "Priority not handled correctly right now") # TODO XXX def test_priority_vs_embedded(self):