@@ -54,7 +54,7 @@ class LarkOptions(object): | |||||
_defaults = { | _defaults = { | ||||
'debug': False, | 'debug': False, | ||||
'keep_all_tokens': False, | 'keep_all_tokens': False, | ||||
'tree_class': Tree, | |||||
'tree_class': None, | |||||
'cache_grammar': False, | 'cache_grammar': False, | ||||
'postlex': None, | 'postlex': None, | ||||
'parser': 'earley', | 'parser': 'earley', | ||||
@@ -97,6 +97,7 @@ class LarkOptions(object): | |||||
def __getattr__(self, name): | def __getattr__(self, name): | ||||
return self.options[name] | return self.options[name] | ||||
def __setattr__(self, name, value): | def __setattr__(self, name, value): | ||||
assert name in self.options | |||||
self.options[name] = value | self.options[name] = value | ||||
def serialize(self): | def serialize(self): | ||||
@@ -227,7 +228,7 @@ class Lark: | |||||
def _prepare_callbacks(self): | def _prepare_callbacks(self): | ||||
self.parser_class = get_frontend(self.options.parser, self.options.lexer) | self.parser_class = get_frontend(self.options.parser, self.options.lexer) | ||||
self._parse_tree_builder = ParseTreeBuilder(self.rules, self.options.tree_class, self.options.propagate_positions, self.options.keep_all_tokens, self.options.parser!='lalr' and self.options.ambiguity=='explicit', self.options.maybe_placeholders) | |||||
self._parse_tree_builder = ParseTreeBuilder(self.rules, self.options.tree_class or Tree, self.options.propagate_positions, self.options.keep_all_tokens, self.options.parser!='lalr' and self.options.ambiguity=='explicit', self.options.maybe_placeholders) | |||||
self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer) | self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer) | ||||
def _build_parser(self): | def _build_parser(self): | ||||
@@ -35,6 +35,16 @@ class Pattern(object): | |||||
value = ('(?%s)' % f) + value | value = ('(?%s)' % f) + value | ||||
return value | return value | ||||
@classmethod | |||||
def deserialize(cls, data): | |||||
class_ = { | |||||
's': PatternStr, | |||||
're': PatternRE, | |||||
}[data[0]] | |||||
value, flags = data[1:] | |||||
return class_(value, frozenset(flags)) | |||||
class PatternStr(Pattern): | class PatternStr(Pattern): | ||||
def to_regexp(self): | def to_regexp(self): | ||||
return self._get_flags(re.escape(self.value)) | return self._get_flags(re.escape(self.value)) | ||||
@@ -44,6 +54,9 @@ class PatternStr(Pattern): | |||||
return len(self.value) | return len(self.value) | ||||
max_width = min_width | max_width = min_width | ||||
def serialize(self): | |||||
return ['s', self.value, list(self.flags)] | |||||
class PatternRE(Pattern): | class PatternRE(Pattern): | ||||
def to_regexp(self): | def to_regexp(self): | ||||
return self._get_flags(self.value) | return self._get_flags(self.value) | ||||
@@ -55,6 +68,9 @@ class PatternRE(Pattern): | |||||
def max_width(self): | def max_width(self): | ||||
return get_regexp_width(self.to_regexp())[1] | return get_regexp_width(self.to_regexp())[1] | ||||
def serialize(self): | |||||
return ['re', self.value, list(self.flags)] | |||||
class TerminalDef(object): | class TerminalDef(object): | ||||
def __init__(self, name, pattern, priority=1): | def __init__(self, name, pattern, priority=1): | ||||
assert isinstance(pattern, Pattern), pattern | assert isinstance(pattern, Pattern), pattern | ||||
@@ -66,11 +82,12 @@ class TerminalDef(object): | |||||
return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern) | return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern) | ||||
def serialize(self): | def serialize(self): | ||||
return [self.name, self.pattern, self.priority] | |||||
return [self.name, self.pattern.serialize(), self.priority] | |||||
@classmethod | @classmethod | ||||
def deserialize(cls, data): | def deserialize(cls, data): | ||||
return cls(*data) | |||||
name, pattern, priority = data | |||||
return cls(name, Pattern.deserialize(pattern), priority) | |||||
@@ -46,7 +46,10 @@ class WithLexer(object): | |||||
} | } | ||||
@classmethod | @classmethod | ||||
def deserialize(cls, data, callbacks): | def deserialize(cls, data, callbacks): | ||||
class_ = globals()[data['type']] # XXX unsafe | |||||
class_ = { | |||||
'LALR_TraditionalLexer': LALR_TraditionalLexer, | |||||
'LALR_ContextualLexer': LALR_ContextualLexer, | |||||
}[data['type']] # XXX unsafe | |||||
parser = lalr_parser.Parser.deserialize(data['parser'], callbacks) | parser = lalr_parser.Parser.deserialize(data['parser'], callbacks) | ||||
assert parser | assert parser | ||||
inst = class_.__new__(class_) | inst = class_.__new__(class_) | ||||
@@ -5,28 +5,11 @@ | |||||
from ..exceptions import UnexpectedToken | from ..exceptions import UnexpectedToken | ||||
from ..lexer import Token | from ..lexer import Token | ||||
from ..grammar import Rule | from ..grammar import Rule | ||||
from ..utils import Enumerator | |||||
from .lalr_analysis import LALR_Analyzer, Shift, Reduce, IntParseTable | from .lalr_analysis import LALR_Analyzer, Shift, Reduce, IntParseTable | ||||
class Enumerator: | |||||
def __init__(self): | |||||
self.enums = {} | |||||
def get(self, item): | |||||
if item not in self.enums: | |||||
self.enums[item] = len(self.enums) | |||||
return self.enums[item] | |||||
def __len__(self): | |||||
return len(self.enums) | |||||
def reversed(self): | |||||
r = {v: k for k, v in self.enums.items()} | |||||
assert len(r) == len(self.enums) | |||||
return r | |||||
class Parser(object): | class Parser(object): | ||||
def __init__(self, parser_conf, debug=False): | def __init__(self, parser_conf, debug=False): | ||||
assert all(r.options is None or r.options.priority is None | assert all(r.options is None or r.options.priority is None | ||||
@@ -128,3 +128,22 @@ def get_regexp_width(regexp): | |||||
return sre_parse.parse(regexp).getwidth() | return sre_parse.parse(regexp).getwidth() | ||||
except sre_constants.error: | except sre_constants.error: | ||||
raise ValueError(regexp) | raise ValueError(regexp) | ||||
class Enumerator: | |||||
def __init__(self): | |||||
self.enums = {} | |||||
def get(self, item): | |||||
if item not in self.enums: | |||||
self.enums[item] = len(self.enums) | |||||
return self.enums[item] | |||||
def __len__(self): | |||||
return len(self.enums) | |||||
def reversed(self): | |||||
r = {v: k for k, v in self.enums.items()} | |||||
assert len(r) == len(self.enums) | |||||
return r | |||||