Browse Source

multi start, all tests passing except cyk

tags/gm/2021-09-23T00Z/github.com--lark-parser-lark/0.7.2
Erez Shinan 5 years ago
parent
commit
2625d68869
12 changed files with 76 additions and 40 deletions
  1. +1
    -0
      lark/common.py
  2. +3
    -1
      lark/exceptions.py
  3. +5
    -2
      lark/lark.py
  4. +3
    -1
      lark/lexer.py
  5. +3
    -2
      lark/load_grammar.py
  6. +19
    -8
      lark/parser_frontends.py
  7. +1
    -1
      lark/parsers/cyk.py
  8. +3
    -2
      lark/parsers/earley.py
  9. +8
    -5
      lark/parsers/grammar_analysis.py
  10. +11
    -11
      lark/parsers/lalr_analysis.py
  11. +10
    -7
      lark/parsers/lalr_parser.py
  12. +9
    -0
      tests/test_parser.py

+ 1
- 0
lark/common.py View File

@@ -20,6 +20,7 @@ class LexerConf(Serialize):

class ParserConf:
def __init__(self, rules, callbacks, start):
assert isinstance(start, list)
self.rules = rules
self.callbacks = callbacks
self.start = start


+ 3
- 1
lark/exceptions.py View File

@@ -52,7 +52,7 @@ class UnexpectedInput(LarkError):


class UnexpectedCharacters(LexError, UnexpectedInput):
def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None):
def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None):
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column)

self.line = line
@@ -65,6 +65,8 @@ class UnexpectedCharacters(LexError, UnexpectedInput):
message += '\n\n' + self.get_context(seq)
if allowed:
message += '\nExpecting: %s\n' % allowed
if token_history:
message += '\nPrevious tokens: %s\n' % ', '.join(repr(t) for t in token_history)

super(UnexpectedCharacters, self).__init__(message)



+ 5
- 2
lark/lark.py View File

@@ -85,6 +85,9 @@ class LarkOptions(Serialize):

options[name] = value

if isinstance(options['start'], str):
options['start'] = [options['start']]

self.__dict__['options'] = options

assert self.parser in ('earley', 'lalr', 'cyk', None)
@@ -287,8 +290,8 @@ class Lark(Serialize):
return self.options.postlex.process(stream)
return stream

def parse(self, text):
def parse(self, text, start=None):
"Parse the given text, according to the options provided. Returns a tree, unless specified otherwise."
return self.parser.parse(text)
return self.parser.parse(text, start=start)

###}

+ 3
- 1
lark/lexer.py View File

@@ -149,6 +149,7 @@ class _Lex:
newline_types = frozenset(newline_types)
ignore_types = frozenset(ignore_types)
line_ctr = LineCounter()
last_token = None

while line_ctr.char_pos < len(stream):
lexer = self.lexer
@@ -166,6 +167,7 @@ class _Lex:
t = lexer.callback[t.type](t)
if not isinstance(t, Token):
raise ValueError("Callbacks must return a token (returned %r)" % t)
last_token = t
yield t
else:
if type_ in lexer.callback:
@@ -180,7 +182,7 @@ class _Lex:
break
else:
allowed = {v for m, tfi in lexer.mres for v in tfi.values()}
raise UnexpectedCharacters(stream, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, state=self.state)
raise UnexpectedCharacters(stream, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, state=self.state, token_history=last_token and [last_token])


class UnlessCallback:


+ 3
- 2
lark/load_grammar.py View File

@@ -554,7 +554,8 @@ class Grammar:
for s in r.expansion
if isinstance(s, NonTerminal)
and s != r.origin}
compiled_rules = [r for r in compiled_rules if r.origin.name==start or r.origin in used_rules]
used_rules |= {NonTerminal(s) for s in start}
compiled_rules = [r for r in compiled_rules if r.origin in used_rules]
if len(compiled_rules) == c:
break

@@ -690,7 +691,7 @@ class GrammarLoader:
callback = ParseTreeBuilder(rules, ST).create_callback()
lexer_conf = LexerConf(terminals, ['WS', 'COMMENT'])

parser_conf = ParserConf(rules, callback, 'start')
parser_conf = ParserConf(rules, callback, ['start'])
self.parser = LALR_TraditionalLexer(lexer_conf, parser_conf)

self.canonize_tree = CanonizeTree()


