@@ -31,6 +31,7 @@ class LarkOptions: | |||||
lexer_callbacks: Dict[str, Callable[[Token], Token]] | lexer_callbacks: Dict[str, Callable[[Token], Token]] | ||||
cache: Union[bool, str] | cache: Union[bool, str] | ||||
g_regex_flags: int | g_regex_flags: int | ||||
use_bytes: bool | |||||
class Lark: | class Lark: | ||||
@@ -55,7 +56,8 @@ class Lark: | |||||
propagate_positions: bool = False, | propagate_positions: bool = False, | ||||
maybe_placeholders: bool = False, | maybe_placeholders: bool = False, | ||||
lexer_callbacks: Optional[Dict[str, Callable[[Token], Token]]] = None, | lexer_callbacks: Optional[Dict[str, Callable[[Token], Token]]] = None, | ||||
g_regex_flags: int = ... | |||||
g_regex_flags: int = ..., | |||||
use_bytes: bool = False, | |||||
): | ): | ||||
... | ... | ||||
@@ -4,10 +4,10 @@ from .lexer import TerminalDef | |||||
###{standalone | ###{standalone | ||||
class LexerConf(Serialize): | class LexerConf(Serialize): | ||||
__serialize_fields__ = 'tokens', 'ignore', 'g_regex_flags' | |||||
__serialize_fields__ = 'tokens', 'ignore', 'g_regex_flags', 'use_bytes' | |||||
__serialize_namespace__ = TerminalDef, | __serialize_namespace__ = TerminalDef, | ||||
def __init__(self, tokens, re_module, ignore=(), postlex=None, callbacks=None, g_regex_flags=0, skip_validation=False): | |||||
def __init__(self, tokens, re_module, ignore=(), postlex=None, callbacks=None, g_regex_flags=0, skip_validation=False, use_bytes=False): | |||||
self.tokens = tokens # TODO should be terminals | self.tokens = tokens # TODO should be terminals | ||||
self.ignore = ignore | self.ignore = ignore | ||||
self.postlex = postlex | self.postlex = postlex | ||||
@@ -15,6 +15,7 @@ class LexerConf(Serialize): | |||||
self.g_regex_flags = g_regex_flags | self.g_regex_flags = g_regex_flags | ||||
self.re_module = re_module | self.re_module = re_module | ||||
self.skip_validation = skip_validation | self.skip_validation = skip_validation | ||||
self.use_bytes = use_bytes | |||||
def _deserialize(self): | def _deserialize(self): | ||||
self.callbacks = {} # TODO | self.callbacks = {} # TODO | ||||
@@ -28,9 +28,14 @@ class UnexpectedInput(LarkError): | |||||
pos = self.pos_in_stream | pos = self.pos_in_stream | ||||
start = max(pos - span, 0) | start = max(pos - span, 0) | ||||
end = pos + span | end = pos + span | ||||
before = text[start:pos].rsplit('\n', 1)[-1] | |||||
after = text[pos:end].split('\n', 1)[0] | |||||
return before + after + '\n' + ' ' * len(before) + '^\n' | |||||
if not isinstance(text, bytes): | |||||
before = text[start:pos].rsplit('\n', 1)[-1] | |||||
after = text[pos:end].split('\n', 1)[0] | |||||
return before + after + '\n' + ' ' * len(before) + '^\n' | |||||
else: | |||||
before = text[start:pos].rsplit(b'\n', 1)[-1] | |||||
after = text[pos:end].split(b'\n', 1)[0] | |||||
return (before + after + b'\n' + b' ' * len(before) + b'^\n').decode("ascii", "backslashreplace") | |||||
def match_examples(self, parse_fn, examples, token_type_match_fallback=False): | def match_examples(self, parse_fn, examples, token_type_match_fallback=False): | ||||
""" Given a parser instance and a dictionary mapping some label with | """ Given a parser instance and a dictionary mapping some label with | ||||
@@ -67,7 +72,11 @@ class UnexpectedInput(LarkError): | |||||
class UnexpectedCharacters(LexError, UnexpectedInput): | class UnexpectedCharacters(LexError, UnexpectedInput): | ||||
def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None): | def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None): | ||||
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column) | |||||
if isinstance(seq, bytes): | |||||
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos:lex_pos+1].decode("ascii", "backslashreplace"), line, column) | |||||
else: | |||||
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column) | |||||
self.line = line | self.line = line | ||||
self.column = column | self.column = column | ||||
@@ -105,6 +105,7 @@ class LarkOptions(Serialize): | |||||
'maybe_placeholders': False, | 'maybe_placeholders': False, | ||||
'edit_terminals': None, | 'edit_terminals': None, | ||||
'g_regex_flags': 0, | 'g_regex_flags': 0, | ||||
'use_bytes': False, | |||||
} | } | ||||
def __init__(self, options_dict): | def __init__(self, options_dict): | ||||
@@ -252,7 +253,7 @@ class Lark(Serialize): | |||||
for t in self.terminals: | for t in self.terminals: | ||||
self.options.edit_terminals(t) | self.options.edit_terminals(t) | ||||
self._terminals_dict = {t.name:t for t in self.terminals} | |||||
self._terminals_dict = {t.name: t for t in self.terminals} | |||||
# If the user asked to invert the priorities, negate them all here. | # If the user asked to invert the priorities, negate them all here. | ||||
# This replaces the old 'resolve__antiscore_sum' option. | # This replaces the old 'resolve__antiscore_sum' option. | ||||
@@ -276,7 +277,7 @@ class Lark(Serialize): | |||||
if hasattr(t, term.name): | if hasattr(t, term.name): | ||||
lexer_callbacks[term.name] = getattr(t, term.name) | lexer_callbacks[term.name] = getattr(t, term.name) | ||||
self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags) | |||||
self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes) | |||||
if self.options.parser: | if self.options.parser: | ||||
self.parser = self._build_parser() | self.parser = self._build_parser() | ||||
@@ -230,7 +230,7 @@ class CallChain: | |||||
def _create_unless(terminals, g_regex_flags, re_): | |||||
def _create_unless(terminals, g_regex_flags, re_, use_bytes): | |||||
tokens_by_type = classify(terminals, lambda t: type(t.pattern)) | tokens_by_type = classify(terminals, lambda t: type(t.pattern)) | ||||
assert len(tokens_by_type) <= 2, tokens_by_type.keys() | assert len(tokens_by_type) <= 2, tokens_by_type.keys() | ||||
embedded_strs = set() | embedded_strs = set() | ||||
@@ -247,31 +247,34 @@ def _create_unless(terminals, g_regex_flags, re_): | |||||
if strtok.pattern.flags <= retok.pattern.flags: | if strtok.pattern.flags <= retok.pattern.flags: | ||||
embedded_strs.add(strtok) | embedded_strs.add(strtok) | ||||
if unless: | if unless: | ||||
callback[retok.name] = UnlessCallback(build_mres(unless, g_regex_flags, re_, match_whole=True)) | |||||
callback[retok.name] = UnlessCallback(build_mres(unless, g_regex_flags, re_, match_whole=True, use_bytes=use_bytes)) | |||||
terminals = [t for t in terminals if t not in embedded_strs] | terminals = [t for t in terminals if t not in embedded_strs] | ||||
return terminals, callback | return terminals, callback | ||||
def _build_mres(terminals, max_size, g_regex_flags, match_whole, re_): | |||||
def _build_mres(terminals, max_size, g_regex_flags, match_whole, re_, use_bytes): | |||||
# Python sets an unreasonable group limit (currently 100) in its re module | # Python sets an unreasonable group limit (currently 100) in its re module | ||||
# Worse, the only way to know we reached it is by catching an AssertionError! | # Worse, the only way to know we reached it is by catching an AssertionError! | ||||
# This function recursively tries less and less groups until it's successful. | # This function recursively tries less and less groups until it's successful. | ||||
postfix = '$' if match_whole else '' | postfix = '$' if match_whole else '' | ||||
mres = [] | mres = [] | ||||
while terminals: | while terminals: | ||||
pattern = u'|'.join(u'(?P<%s>%s)' % (t.name, t.pattern.to_regexp() + postfix) for t in terminals[:max_size]) | |||||
if use_bytes: | |||||
pattern = pattern.encode() | |||||
try: | try: | ||||
mre = re_.compile(u'|'.join(u'(?P<%s>%s)'%(t.name, t.pattern.to_regexp()+postfix) for t in terminals[:max_size]), g_regex_flags) | |||||
mre = re_.compile(pattern, g_regex_flags) | |||||
except AssertionError: # Yes, this is what Python provides us.. :/ | except AssertionError: # Yes, this is what Python provides us.. :/ | ||||
return _build_mres(terminals, max_size//2, g_regex_flags, match_whole, re_) | |||||
return _build_mres(terminals, max_size//2, g_regex_flags, match_whole, re_, use_bytes) | |||||
# terms_from_name = {t.name: t for t in terminals[:max_size]} | # terms_from_name = {t.name: t for t in terminals[:max_size]} | ||||
mres.append((mre, {i:n for n,i in mre.groupindex.items()} )) | mres.append((mre, {i:n for n,i in mre.groupindex.items()} )) | ||||
terminals = terminals[max_size:] | terminals = terminals[max_size:] | ||||
return mres | return mres | ||||
def build_mres(terminals, g_regex_flags, re_, match_whole=False): | |||||
return _build_mres(terminals, len(terminals), g_regex_flags, match_whole, re_) | |||||
def build_mres(terminals, g_regex_flags, re_, use_bytes, match_whole=False): | |||||
return _build_mres(terminals, len(terminals), g_regex_flags, match_whole, re_, use_bytes) | |||||
def _regexp_has_newline(r): | def _regexp_has_newline(r): | ||||
r"""Expressions that may indicate newlines in a regexp: | r"""Expressions that may indicate newlines in a regexp: | ||||
@@ -321,12 +324,13 @@ class TraditionalLexer(Lexer): | |||||
self.terminals = terminals | self.terminals = terminals | ||||
self.user_callbacks = conf.callbacks | self.user_callbacks = conf.callbacks | ||||
self.g_regex_flags = conf.g_regex_flags | self.g_regex_flags = conf.g_regex_flags | ||||
self.use_bytes = conf.use_bytes | |||||
self._mres = None | self._mres = None | ||||
# self.build(g_regex_flags) | # self.build(g_regex_flags) | ||||
def _build(self): | def _build(self): | ||||
terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, re_=self.re) | |||||
terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, re_=self.re, use_bytes=self.use_bytes) | |||||
assert all(self.callback.values()) | assert all(self.callback.values()) | ||||
for type_, f in self.user_callbacks.items(): | for type_, f in self.user_callbacks.items(): | ||||
@@ -336,7 +340,7 @@ class TraditionalLexer(Lexer): | |||||
else: | else: | ||||
self.callback[type_] = f | self.callback[type_] = f | ||||
self._mres = build_mres(terminals, self.g_regex_flags, self.re) | |||||
self._mres = build_mres(terminals, self.g_regex_flags, self.re, self.use_bytes) | |||||
@property | @property | ||||
def mres(self): | def mres(self): | ||||
@@ -365,7 +369,8 @@ class ContextualLexer(Lexer): | |||||
assert t.name not in tokens_by_name, t | assert t.name not in tokens_by_name, t | ||||
tokens_by_name[t.name] = t | tokens_by_name[t.name] = t | ||||
trad_conf = type(conf)(terminals, conf.re_module, conf.ignore, callbacks=conf.callbacks, g_regex_flags=conf.g_regex_flags, skip_validation=conf.skip_validation) | |||||
trad_conf = copy(conf) | |||||
trad_conf.tokens = terminals | |||||
lexer_by_tokens = {} | lexer_by_tokens = {} | ||||
self.lexers = {} | self.lexers = {} | ||||