From 19a9c9c2064e71565e54c90f8126e410439e91bb Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 14 Oct 2017 14:21:28 +0300 Subject: [PATCH] Towards an introspectable tree-builder. Also added tests. --- lark/lark.py | 2 +- lark/parse_tree_builder.py | 113 ++++++++++++++++++++----------------- tests/test_parser.py | 94 +++++++++++++++++++++++++++++- 3 files changed, 156 insertions(+), 53 deletions(-) diff --git a/lark/lark.py b/lark/lark.py index 18d9959..90a06fc 100644 --- a/lark/lark.py +++ b/lark/lark.py @@ -66,7 +66,7 @@ class LarkOptions(object): assert self.parser in ('earley', 'lalr', None) if self.parser == 'earley' and self.transformer: - raise ValueError('Cannot specify an auto-transformer when using the Earley algorithm.' + raise ValueError('Cannot specify an embedded transformer when using the Earley algorithm.' 'Please use your transformer on the resulting parse tree, or use a different algorithm (i.e. lalr)') if o: diff --git a/lark/parse_tree_builder.py b/lark/parse_tree_builder.py index 601372e..af553cc 100644 --- a/lark/parse_tree_builder.py +++ b/lark/parse_tree_builder.py @@ -2,23 +2,48 @@ from .common import is_terminal, GrammarError from .utils import suppress from .lexer import Token -class Callback(object): - pass +class NodeBuilder: + def __init__(self, tree_class, name): + self.tree_class = tree_class + self.name = name + def __call__(self, children): + return self.tree_class(self.name, children) -def create_expand1_tree_builder_function(tree_builder): - def expand1(children): +class Expand1: + def __init__(self, node_builder): + self.node_builder = node_builder + + def __call__(self, children): if len(children) == 1: return children[0] else: - return tree_builder(children) - return expand1 + return self.node_builder(children) + +class TokenWrapper: + "Used for fixing the results of scanless parsing" -def create_token_wrapper(tree_builder, name): - def join_children(children): - children = [Token(name, ''.join(children))] - return tree_builder(children) - return join_children + def __init__(self, node_builder, token_name): + self.node_builder = node_builder + self.token_name = token_name + + def __call__(self, children): + return self.node_builder( [Token(self.token_name, ''.join(children))] ) + +class ChildFilter: + def __init__(self, node_builder, to_include): + self.node_builder = node_builder + self.to_include = to_include + + def __call__(self, children): + filtered = [] + for i, to_expand in self.to_include: + if to_expand: + filtered += children[i].children + else: + filtered.append(children[i]) + + return self.node_builder(filtered) def create_rule_handler(expansion, usermethod, keep_all_tokens, filter_out): # if not keep_all_tokens: @@ -29,32 +54,26 @@ def create_rule_handler(expansion, usermethod, keep_all_tokens, filter_out): ] if len(to_include) < len(expansion) or any(to_expand for i, to_expand in to_include): - def _build_ast(match): - children = [] - for i, to_expand in to_include: - if to_expand: - children += match[i].children - else: - children.append(match[i]) - - return usermethod(children) - return _build_ast + return ChildFilter(usermethod, to_include) # else, if no filtering required.. return usermethod -def propagate_positions_wrapper(f): - def _f(args): - res = f(args) +class PropagatePositions: + def __init__(self, node_builder): + self.node_builder = node_builder - if args: - for a in args: + def __call__(self, children): + res = self.node_builder(children) + + if children: + for a in children: with suppress(AttributeError): res.line = a.line res.column = a.column break - for a in reversed(args): + for a in reversed(children): with suppress(AttributeError): res.end_line = a.end_line res.end_col = a.end_col @@ -62,7 +81,9 @@ def propagate_positions_wrapper(f): return res - return _f + +class Callback(object): + pass class ParseTreeBuilder: def __init__(self, tree_class, propagate_positions=False, keep_all_tokens=False): @@ -70,14 +91,6 @@ class ParseTreeBuilder: self.propagate_positions = propagate_positions self.always_keep_all_tokens = keep_all_tokens - def _create_tree_builder_function(self, name): - tree_class = self.tree_class - def tree_builder_f(children): - return tree_class(name, children) - return tree_builder_f - - - def create_tree_builder(self, rules, transformer): callback = Callback() new_rules = [] @@ -93,36 +106,34 @@ class ParseTreeBuilder: expand1 = options.expand1 if options else False create_token = options.create_token if options else False - _origin = origin - for expansion, alias in expansions: if alias and origin.startswith('_'): - raise Exception("Rule %s is marked for expansion (it starts with an underscore) and isn't allowed to have aliases (alias=%s)" % (origin, alias)) + raise Exception("Rule %s is marked for expansion (it starts with an underscore) and isn't allowed to have aliases (alias=%s)" % (origin, alias)) + + elif not alias: + alias = origin try: - f = transformer._get_func(alias or _origin) + f = transformer._get_func(alias) except AttributeError: - if alias: - f = self._create_tree_builder_function(alias) - else: - f = self._create_tree_builder_function(_origin) - if expand1: - f = create_expand1_tree_builder_function(f) + f = NodeBuilder(self.tree_class, alias) - if create_token: - f = create_token_wrapper(f, create_token) + if expand1: + f = Expand1(f) + if create_token: + f = TokenWrapper(f, create_token) alias_handler = create_rule_handler(expansion, f, keep_all_tokens, filter_out) if self.propagate_positions: - alias_handler = propagate_positions_wrapper(alias_handler) + alias_handler = PropagatePositions(alias_handler) - callback_name = 'autoalias_%s_%s' % (_origin, '_'.join(expansion)) + callback_name = 'autoalias_%s_%s' % (origin, '_'.join(expansion)) if hasattr(callback, callback_name): raise GrammarError("Rule expansion '%s' already exists in rule %s" % (' '.join(expansion), origin)) setattr(callback, callback_name, alias_handler) - new_rules.append(( _origin, expansion, callback_name, options )) + new_rules.append(( origin, expansion, callback_name, options )) return new_rules, callback diff --git a/tests/test_parser.py b/tests/test_parser.py index 9fa05eb..3799fce 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -19,7 +19,7 @@ logging.basicConfig(level=logging.INFO) from lark.lark import Lark from lark.common import GrammarError, ParseError from lark.lexer import LexError -from lark.tree import Tree +from lark.tree import Tree, Transformer __path__ = os.path.dirname(__file__) def _read(n, *args): @@ -53,6 +53,98 @@ class TestParsers(unittest.TestCase): l = Lark(g, parser='earley', lexer='dynamic') self.assertRaises(ParseError, l.parse, 'a') + def test_propagate_positions(self): + g = Lark("""start: a + a: "a" + """, propagate_positions=True) + + r = g.parse('a') + self.assertEqual( r.children[0].line, 1 ) + + def test_expand1(self): + + g = Lark("""start: a + ?a: b + b: "x" + """) + + r = g.parse('x') + self.assertEqual( r.children[0].data, "b" ) + + g = Lark("""start: a + ?a: b -> c + b: "x" + """) + + r = g.parse('x') + self.assertEqual( r.children[0].data, "b" ) + + + g = Lark("""start: a + ?a: b b -> c + b: "x" + """) + r = g.parse('xx') + self.assertEqual( r.children[0].data, "c" ) + + def test_embedded_transformer(self): + class T(Transformer): + def a(self, children): + return "" + def b(self, children): + return "" + def c(self, children): + return "" + + # Test regular + g = Lark("""start: a + a : "x" + """, parser='lalr') + r = T().transform(g.parse("x")) + self.assertEqual( r.children, [""] ) + + + g = Lark("""start: a + a : "x" + """, parser='lalr', transformer=T()) + r = g.parse("x") + self.assertEqual( r.children, [""] ) + + + # Test Expand1 + g = Lark("""start: a + ?a : b + b : "x" + """, parser='lalr') + r = T().transform(g.parse("x")) + self.assertEqual( r.children, [""] ) + + + g = Lark("""start: a + ?a : b + b : "x" + """, parser='lalr', transformer=T()) + r = g.parse("x") + self.assertEqual( r.children, [""] ) + + # Test Expand1 -> Alias + g = Lark("""start: a + ?a : b b -> c + b : "x" + """, parser='lalr') + r = T().transform(g.parse("xx")) + self.assertEqual( r.children, [""] ) + + + g = Lark("""start: a + ?a : b b -> c + b : "x" + """, parser='lalr', transformer=T()) + r = g.parse("xx") + self.assertEqual( r.children, [""] ) + + + def _make_full_earley_test(LEXER): class _TestFullEarley(unittest.TestCase):