+ 19
- 8
lark/parser_frontends.py View File

@@ -44,18 +44,28 @@ def get_frontend(parser, lexer):
raise ValueError('Unknown parser: %s' % parser)


class _ParserFrontend(Serialize):
def _parse(self, input, start, *args):
if start is None:
start = self.start
if len(start) > 1:
raise ValueError("Lark initialized with more than 1 possible start rule. Must specify which start rule to parse", start)
start ,= start
return self.parser.parse(input, start, *args)


class WithLexer(Serialize):
class WithLexer(_ParserFrontend):
lexer = None
parser = None
lexer_conf = None
start = None

__serialize_fields__ = 'parser', 'lexer_conf'
__serialize_fields__ = 'parser', 'lexer_conf', 'start'
__serialize_namespace__ = LexerConf,

def __init__(self, lexer_conf, parser_conf, options=None):
self.lexer_conf = lexer_conf
self.start = parser_conf.start
self.postlex = lexer_conf.postlex

@classmethod
@@ -73,10 +83,10 @@ class WithLexer(Serialize):
stream = self.lexer.lex(text)
return self.postlex.process(stream) if self.postlex else stream

def parse(self, text):
def parse(self, text, start=None):
token_stream = self.lex(text)
sps = self.lexer.set_parser_state
return self.parser.parse(token_stream, *[sps] if sps is not NotImplemented else [])
return self._parse(token_stream, start, *[sps] if sps is not NotImplemented else [])

def init_traditional_lexer(self):
self.lexer = TraditionalLexer(self.lexer_conf.tokens, ignore=self.lexer_conf.ignore, user_callbacks=self.lexer_conf.callbacks)
@@ -135,9 +145,10 @@ class Earley(WithLexer):
return term.name == token.type


class XEarley:
class XEarley(_ParserFrontend):
def __init__(self, lexer_conf, parser_conf, options=None, **kw):
self.token_by_name = {t.name:t for t in lexer_conf.tokens}
self.start = parser_conf.start

self._prepare_match(lexer_conf)
resolve_ambiguity = options.ambiguity == 'resolve'
@@ -167,8 +178,8 @@ class XEarley:

self.regexps[t.name] = re.compile(regexp)

def parse(self, text):
return self.parser.parse(text)
def parse(self, text, start):
return self._parse(text, start)

class XEarley_CompleteLex(XEarley):
def __init__(self, *args, **kw):
@@ -187,7 +198,7 @@ class CYK(WithLexer):

self.callbacks = parser_conf.callbacks

def parse(self, text):
def parse(self, text, start):
tokens = list(self.lex(text))
parse = self._parser.parse(tokens)
parse = self._transform(parse)


+ 1
- 1
lark/parsers/cyk.py View File

@@ -89,7 +89,7 @@ class Parser(object):
self.orig_rules = {rule: rule for rule in rules}
rules = [self._to_rule(rule) for rule in rules]
self.grammar = to_cnf(Grammar(rules))
self.start = NT(start)
self.start = NT(start[0])

