From aa7dc19bc343d211d1f5995680cd27d02188f8af Mon Sep 17 00:00:00 2001 From: Erez Sh Date: Sun, 29 Nov 2020 13:28:02 +0200 Subject: [PATCH] Corrections for PR --- lark/common.py | 10 +++++----- lark/exceptions.py | 17 ++++++----------- lark/lexer.py | 18 ++++++++---------- lark/utils.py | 1 + 4 files changed, 20 insertions(+), 26 deletions(-) diff --git a/lark/common.py b/lark/common.py index 30b92eb..467acf8 100644 --- a/lark/common.py +++ b/lark/common.py @@ -12,8 +12,8 @@ class LexerConf(Serialize): def __init__(self, terminals, re_module, ignore=(), postlex=None, callbacks=None, g_regex_flags=0, skip_validation=False, use_bytes=False): self.terminals = terminals - self.terminals_by_names = {t.name: t for t in self.terminals} - assert len(self.terminals) == len(self.terminals_by_names) + self.terminals_by_name = {t.name: t for t in self.terminals} + assert len(self.terminals) == len(self.terminals_by_name) self.ignore = ignore self.postlex = postlex self.callbacks = callbacks or {} @@ -22,14 +22,14 @@ class LexerConf(Serialize): self.skip_validation = skip_validation self.use_bytes = use_bytes self.lexer_type = None - + @property def tokens(self): warn("LexerConf.tokens is deprecated. Use LexerConf.terminals instead", DeprecationWarning) return self.terminals - + def _deserialize(self): - self.terminals_by_names = {t.name: t for t in self.terminals} + self.terminals_by_name = {t.name: t for t in self.terminals} diff --git a/lark/exceptions.py b/lark/exceptions.py index faae832..23e78b9 100644 --- a/lark/exceptions.py +++ b/lark/exceptions.py @@ -1,4 +1,4 @@ -from .utils import STRING_TYPE, logger +from .utils import STRING_TYPE, logger, NO_VALUE ###{standalone @@ -120,12 +120,8 @@ class UnexpectedInput(LarkError): def _format_expected(self, expected): if self._terminals_by_name: - ts = [] - for ter in expected: - ts.append(self._terminals_by_name[ter].user_repr()) - else: - ts = expected - return "Expected one of: \n\t* %s\n" % '\n\t* '.join(ts) + expected = [self._terminals_by_name[t_name].user_repr() for t_name in expected] + return "Expected one of: \n\t* %s\n" % '\n\t* '.join(expected) class UnexpectedEOF(ParseError, UnexpectedInput): @@ -178,7 +174,6 @@ class UnexpectedCharacters(LexError, UnexpectedInput): message += '\nPrevious tokens: %s\n' % ', '.join(repr(t) for t in self.token_history) return message -_not_set_marker = object() class UnexpectedToken(ParseError, UnexpectedInput): """When the parser throws UnexpectedToken, it instantiates a puppet @@ -197,7 +192,7 @@ class UnexpectedToken(ParseError, UnexpectedInput): self.token = token self.expected = expected # XXX deprecate? `accepts` is better - self._accepts = _not_set_marker + self._accepts = NO_VALUE self.considered_rules = considered_rules self.puppet = puppet self._terminals_by_name = terminals_by_name @@ -207,8 +202,8 @@ class UnexpectedToken(ParseError, UnexpectedInput): @property def accepts(self): - if self._accepts is _not_set_marker: - self._accepts = self.puppet and self.puppet.accepts() + if self._accepts is NO_VALUE: + self._accepts = self.puppet and self.puppet.accepts() return self._accepts def __str__(self): diff --git a/lark/lexer.py b/lark/lexer.py index c089e8a..114b4ce 100644 --- a/lark/lexer.py +++ b/lark/lexer.py @@ -92,7 +92,7 @@ class TerminalDef(Serialize): def __repr__(self): return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern) - + def user_repr(self): if self.name.startswith('__'): # We represent a generated terminal return self.pattern.raw or self.name @@ -317,7 +317,7 @@ class TraditionalLexer(Lexer): self.user_callbacks = conf.callbacks self.g_regex_flags = conf.g_regex_flags self.use_bytes = conf.use_bytes - self.terminals_by_names = conf.terminals_by_names + self.terminals_by_name = conf.terminals_by_name self._mres = None @@ -361,7 +361,7 @@ class TraditionalLexer(Lexer): allowed = {""} raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token], - state=parser_state, terminals_by_name=self.terminals_by_names) + state=parser_state, terminals_by_name=self.terminals_by_name) value, type_ = res @@ -403,7 +403,7 @@ class ContextualLexer(Lexer): def __init__(self, conf, states, always_accept=()): terminals = list(conf.terminals) - tokens_by_name = conf.terminals_by_names + terminals_by_name = conf.terminals_by_name trad_conf = copy(conf) trad_conf.terminals = terminals @@ -416,9 +416,8 @@ class ContextualLexer(Lexer): lexer = lexer_by_tokens[key] except KeyError: accepts = set(accepts) | set(conf.ignore) | set(always_accept) - state_tokens = [tokens_by_name[n] for n in accepts if n and n in tokens_by_name] lexer_conf = copy(trad_conf) - lexer_conf.terminals = state_tokens + lexer_conf.terminals = [terminals_by_name[n] for n in accepts if n in terminals_by_name] lexer = TraditionalLexer(lexer_conf) lexer_by_tokens[key] = lexer @@ -440,13 +439,12 @@ class ContextualLexer(Lexer): except UnexpectedCharacters as e: # In the contextual lexer, UnexpectedCharacters can mean that the terminal is defined, but not in the current context. # This tests the input against the global context, to provide a nicer error. - last_token = lexer_state.last_token # self.root_lexer.next_token will change this to the wrong token try: + last_token = lexer_state.last_token # Save last_token. Calling root_lexer.next_token will change this to the wrong token token = self.root_lexer.next_token(lexer_state, parser_state) + raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name) except UnexpectedCharacters: - raise e# Don't raise the exception that the root lexer raise. It has the wrong expected set. - else: - raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_names) + raise e # Raise the original UnexpectedCharacters. The root lexer raises it with the wrong expected set. class LexerThread: """A thread that ties a lexer instance and a lexer state, to be used by the parser""" diff --git a/lark/utils.py b/lark/utils.py index 3b5b8a8..642a59f 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -13,6 +13,7 @@ logger.setLevel(logging.CRITICAL) Py36 = (sys.version_info[:2] >= (3, 6)) +NO_VALUE = object() def classify(seq, key=None, value=None): d = {}