| @@ -20,6 +20,7 @@ class LexerConf(Serialize): | |||||
| class ParserConf: | class ParserConf: | ||||
| def __init__(self, rules, callbacks, start): | def __init__(self, rules, callbacks, start): | ||||
| assert isinstance(start, list) | |||||
| self.rules = rules | self.rules = rules | ||||
| self.callbacks = callbacks | self.callbacks = callbacks | ||||
| self.start = start | self.start = start | ||||
| @@ -52,7 +52,7 @@ class UnexpectedInput(LarkError): | |||||
| class UnexpectedCharacters(LexError, UnexpectedInput): | class UnexpectedCharacters(LexError, UnexpectedInput): | ||||
| def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None): | |||||
| def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None): | |||||
| message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column) | message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column) | ||||
| self.line = line | self.line = line | ||||
| @@ -65,6 +65,8 @@ class UnexpectedCharacters(LexError, UnexpectedInput): | |||||
| message += '\n\n' + self.get_context(seq) | message += '\n\n' + self.get_context(seq) | ||||
| if allowed: | if allowed: | ||||
| message += '\nExpecting: %s\n' % allowed | message += '\nExpecting: %s\n' % allowed | ||||
| if token_history: | |||||
| message += '\nPrevious tokens: %s\n' % ', '.join(repr(t) for t in token_history) | |||||
| super(UnexpectedCharacters, self).__init__(message) | super(UnexpectedCharacters, self).__init__(message) | ||||
| @@ -85,6 +85,9 @@ class LarkOptions(Serialize): | |||||
| options[name] = value | options[name] = value | ||||
| if isinstance(options['start'], str): | |||||
| options['start'] = [options['start']] | |||||
| self.__dict__['options'] = options | self.__dict__['options'] = options | ||||
| assert self.parser in ('earley', 'lalr', 'cyk', None) | assert self.parser in ('earley', 'lalr', 'cyk', None) | ||||
| @@ -287,8 +290,8 @@ class Lark(Serialize): | |||||
| return self.options.postlex.process(stream) | return self.options.postlex.process(stream) | ||||
| return stream | return stream | ||||
| def parse(self, text): | |||||
| def parse(self, text, start=None): | |||||
| "Parse the given text, according to the options provided. Returns a tree, unless specified otherwise." | "Parse the given text, according to the options provided. Returns a tree, unless specified otherwise." | ||||
| return self.parser.parse(text) | |||||
| return self.parser.parse(text, start=start) | |||||
| ###} | ###} | ||||
| @@ -149,6 +149,7 @@ class _Lex: | |||||
| newline_types = frozenset(newline_types) | newline_types = frozenset(newline_types) | ||||
| ignore_types = frozenset(ignore_types) | ignore_types = frozenset(ignore_types) | ||||
| line_ctr = LineCounter() | line_ctr = LineCounter() | ||||
| last_token = None | |||||
| while line_ctr.char_pos < len(stream): | while line_ctr.char_pos < len(stream): | ||||
| lexer = self.lexer | lexer = self.lexer | ||||
| @@ -166,6 +167,7 @@ class _Lex: | |||||
| t = lexer.callback[t.type](t) | t = lexer.callback[t.type](t) | ||||
| if not isinstance(t, Token): | if not isinstance(t, Token): | ||||
| raise ValueError("Callbacks must return a token (returned %r)" % t) | raise ValueError("Callbacks must return a token (returned %r)" % t) | ||||
| last_token = t | |||||
| yield t | yield t | ||||
| else: | else: | ||||
| if type_ in lexer.callback: | if type_ in lexer.callback: | ||||
| @@ -180,7 +182,7 @@ class _Lex: | |||||
| break | break | ||||
| else: | else: | ||||
| allowed = {v for m, tfi in lexer.mres for v in tfi.values()} | allowed = {v for m, tfi in lexer.mres for v in tfi.values()} | ||||
| raise UnexpectedCharacters(stream, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, state=self.state) | |||||
| raise UnexpectedCharacters(stream, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, state=self.state, token_history=last_token and [last_token]) | |||||
| class UnlessCallback: | class UnlessCallback: | ||||
| @@ -554,7 +554,8 @@ class Grammar: | |||||
| for s in r.expansion | for s in r.expansion | ||||
| if isinstance(s, NonTerminal) | if isinstance(s, NonTerminal) | ||||
| and s != r.origin} | and s != r.origin} | ||||
| compiled_rules = [r for r in compiled_rules if r.origin.name==start or r.origin in used_rules] | |||||
| used_rules |= {NonTerminal(s) for s in start} | |||||
| compiled_rules = [r for r in compiled_rules if r.origin in used_rules] | |||||
| if len(compiled_rules) == c: | if len(compiled_rules) == c: | ||||
| break | break | ||||
| @@ -690,7 +691,7 @@ class GrammarLoader: | |||||
| callback = ParseTreeBuilder(rules, ST).create_callback() | callback = ParseTreeBuilder(rules, ST).create_callback() | ||||
| lexer_conf = LexerConf(terminals, ['WS', 'COMMENT']) | lexer_conf = LexerConf(terminals, ['WS', 'COMMENT']) | ||||
| parser_conf = ParserConf(rules, callback, 'start') | |||||
| parser_conf = ParserConf(rules, callback, ['start']) | |||||
| self.parser = LALR_TraditionalLexer(lexer_conf, parser_conf) | self.parser = LALR_TraditionalLexer(lexer_conf, parser_conf) | ||||
| self.canonize_tree = CanonizeTree() | self.canonize_tree = CanonizeTree() | ||||
| @@ -44,18 +44,28 @@ def get_frontend(parser, lexer): | |||||
| raise ValueError('Unknown parser: %s' % parser) | raise ValueError('Unknown parser: %s' % parser) | ||||
| class _ParserFrontend(Serialize): | |||||
| def _parse(self, input, start, *args): | |||||
| if start is None: | |||||
| start = self.start | |||||
| if len(start) > 1: | |||||
| raise ValueError("Lark initialized with more than 1 possible start rule. Must specify which start rule to parse", start) | |||||
| start ,= start | |||||
| return self.parser.parse(input, start, *args) | |||||
| class WithLexer(Serialize): | |||||
| class WithLexer(_ParserFrontend): | |||||
| lexer = None | lexer = None | ||||
| parser = None | parser = None | ||||
| lexer_conf = None | lexer_conf = None | ||||
| start = None | |||||
| __serialize_fields__ = 'parser', 'lexer_conf' | |||||
| __serialize_fields__ = 'parser', 'lexer_conf', 'start' | |||||
| __serialize_namespace__ = LexerConf, | __serialize_namespace__ = LexerConf, | ||||
| def __init__(self, lexer_conf, parser_conf, options=None): | def __init__(self, lexer_conf, parser_conf, options=None): | ||||
| self.lexer_conf = lexer_conf | self.lexer_conf = lexer_conf | ||||
| self.start = parser_conf.start | |||||
| self.postlex = lexer_conf.postlex | self.postlex = lexer_conf.postlex | ||||
| @classmethod | @classmethod | ||||
| @@ -73,10 +83,10 @@ class WithLexer(Serialize): | |||||
| stream = self.lexer.lex(text) | stream = self.lexer.lex(text) | ||||
| return self.postlex.process(stream) if self.postlex else stream | return self.postlex.process(stream) if self.postlex else stream | ||||
| def parse(self, text): | |||||
| def parse(self, text, start=None): | |||||
| token_stream = self.lex(text) | token_stream = self.lex(text) | ||||
| sps = self.lexer.set_parser_state | sps = self.lexer.set_parser_state | ||||
| return self.parser.parse(token_stream, *[sps] if sps is not NotImplemented else []) | |||||
| return self._parse(token_stream, start, *[sps] if sps is not NotImplemented else []) | |||||
| def init_traditional_lexer(self): | def init_traditional_lexer(self): | ||||
| self.lexer = TraditionalLexer(self.lexer_conf.tokens, ignore=self.lexer_conf.ignore, user_callbacks=self.lexer_conf.callbacks) | self.lexer = TraditionalLexer(self.lexer_conf.tokens, ignore=self.lexer_conf.ignore, user_callbacks=self.lexer_conf.callbacks) | ||||
| @@ -135,9 +145,10 @@ class Earley(WithLexer): | |||||
| return term.name == token.type | return term.name == token.type | ||||
| class XEarley: | |||||
| class XEarley(_ParserFrontend): | |||||
| def __init__(self, lexer_conf, parser_conf, options=None, **kw): | def __init__(self, lexer_conf, parser_conf, options=None, **kw): | ||||
| self.token_by_name = {t.name:t for t in lexer_conf.tokens} | self.token_by_name = {t.name:t for t in lexer_conf.tokens} | ||||
| self.start = parser_conf.start | |||||
| self._prepare_match(lexer_conf) | self._prepare_match(lexer_conf) | ||||
| resolve_ambiguity = options.ambiguity == 'resolve' | resolve_ambiguity = options.ambiguity == 'resolve' | ||||
| @@ -167,8 +178,8 @@ class XEarley: | |||||
| self.regexps[t.name] = re.compile(regexp) | self.regexps[t.name] = re.compile(regexp) | ||||
| def parse(self, text): | |||||
| return self.parser.parse(text) | |||||
| def parse(self, text, start): | |||||
| return self._parse(text, start) | |||||
| class XEarley_CompleteLex(XEarley): | class XEarley_CompleteLex(XEarley): | ||||
| def __init__(self, *args, **kw): | def __init__(self, *args, **kw): | ||||
| @@ -187,7 +198,7 @@ class CYK(WithLexer): | |||||
| self.callbacks = parser_conf.callbacks | self.callbacks = parser_conf.callbacks | ||||
| def parse(self, text): | |||||
| def parse(self, text, start): | |||||
| tokens = list(self.lex(text)) | tokens = list(self.lex(text)) | ||||
| parse = self._parser.parse(tokens) | parse = self._parser.parse(tokens) | ||||
| parse = self._transform(parse) | parse = self._transform(parse) | ||||
| @@ -89,7 +89,7 @@ class Parser(object): | |||||
| self.orig_rules = {rule: rule for rule in rules} | self.orig_rules = {rule: rule for rule in rules} | ||||
| rules = [self._to_rule(rule) for rule in rules] | rules = [self._to_rule(rule) for rule in rules] | ||||
| self.grammar = to_cnf(Grammar(rules)) | self.grammar = to_cnf(Grammar(rules)) | ||||
| self.start = NT(start) | |||||
| self.start = NT(start[0]) | |||||
| def _to_rule(self, lark_rule): | def _to_rule(self, lark_rule): | ||||
| """Converts a lark rule, (lhs, rhs, callback, options), to a Rule.""" | """Converts a lark rule, (lhs, rhs, callback, options), to a Rule.""" | ||||
| @@ -273,8 +273,9 @@ class Parser: | |||||
| ## Column is now the final column in the parse. | ## Column is now the final column in the parse. | ||||
| assert i == len(columns)-1 | assert i == len(columns)-1 | ||||
| def parse(self, stream, start_symbol=None): | |||||
| start_symbol = NonTerminal(start_symbol or self.parser_conf.start) | |||||
| def parse(self, stream, start): | |||||
| assert start, start | |||||
| start_symbol = NonTerminal(start) | |||||
| columns = [set()] | columns = [set()] | ||||
| to_scan = set() # The scan buffer. 'Q' in E.Scott's paper. | to_scan = set() # The scan buffer. 'Q' in E.Scott's paper. | ||||
| @@ -109,8 +109,10 @@ class GrammarAnalyzer(object): | |||||
| def __init__(self, parser_conf, debug=False): | def __init__(self, parser_conf, debug=False): | ||||
| self.debug = debug | self.debug = debug | ||||
| root_rule = Rule(NonTerminal('$root'), [NonTerminal(parser_conf.start), Terminal('$END')]) | |||||
| rules = parser_conf.rules + [root_rule] | |||||
| root_rules = {start: Rule(NonTerminal('$root_' + start), [NonTerminal(start), Terminal('$END')]) | |||||
| for start in parser_conf.start} | |||||
| rules = parser_conf.rules + list(root_rules.values()) | |||||
| self.rules_by_origin = classify(rules, lambda r: r.origin) | self.rules_by_origin = classify(rules, lambda r: r.origin) | ||||
| if len(rules) != len(set(rules)): | if len(rules) != len(set(rules)): | ||||
| @@ -122,10 +124,11 @@ class GrammarAnalyzer(object): | |||||
| if not (sym.is_term or sym in self.rules_by_origin): | if not (sym.is_term or sym in self.rules_by_origin): | ||||
| raise GrammarError("Using an undefined rule: %s" % sym) # TODO test validation | raise GrammarError("Using an undefined rule: %s" % sym) # TODO test validation | ||||
| self.start_state = self.expand_rule(root_rule.origin) | |||||
| self.start_states = {start: self.expand_rule(root_rule.origin) | |||||
| for start, root_rule in root_rules.items()} | |||||
| end_rule = RulePtr(root_rule, len(root_rule.expansion)) | |||||
| self.end_state = fzset({end_rule}) | |||||
| self.end_states = {start: fzset({RulePtr(root_rule, len(root_rule.expansion))}) | |||||
| for start, root_rule in root_rules.items()} | |||||
| self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules) | self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules) | ||||
| @@ -29,10 +29,10 @@ Shift = Action('Shift') | |||||
| Reduce = Action('Reduce') | Reduce = Action('Reduce') | ||||
| class ParseTable: | class ParseTable: | ||||
| def __init__(self, states, start_state, end_state): | |||||
| def __init__(self, states, start_states, end_states): | |||||
| self.states = states | self.states = states | ||||
| self.start_state = start_state | |||||
| self.end_state = end_state | |||||
| self.start_states = start_states | |||||
| self.end_states = end_states | |||||
| def serialize(self, memo): | def serialize(self, memo): | ||||
| tokens = Enumerator() | tokens = Enumerator() | ||||
| @@ -47,8 +47,8 @@ class ParseTable: | |||||
| return { | return { | ||||
| 'tokens': tokens.reversed(), | 'tokens': tokens.reversed(), | ||||
| 'states': states, | 'states': states, | ||||
| 'start_state': self.start_state, | |||||
| 'end_state': self.end_state, | |||||
| 'start_states': self.start_states, | |||||
| 'end_states': self.end_states, | |||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| @@ -59,7 +59,7 @@ class ParseTable: | |||||
| for token, (action, arg) in actions.items()} | for token, (action, arg) in actions.items()} | ||||
| for state, actions in data['states'].items() | for state, actions in data['states'].items() | ||||
| } | } | ||||
| return cls(states, data['start_state'], data['end_state']) | |||||
| return cls(states, data['start_states'], data['end_states']) | |||||
| class IntParseTable(ParseTable): | class IntParseTable(ParseTable): | ||||
| @@ -76,9 +76,9 @@ class IntParseTable(ParseTable): | |||||
| int_states[ state_to_idx[s] ] = la | int_states[ state_to_idx[s] ] = la | ||||
| start_state = state_to_idx[parse_table.start_state] | |||||
| end_state = state_to_idx[parse_table.end_state] | |||||
| return cls(int_states, start_state, end_state) | |||||
| start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()} | |||||
| end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()} | |||||
| return cls(int_states, start_states, end_states) | |||||
| ###} | ###} | ||||
| @@ -124,10 +124,10 @@ class LALR_Analyzer(GrammarAnalyzer): | |||||
| self.states[state] = {k.name:v[0] for k, v in lookahead.items()} | self.states[state] = {k.name:v[0] for k, v in lookahead.items()} | ||||
| for _ in bfs([self.start_state], step): | |||||
| for _ in bfs(self.start_states.values(), step): | |||||
| pass | pass | ||||
| self._parse_table = ParseTable(self.states, self.start_state, self.end_state) | |||||
| self._parse_table = ParseTable(self.states, self.start_states, self.end_states) | |||||
| if self.debug: | if self.debug: | ||||
| self.parse_table = self._parse_table | self.parse_table = self._parse_table | ||||
| @@ -39,19 +39,22 @@ class LALR_Parser(object): | |||||
| class _Parser: | class _Parser: | ||||
| def __init__(self, parse_table, callbacks): | def __init__(self, parse_table, callbacks): | ||||
| self.states = parse_table.states | self.states = parse_table.states | ||||
| self.start_state = parse_table.start_state | |||||
| self.end_state = parse_table.end_state | |||||
| self.start_states = parse_table.start_states | |||||
| self.end_states = parse_table.end_states | |||||
| self.callbacks = callbacks | self.callbacks = callbacks | ||||
| def parse(self, seq, set_state=None): | |||||
| def parse(self, seq, start, set_state=None): | |||||
| token = None | token = None | ||||
| stream = iter(seq) | stream = iter(seq) | ||||
| states = self.states | states = self.states | ||||
| state_stack = [self.start_state] | |||||
| start_state = self.start_states[start] | |||||
| end_state = self.end_states[start] | |||||
| state_stack = [start_state] | |||||
| value_stack = [] | value_stack = [] | ||||
| if set_state: set_state(self.start_state) | |||||
| if set_state: set_state(start_state) | |||||
| def get_action(token): | def get_action(token): | ||||
| state = state_stack[-1] | state = state_stack[-1] | ||||
| @@ -81,7 +84,7 @@ class _Parser: | |||||
| for token in stream: | for token in stream: | ||||
| while True: | while True: | ||||
| action, arg = get_action(token) | action, arg = get_action(token) | ||||
| assert arg != self.end_state | |||||
| assert arg != end_state | |||||
| if action is Shift: | if action is Shift: | ||||
| state_stack.append(arg) | state_stack.append(arg) | ||||
| @@ -95,7 +98,7 @@ class _Parser: | |||||
| while True: | while True: | ||||
| _action, arg = get_action(token) | _action, arg = get_action(token) | ||||
| if _action is Shift: | if _action is Shift: | ||||
| assert arg == self.end_state | |||||
| assert arg == end_state | |||||
| val ,= value_stack | val ,= value_stack | ||||
| return val | return val | ||||
| else: | else: | ||||
| @@ -1523,6 +1523,15 @@ def _make_parser_test(LEXER, PARSER): | |||||
| parser3 = Lark.deserialize(d, namespace, m) | parser3 = Lark.deserialize(d, namespace, m) | ||||
| self.assertEqual(parser3.parse('ABC'), Tree('start', [Tree('b', [])]) ) | self.assertEqual(parser3.parse('ABC'), Tree('start', [Tree('b', [])]) ) | ||||
| def test_multi_start(self): | |||||
| parser = _Lark(''' | |||||
| a: "x" "a"? | |||||
| b: "x" "b"? | |||||
| ''', start=['a', 'b']) | |||||
| self.assertEqual(parser.parse('xa', 'a'), Tree('a', [])) | |||||
| self.assertEqual(parser.parse('xb', 'b'), Tree('b', [])) | |||||
| _NAME = "Test" + PARSER.capitalize() + LEXER.capitalize() | _NAME = "Test" + PARSER.capitalize() + LEXER.capitalize() | ||||