def _to_rule(self, lark_rule):
"""Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""


+ 3
- 2
lark/parsers/earley.py View File

@@ -273,8 +273,9 @@ class Parser:
## Column is now the final column in the parse.
assert i == len(columns)-1

def parse(self, stream, start_symbol=None):
start_symbol = NonTerminal(start_symbol or self.parser_conf.start)
def parse(self, stream, start):
assert start, start
start_symbol = NonTerminal(start)

columns = [set()]
to_scan = set() # The scan buffer. 'Q' in E.Scott's paper.


+ 8
- 5
lark/parsers/grammar_analysis.py View File

@@ -109,8 +109,10 @@ class GrammarAnalyzer(object):
def __init__(self, parser_conf, debug=False):
self.debug = debug

root_rule = Rule(NonTerminal('$root'), [NonTerminal(parser_conf.start), Terminal('$END')])
rules = parser_conf.rules + [root_rule]
root_rules = {start: Rule(NonTerminal('$root_' + start), [NonTerminal(start), Terminal('$END')])
for start in parser_conf.start}

rules = parser_conf.rules + list(root_rules.values())
self.rules_by_origin = classify(rules, lambda r: r.origin)

if len(rules) != len(set(rules)):
@@ -122,10 +124,11 @@ class GrammarAnalyzer(object):
if not (sym.is_term or sym in self.rules_by_origin):
raise GrammarError("Using an undefined rule: %s" % sym) # TODO test validation

self.start_state = self.expand_rule(root_rule.origin)
self.start_states = {start: self.expand_rule(root_rule.origin)
for start, root_rule in root_rules.items()}

end_rule = RulePtr(root_rule, len(root_rule.expansion))
self.end_state = fzset({end_rule})
self.end_states = {start: fzset({RulePtr(root_rule, len(root_rule.expansion))})
for start, root_rule in root_rules.items()}

self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules)



+ 11
- 11
lark/parsers/lalr_analysis.py View File

@@ -29,10 +29,10 @@ Shift = Action('Shift')
Reduce = Action('Reduce')

class ParseTable:
def __init__(self, states, start_state, end_state):
def __init__(self, states, start_states, end_states):
self.states = states
self.start_state = start_state
self.end_state = end_state
self.start_states = start_states
self.end_states = end_states

def serialize(self, memo):
tokens = Enumerator()
@@ -47,8 +47,8 @@ class ParseTable:
return {
'tokens': tokens.reversed(),
'states': states,
'start_state': self.start_state,
'end_state': self.end_state,
'start_states': self.start_states,
'end_states': self.end_states,
}

@classmethod
@@ -59,7 +59,7 @@ class ParseTable:
for token, (action, arg) in actions.items()}
for state, actions in data['states'].items()
}
return cls(states, data['start_state'], data['end_state'])
return cls(states, data['start_states'], data['end_states'])


class IntParseTable(ParseTable):
@@ -76,9 +76,9 @@ class IntParseTable(ParseTable):
int_states[ state_to_idx[s] ] = la


start_state = state_to_idx[parse_table.start_state]
end_state = state_to_idx[parse_table.end_state]
return cls(int_states, start_state, end_state)
start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()}
end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()}
return cls(int_states, start_states, end_states)

###}

@@ -124,10 +124,10 @@ class LALR_Analyzer(GrammarAnalyzer):

self.states[state] = {k.name:v[0] for k, v in lookahead.items()}

for _ in bfs([self.start_state], step):
for _ in bfs(self.start_states.values(), step):
pass

self._parse_table = ParseTable(self.states, self.start_state, self.end_state)
self._parse_table = ParseTable(self.states, self.start_states, self.end_states)

if self.debug:
self.parse_table = self._parse_table


+ 10
- 7
lark/parsers/lalr_parser.py View File

@@ -39,19 +39,22 @@ class LALR_Parser(object):
class _Parser:
def __init__(self, parse_table, callbacks):
self.states = parse_table.states
self.start_state = parse_table.start_state
self.end_state = parse_table.end_state
self.start_states = parse_table.start_states
self.end_states = parse_table.end_states
self.callbacks = callbacks

def parse(self, seq, set_state=None):
def parse(self, seq, start, set_state=None):
token = None
stream = iter(seq)
states = self.states

state_stack = [self.start_state]
start_state = self.start_states[start]
end_state = self.end_states[start]

state_stack = [start_state]
value_stack = []

if set_state: set_state(self.start_state)
if set_state: set_state(start_state)

def get_action(token):
state = state_stack[-1]
@@ -81,7 +84,7 @@ class _Parser:
for token in stream:
while True:
action, arg = get_action(token)
assert arg != self.end_state
assert arg != end_state

if action is Shift:
state_stack.append(arg)
@@ -95,7 +98,7 @@ class _Parser:
while True:
_action, arg = get_action(token)
if _action is Shift:
assert arg == self.end_state
assert arg == end_state
val ,= value_stack
return val
else:


+ 9
- 0
tests/test_parser.py View File

@@ -1523,6 +1523,15 @@ def _make_parser_test(LEXER, PARSER):
parser3 = Lark.deserialize(d, namespace, m)
self.assertEqual(parser3.parse('ABC'), Tree('start', [Tree('b', [])]) )

def test_multi_start(self):
parser = _Lark('''
a: "x" "a"?
b: "x" "b"?
''', start=['a', 'b'])

self.assertEqual(parser.parse('xa', 'a'), Tree('a', []))
self.assertEqual(parser.parse('xb', 'b'), Tree('b', []))



_NAME = "Test" + PARSER.capitalize() + LEXER.capitalize()


Loading…
Cancel
Save