diff --git a/lark/lark.py b/lark/lark.py index ed1dfbc..a32c214 100644 --- a/lark/lark.py +++ b/lark/lark.py @@ -343,9 +343,7 @@ class Lark(Serialize): rule.options.priority = None # TODO Deprecate lexer_callbacks? - lexer_callbacks = (_get_lexer_callbacks(self.options.transformer, self.terminals) - if self.options.transformer - else {}) + lexer_callbacks = {} lexer_callbacks.update(self.options.lexer_callbacks) self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes) @@ -375,8 +373,7 @@ class Lark(Serialize): return TraditionalLexer(lexer_conf) def _prepare_callbacks(self): - self.parser_class = get_frontend(self.options.parser, self.options.lexer) - self._callbacks = None + self._callbacks = {} # we don't need these callbacks if we aren't building a tree if self.options.ambiguity != 'forest': self._parse_tree_builder = ParseTreeBuilder( @@ -386,9 +383,11 @@ class Lark(Serialize): self.options.parser != 'lalr' and self.options.ambiguity == 'explicit', self.options.maybe_placeholders ) - self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer) + self._callbacks.update(self._parse_tree_builder.create_callback(self.options.transformer)) + self._callbacks.update(_get_lexer_callbacks(self.options.transformer, self.terminals)) def _build_parser(self): + self.parser_class = get_frontend(self.options.parser, self.options.lexer) self._prepare_callbacks() parser_conf = ParserConf(self.rules, self._callbacks, self.options.start) return self.parser_class(self.lexer_conf, parser_conf, options=self.options) @@ -428,16 +427,21 @@ class Lark(Serialize): self.options = LarkOptions.deserialize(options, memo) self.rules = [Rule.deserialize(r, memo) for r in data['rules']] self.source_path = '' + self.parser_class = get_frontend(self.options.parser, self.options.lexer) + self.lexer_conf = self.parser_class.deserialize_lexer_conf( # We need the terminals list to for _prepare_callbacks + data['parser'], + memo, + self.options) + self.terminals = self.lexer_conf.terminals + self._terminals_dict = {t.name: t for t in self.terminals} self._prepare_callbacks() self.parser = self.parser_class.deserialize( data['parser'], memo, + self.lexer_conf, self._callbacks, self.options, # Not all, but multiple attributes are used ) - self.lexer_conf = self.parser.lexer_conf - self.terminals = self.parser.lexer_conf.terminals - self._terminals_dict = {t.name: t for t in self.terminals} return self @classmethod diff --git a/lark/parser_frontends.py b/lark/parser_frontends.py index 5acbbeb..eea582a 100644 --- a/lark/parser_frontends.py +++ b/lark/parser_frontends.py @@ -38,23 +38,26 @@ class MakeParsingFrontend: parser_conf.parser_type = self.parser_type lexer_conf.lexer_type = self.lexer_type return ParsingFrontend(lexer_conf, parser_conf, options) - + @classmethod - def deserialize(cls, data, memo, callbacks, options): - lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo) - parser_conf = ParserConf.deserialize(data['parser_conf'], memo) - parser = LALR_Parser.deserialize(data['parser'], memo, callbacks, options.debug) - parser_conf.callbacks = callbacks - + def deserialize_lexer_conf(cls, data, memo, options): + # We need lexer_conf earley to have the terminals that we need to produce the callback list for paser_conf + # So we split deserialize into two methods terminals = [item for item in memo.values() if isinstance(item, TerminalDef)] - + lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo) lexer_conf.callbacks = _get_lexer_callbacks(options.transformer, terminals) lexer_conf.re_module = regex if options.regex else re lexer_conf.use_bytes = options.use_bytes lexer_conf.g_regex_flags = options.g_regex_flags lexer_conf.skip_validation = True lexer_conf.postlex = options.postlex + return lexer_conf + @classmethod + def deserialize(cls, data, memo, lexer_conf, callbacks, options): + parser_conf = ParserConf.deserialize(data['parser_conf'], memo) + parser = LALR_Parser.deserialize(data['parser'], memo, callbacks, options.debug) + parser_conf.callbacks = callbacks return ParsingFrontend(lexer_conf, parser_conf, options, parser=parser) diff --git a/lark/parsers/lalr_parser.py b/lark/parsers/lalr_parser.py index 9ca36f0..6271bd5 100644 --- a/lark/parsers/lalr_parser.py +++ b/lark/parsers/lalr_parser.py @@ -129,7 +129,7 @@ class ParserState(object): # shift once and return assert not is_end state_stack.append(arg) - value_stack.append(token) + value_stack.append(token if token.type not in callbacks else callbacks[token.type](token)) return else: # reduce+shift as many times as necessary diff --git a/tests/test_parser.py b/tests/test_parser.py index 8500b01..ef91a92 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -10,7 +10,7 @@ from copy import copy, deepcopy from lark.utils import Py36, isascii -from lark import Token +from lark import Token, Transformer_NonRecursive try: from cStringIO import StringIO as cStringIO @@ -34,7 +34,7 @@ from lark import logger from lark.lark import Lark from lark.exceptions import GrammarError, ParseError, UnexpectedToken, UnexpectedInput, UnexpectedCharacters from lark.tree import Tree -from lark.visitors import Transformer, Transformer_InPlace, v_args +from lark.visitors import Transformer, Transformer_InPlace, v_args, Transformer_InPlaceRecursive from lark.grammar import Rule from lark.lexer import TerminalDef, Lexer, TraditionalLexer from lark.indenter import Indenter @@ -162,6 +162,28 @@ class TestParsers(unittest.TestCase): r = p.parse("x") self.assertEqual( r.children, ["X!"] ) + def test_visit_tokens2(self): + g = """ + start: add+ + add: NUM "+" NUM + NUM: /\d+/ + %ignore " " + """ + text = "1+2 3+4" + expected = Tree('start', [3, 7]) + for base in (Transformer, Transformer_InPlace, Transformer_NonRecursive, Transformer_InPlaceRecursive): + class T(base): + def add(self, children): + return sum(children if isinstance(children, list) else children.children) + + def NUM(self, token): + return int(token) + + + parser = Lark(g, parser='lalr', transformer=T()) + result = parser.parse(text) + self.assertEqual(result, expected) + def test_vargs_meta(self): @v_args(meta=True)