@@ -17,9 +17,6 @@ class LexerConf(Serialize): | |||
self.skip_validation = skip_validation | |||
self.use_bytes = use_bytes | |||
def _deserialize(self): | |||
self.callbacks = {} # TODO | |||
###} | |||
class ParserConf: | |||
@@ -11,7 +11,7 @@ from .common import LexerConf, ParserConf | |||
from .lexer import Lexer, TraditionalLexer, TerminalDef, UnexpectedToken | |||
from .parse_tree_builder import ParseTreeBuilder | |||
from .parser_frontends import get_frontend | |||
from .parser_frontends import get_frontend, _get_lexer_callbacks | |||
from .grammar import Rule | |||
import re | |||
@@ -278,12 +278,10 @@ class Lark(Serialize): | |||
rule.options.priority = None | |||
# TODO Deprecate lexer_callbacks? | |||
lexer_callbacks = dict(self.options.lexer_callbacks) | |||
if self.options.transformer: | |||
t = self.options.transformer | |||
for term in self.terminals: | |||
if hasattr(t, term.name): | |||
lexer_callbacks[term.name] = getattr(t, term.name) | |||
lexer_callbacks = (_get_lexer_callbacks(self.options.transformer, self.terminals) | |||
if self.options.transformer | |||
else {}) | |||
lexer_callbacks.update(self.options.lexer_callbacks) | |||
self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes) | |||
@@ -344,7 +342,14 @@ class Lark(Serialize): | |||
self.rules = [Rule.deserialize(r, memo) for r in data['rules']] | |||
self.source = '<deserialized>' | |||
self._prepare_callbacks() | |||
self.parser = self.parser_class.deserialize(data['parser'], memo, self._callbacks, self.options.postlex, re_module) | |||
self.parser = self.parser_class.deserialize( | |||
data['parser'], | |||
memo, | |||
self._callbacks, | |||
self.options.postlex, | |||
self.options.transformer, | |||
re_module | |||
) | |||
return self | |||
@classmethod | |||
@@ -1,6 +1,6 @@ | |||
from .utils import get_regexp_width, Serialize | |||
from .parsers.grammar_analysis import GrammarAnalyzer | |||
from .lexer import TraditionalLexer, ContextualLexer, Lexer, Token | |||
from .lexer import TraditionalLexer, ContextualLexer, Lexer, Token, TerminalDef | |||
from .parsers import earley, xearley, cyk | |||
from .parsers.lalr_parser import LALR_Parser | |||
from .grammar import Rule | |||
@@ -58,6 +58,15 @@ class _ParserFrontend(Serialize): | |||
return self.parser.parse(input, start, *args) | |||
def _get_lexer_callbacks(transformer, terminals): | |||
result = {} | |||
for terminal in terminals: | |||
callback = getattr(transformer, terminal.name, None) | |||
if callback is not None: | |||
result[terminal.name] = callback | |||
return result | |||
class WithLexer(_ParserFrontend): | |||
lexer = None | |||
parser = None | |||
@@ -73,13 +82,18 @@ class WithLexer(_ParserFrontend): | |||
self.postlex = lexer_conf.postlex | |||
@classmethod | |||
def deserialize(cls, data, memo, callbacks, postlex, re_module): | |||
def deserialize(cls, data, memo, callbacks, postlex, transformer, re_module): | |||
inst = super(WithLexer, cls).deserialize(data, memo) | |||
inst.postlex = postlex | |||
inst.parser = LALR_Parser.deserialize(inst.parser, memo, callbacks) | |||
terminals = [item for item in memo.values() if isinstance(item, TerminalDef)] | |||
inst.lexer_conf.callbacks = _get_lexer_callbacks(transformer, terminals) | |||
inst.lexer_conf.re_module = re_module | |||
inst.lexer_conf.skip_validation=True | |||
inst.init_lexer() | |||
return inst | |||
def _serialize(self, data, memo): | |||
@@ -229,4 +243,3 @@ class CYK(WithLexer): | |||
def _apply_callback(self, tree): | |||
return self.callbacks[tree.rule](tree.children) | |||
@@ -106,6 +106,33 @@ class TestStandalone(TestCase): | |||
x = l.parse('(\n)\n') | |||
self.assertEqual(x, Tree('start', [])) | |||
def test_transformer(self): | |||
grammar = r""" | |||
start: some_rule "(" SOME_TERMINAL ")" | |||
some_rule: SOME_TERMINAL | |||
SOME_TERMINAL: /[A-Za-z_][A-Za-z0-9_]*/ | |||
""" | |||
context = self._create_standalone(grammar) | |||
_Lark = context["Lark_StandAlone"] | |||
_Token = context["Token"] | |||
_Tree = context["Tree"] | |||
class MyTransformer(context["Transformer"]): | |||
def SOME_TERMINAL(self, token): | |||
return _Token("SOME_TERMINAL", "token is transformed") | |||
def some_rule(self, children): | |||
return _Tree("rule_is_transformed", []) | |||
parser = _Lark(transformer=MyTransformer()) | |||
self.assertEqual( | |||
parser.parse("FOO(BAR)"), | |||
_Tree("start", [ | |||
_Tree("rule_is_transformed", []), | |||
_Token("SOME_TERMINAL", "token is transformed") | |||
]) | |||
) | |||
if __name__ == '__main__': | |||