Browse Source

Merge branch 'chanicpanic-forest-transformer'

tags/gm/2021-09-23T00Z/github.com--lark-parser-lark/0.10.0
Erez Sh 4 years ago
parent
commit
555b268eb2
6 changed files with 582 additions and 175 deletions
  1. +7
    -3
      lark/lark.py
  2. +4
    -1
      lark/parser_frontends.py
  3. +10
    -7
      lark/parsers/earley.py
  4. +329
    -161
      lark/parsers/earley_forest.py
  5. +4
    -3
      lark/parsers/xearley.py
  6. +228
    -0
      tests/test_tree_forest_transformer.py

+ 7
- 3
lark/lark.py View File

@@ -75,6 +75,7 @@ class LarkOptions(Serialize):
- "resolve" - The parser will automatically choose the simplest derivation
(it chooses consistently: greedy for tokens, non-greedy for rules)
- "explicit": The parser will return all derivations wrapped in "_ambig" tree nodes (i.e. a forest).
- "forest": The parser will return the root of the shared packed parse forest.

**=== Misc. / Domain Specific Options ===**

@@ -262,7 +263,7 @@ class Lark(Serialize):

assert self.options.priority in ('auto', None, 'normal', 'invert'), 'invalid priority option specified: {}. options are auto, none, normal, invert.'.format(self.options.priority)
assert self.options.ambiguity not in ('resolve__antiscore_sum', ), 'resolve__antiscore_sum has been replaced with the option priority="invert"'
assert self.options.ambiguity in ('resolve', 'explicit', 'auto', )
assert self.options.ambiguity in ('resolve', 'explicit', 'forest', 'auto', )

# Parse the grammar file and compose the grammars (TODO)
self.grammar = load_grammar(grammar, self.source, re_module)
@@ -317,8 +318,11 @@ class Lark(Serialize):

def _prepare_callbacks(self):
self.parser_class = get_frontend(self.options.parser, self.options.lexer)
self._parse_tree_builder = ParseTreeBuilder(self.rules, self.options.tree_class or Tree, self.options.propagate_positions, self.options.keep_all_tokens, self.options.parser!='lalr' and self.options.ambiguity=='explicit', self.options.maybe_placeholders)
self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer)
self._callbacks = None
# we don't need these callbacks if we aren't building a tree
if self.options.ambiguity != 'forest':
self._parse_tree_builder = ParseTreeBuilder(self.rules, self.options.tree_class or Tree, self.options.propagate_positions, self.options.keep_all_tokens, self.options.parser!='lalr' and self.options.ambiguity=='explicit', self.options.maybe_placeholders)
self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer)

def _build_parser(self):
self._prepare_callbacks()


+ 4
- 1
lark/parser_frontends.py View File

@@ -165,7 +165,8 @@ class Earley(WithLexer):

resolve_ambiguity = options.ambiguity == 'resolve'
debug = options.debug if options else False
self.parser = earley.Parser(parser_conf, self.match, resolve_ambiguity=resolve_ambiguity, debug=debug)
tree_class = options.tree_class or Tree if options.ambiguity != 'forest' else None
self.parser = earley.Parser(parser_conf, self.match, resolve_ambiguity=resolve_ambiguity, debug=debug, tree_class=tree_class)

def match(self, term, token):
return term.name == token.type
@@ -179,11 +180,13 @@ class XEarley(_ParserFrontend):
self._prepare_match(lexer_conf)
resolve_ambiguity = options.ambiguity == 'resolve'
debug = options.debug if options else False
tree_class = options.tree_class or Tree if options.ambiguity != 'forest' else None
self.parser = xearley.Parser(parser_conf,
self.match,
ignore=lexer_conf.ignore,
resolve_ambiguity=resolve_ambiguity,
debug=debug,
tree_class=tree_class,
**kw
)



+ 10
- 7
lark/parsers/earley.py View File

@@ -12,20 +12,22 @@ is better documented here:

from collections import deque

from ..tree import Tree
from ..visitors import Transformer_InPlace, v_args
from ..exceptions import UnexpectedEOF, UnexpectedToken
from ..utils import logger
from .grammar_analysis import GrammarAnalyzer
from ..grammar import NonTerminal
from .earley_common import Item, TransitiveItem
from .earley_forest import ForestToTreeVisitor, ForestSumVisitor, SymbolNode, CompleteForestToAmbiguousTreeVisitor
from .earley_forest import ForestSumVisitor, SymbolNode, ForestToParseTree

class Parser:
def __init__(self, parser_conf, term_matcher, resolve_ambiguity=True, debug=False):
def __init__(self, parser_conf, term_matcher, resolve_ambiguity=True, debug=False, tree_class=Tree):
analysis = GrammarAnalyzer(parser_conf)
self.parser_conf = parser_conf
self.resolve_ambiguity = resolve_ambiguity
self.debug = debug
self.tree_class = tree_class

self.FIRST = analysis.FIRST
self.NULLABLE = analysis.NULLABLE
@@ -312,12 +314,13 @@ class Parser:
elif len(solutions) > 1:
assert False, 'Earley should not generate multiple start symbol items!'

# Perform our SPPF -> AST conversion using the right ForestVisitor.
forest_tree_visitor_cls = ForestToTreeVisitor if self.resolve_ambiguity else CompleteForestToAmbiguousTreeVisitor
forest_tree_visitor = forest_tree_visitor_cls(self.callbacks, self.forest_sum_visitor and self.forest_sum_visitor())

return forest_tree_visitor.visit(solutions[0])
if self.tree_class is not None:
# Perform our SPPF -> AST conversion
transformer = ForestToParseTree(self.tree_class, self.callbacks, self.forest_sum_visitor and self.forest_sum_visitor(), self.resolve_ambiguity)
return transformer.transform(solutions[0])

# return the root of the SPPF
return solutions[0]

class ApplyCallbacks(Transformer_InPlace):
def __init__(self, postprocess):


+ 329
- 161
lark/parsers/earley_forest.py View File

@@ -12,10 +12,13 @@ from math import isinf
from collections import deque
from operator import attrgetter
from importlib import import_module
from functools import partial

from ..parse_tree_builder import AmbiguousIntermediateExpander
from ..visitors import Discard
from ..lexer import Token
from ..utils import logger
from ..tree import Tree
from ..exceptions import ParseError

class ForestNode(object):
pass
@@ -125,8 +128,13 @@ class PackedNode(ForestNode):
"""
return self.is_empty, -self.priority, self.rule.order

@property
def children(self):
return [x for x in [self.left, self.right] if x is not None]

def __iter__(self):
return iter([self.left, self.right])
yield self.left
yield self.right

def __eq__(self, other):
if not isinstance(other, PackedNode):
@@ -153,7 +161,12 @@ class ForestVisitor(object):

Use this as a base when you need to walk the forest.
"""
__slots__ = ['result']

def get_cycle_in_path(self, node, path):
index = len(path) - 1
while id(path[index]) != id(node):
index -= 1
return path[index:]

def visit_token_node(self, node): pass
def visit_symbol_node_in(self, node): pass
@@ -161,8 +174,18 @@ class ForestVisitor(object):
def visit_packed_node_in(self, node): pass
def visit_packed_node_out(self, node): pass

def on_cycle(self, node, path):
"""Called when a cycle is encountered. `node` is the node that causes
the cycle. `path` the list of nodes being visited: nodes that have been
entered but not exited. The first element is the root in a forest
visit, and the last element is the node visited most recently.
`path` should be treated as read-only. The utility function
`get_cycle_in_path` may be used to obtain a slice of `path` that only
contains the nodes that make up the cycle."""
pass

def visit(self, root):
self.result = None
# Visiting is a list of IDs of all symbol/intermediate nodes currently in
# the stack. It serves two purposes: to detect when we 'recurse' in and out
# of a symbol/intermediate so that we can process both up and down. Also,
@@ -170,6 +193,10 @@ class ForestVisitor(object):
# to recurse into a node that's already on the stack (infinite recursion).
visiting = set()

# a list of nodes that are currently being visited
# used for the `on_cycle` callback
path = []

# We do not use recursion here to walk the Forest due to the limited
# stack size in python. Therefore input_stack is essentially our stack.
input_stack = deque([root])
@@ -180,7 +207,11 @@ class ForestVisitor(object):
vpni = getattr(self, 'visit_packed_node_in')
vsno = getattr(self, 'visit_symbol_node_out')
vsni = getattr(self, 'visit_symbol_node_in')
vino = getattr(self, 'visit_intermediate_node_out', vsno)
vini = getattr(self, 'visit_intermediate_node_in', vsni)
vtn = getattr(self, 'visit_token_node')
oc = getattr(self, 'on_cycle')

while input_stack:
current = next(reversed(input_stack))
try:
@@ -196,7 +227,8 @@ class ForestVisitor(object):
continue

if id(next_node) in visiting:
raise ParseError("Infinite recursion in grammar, in rule '%s'!" % next_node.s.name)
oc(next_node, path)
continue

input_stack.append(next_node)
continue
@@ -208,25 +240,134 @@ class ForestVisitor(object):

current_id = id(current)
if current_id in visiting:
if isinstance(current, PackedNode): vpno(current)
else: vsno(current)
if isinstance(current, PackedNode):
vpno(current)
elif current.is_intermediate:
vino(current)
else:
vsno(current)
input_stack.pop()
path.pop()
visiting.remove(current_id)
continue
else:
visiting.add(current_id)
if isinstance(current, PackedNode): next_node = vpni(current)
else: next_node = vsni(current)
path.append(current)
if isinstance(current, PackedNode):
next_node = vpni(current)
elif current.is_intermediate:
next_node = vini(current)
else:
next_node = vsni(current)
if next_node is None:
continue

if id(next_node) in visiting:
raise ParseError("Infinite recursion in grammar!")
if not isinstance(next_node, ForestNode) and \
not isinstance(next_node, Token):
next_node = iter(next_node)
elif id(next_node) in visiting:
oc(next_node, path)
continue

input_stack.append(next_node)
continue

return self.result
class ForestTransformer(ForestVisitor):
"""The base class for a bottom-up forest transformation.
Transformations are applied via inheritance and overriding of the
following methods:

transform_symbol_node
transform_intermediate_node
transform_packed_node
transform_token_node

`transform_token_node` receives a Token as an argument.
All other methods receive the node that is being transformed and
a list of the results of the transformations of that node's children.
The return value of these methods are the resulting transformations.

If `Discard` is raised in a transformation, no data from that node
will be passed to its parent's transformation.
"""

def __init__(self):
# results of transformations
self.data = dict()
# used to track parent nodes
self.node_stack = deque()

def transform(self, root):
"""Perform a transformation on a Forest."""
self.node_stack.append('result')
self.data['result'] = []
self.visit(root)
assert len(self.data['result']) <= 1
if self.data['result']:
return self.data['result'][0]

def transform_symbol_node(self, node, data):
return node

def transform_intermediate_node(self, node, data):
return node

def transform_packed_node(self, node, data):
return node

def transform_token_node(self, node):
return node

def visit_symbol_node_in(self, node):
self.node_stack.append(id(node))
self.data[id(node)] = []
return node.children

def visit_packed_node_in(self, node):
self.node_stack.append(id(node))
self.data[id(node)] = []
return node.children

def visit_token_node(self, node):
try:
transformed = self.transform_token_node(node)
except Discard:
pass
else:
self.data[self.node_stack[-1]].append(transformed)

def visit_symbol_node_out(self, node):
self.node_stack.pop()
try:
transformed = self.transform_symbol_node(node, self.data[id(node)])
except Discard:
pass
else:
self.data[self.node_stack[-1]].append(transformed)
finally:
del self.data[id(node)]

def visit_intermediate_node_out(self, node):
self.node_stack.pop()
try:
transformed = self.transform_intermediate_node(node, self.data[id(node)])
except Discard:
pass
else:
self.data[self.node_stack[-1]].append(transformed)
finally:
del self.data[id(node)]

def visit_packed_node_out(self, node):
self.node_stack.pop()
try:
transformed = self.transform_packed_node(node, self.data[id(node)])
except Discard:
pass
else:
self.data[self.node_stack[-1]].append(transformed)
finally:
del self.data[id(node)]

class ForestSumVisitor(ForestVisitor):
"""
@@ -245,7 +386,8 @@ class ForestSumVisitor(ForestVisitor):
final tree.
"""
def visit_packed_node_in(self, node):
return iter([node.left, node.right])
yield node.left
yield node.right

def visit_symbol_node_in(self, node):
return iter(node.children)
@@ -259,178 +401,203 @@ class ForestSumVisitor(ForestVisitor):
def visit_symbol_node_out(self, node):
node.priority = max(child.priority for child in node.children)

class ForestToTreeVisitor(ForestVisitor):
class PackedData():
"""Used in transformationss of packed nodes to distinguish the data
that comes from the left child and the right child.
"""
A Forest visitor which converts an SPPF forest to an unambiguous AST.

The implementation in this visitor walks only the first ambiguous child
of each symbol node. When it finds an ambiguous symbol node it first
calls the forest_sum_visitor implementation to sort the children
into preference order using the algorithms defined there; so the first
child should always be the highest preference. The forest_sum_visitor
implementation should be another ForestVisitor which sorts the children
according to some priority mechanism.

def __init__(self, node, data):
self.left = None
self.right = None
if data:
if node.left:
self.left = data[0]
if len(data) > 1 and node.right:
self.right = data[1]
elif node.right:
self.right = data[0]

class ForestToParseTree(ForestTransformer):
"""Used by the earley parser when ambiguity equals 'resolve' or
'explicit'. Transforms an SPPF into an (ambiguous) parse tree.

tree_class: The Tree class to use for construction
callbacks: A dictionary of rules to functions that output a tree
prioritizer: A ForestVisitor that manipulates the priorities of
ForestNodes
resolve_ambiguity: If True, ambiguities will be resolved based on
priorities. Otherwise, `_ambig` nodes will be in the resulting
tree.
"""
__slots__ = ['forest_sum_visitor', 'callbacks', 'output_stack']
def __init__(self, callbacks, forest_sum_visitor = None):
assert callbacks
self.forest_sum_visitor = forest_sum_visitor
def __init__(self, tree_class=Tree, callbacks=dict(), prioritizer=ForestSumVisitor(), resolve_ambiguity=True):
super(ForestToParseTree, self).__init__()
self.tree_class = tree_class
self.callbacks = callbacks
self.prioritizer = prioritizer
self.resolve_ambiguity = resolve_ambiguity
self._on_cycle_retreat = False

def on_cycle(self, node, path):
logger.warning("Cycle encountered in the SPPF at node: %s. "
"As infinite ambiguities cannot be represented in a tree, "
"this family of derivations will be discarded.", node)
if self.resolve_ambiguity:
# TODO: choose a different path if cycle is encountered
logger.warning("At this time, using ambiguity resolution for SPPFs "
"with cycles may result in None being returned.")
self._on_cycle_retreat = True

def _check_cycle(self, node):
if self._on_cycle_retreat:
raise Discard()

def visit(self, root):
self.output_stack = deque()
return super(ForestToTreeVisitor, self).visit(root)
def _collapse_ambig(self, children):
new_children = []
for child in children:
if hasattr(child, 'data') and child.data == '_ambig':
new_children += child.children
else:
new_children.append(child)
return new_children

def visit_token_node(self, node):
self.output_stack[-1].append(node)
def _call_rule_func(self, node, data):
# called when transforming children of symbol nodes
# data is a list of trees or tokens that correspond to the
# symbol's rule expansion
return self.callbacks[node.rule](data)

def _call_ambig_func(self, node, data):
# called when transforming a symbol node
# data is a list of trees where each tree's data is
# equal to the name of the symbol or one of its aliases.
if len(data) > 1:
return self.tree_class('_ambig', data)
elif data:
return data[0]
raise Discard()

def transform_symbol_node(self, node, data):
self._check_cycle(node)
data = self._collapse_ambig(data)
return self._call_ambig_func(node, data)

def transform_intermediate_node(self, node, data):
self._check_cycle(node)
if len(data) > 1:
children = [self.tree_class('_inter', c) for c in data]
return self.tree_class('_iambig', children)
return data[0]

def transform_packed_node(self, node, data):
self._check_cycle(node)
children = []
assert len(data) <= 2
data = PackedData(node, data)
if data.left is not None:
if node.left.is_intermediate and isinstance(data.left, list):
children += data.left
else:
children.append(data.left)
if data.right is not None:
children.append(data.right)
if node.parent.is_intermediate:
return children
return self._call_rule_func(node, children)

def visit_symbol_node_in(self, node):
if self.forest_sum_visitor and node.is_ambiguous and isinf(node.priority):
self.forest_sum_visitor.visit(node)
return next(iter(node.children))
self._on_cycle_retreat = False
super(ForestToParseTree, self).visit_symbol_node_in(node)
if self.prioritizer and node.is_ambiguous and isinf(node.priority):
self.prioritizer.visit(node)
if self.resolve_ambiguity:
return node.children[0]
return node.children

def visit_packed_node_in(self, node):
if not node.parent.is_intermediate:
self.output_stack.append([])
return iter([node.left, node.right])

def visit_packed_node_out(self, node):
if not node.parent.is_intermediate:
result = self.callbacks[node.rule](self.output_stack.pop())
if self.output_stack:
self.output_stack[-1].append(result)
else:
self.result = result

class ForestToAmbiguousTreeVisitor(ForestToTreeVisitor):
"""
A Forest visitor which converts an SPPF forest to an ambiguous AST.

Because of the fundamental disparity between what can be stored in
an SPPF and what can be stored in a Tree; this implementation is not
complete. It correctly deals with ambiguities that occur on symbol nodes only,
and cannot deal with ambiguities that occur on intermediate nodes.

Usually, most parsers can be rewritten to avoid intermediate node
ambiguities. Also, this implementation could be fixed, however
the code to handle intermediate node ambiguities is messy and
would not be performant. It is much better not to use this and
instead to correctly disambiguate the forest and only store unambiguous
parses in Trees. It is here just to provide some parity with the
old ambiguity='explicit'.

This is mainly used by the test framework, to make it simpler to write
tests ensuring the SPPF contains the right results.
"""
def __init__(self, callbacks, forest_sum_visitor = ForestSumVisitor):
super(ForestToAmbiguousTreeVisitor, self).__init__(callbacks, forest_sum_visitor)
self._on_cycle_retreat = False
return super(ForestToParseTree, self).visit_packed_node_in(node)

def visit_token_node(self, node):
self.output_stack[-1].children.append(node)
self._on_cycle_retreat = False
return super(ForestToParseTree, self).visit_token_node(node)

def visit_symbol_node_in(self, node):
if node.is_ambiguous:
if self.forest_sum_visitor and isinf(node.priority):
self.forest_sum_visitor.visit(node)
if node.is_intermediate:
# TODO Support ambiguous intermediate nodes!
logger.warning("Ambiguous intermediate node in the SPPF: %s. "
"Lark does not currently process these ambiguities; resolving with the first derivation.", node)
return next(iter(node.children))
else:
self.output_stack.append(Tree('_ambig', []))

return iter(node.children)

def visit_symbol_node_out(self, node):
if not node.is_intermediate and node.is_ambiguous:
result = self.output_stack.pop()
if self.output_stack:
self.output_stack[-1].children.append(result)
else:
self.result = result
def handles_ambiguity(func):
"""Decorator for methods of subclasses of TreeForestTransformer.
Denotes that the method should receive a list of transformed derivations."""
func.handles_ambiguity = True
return func

def visit_packed_node_in(self, node):
if not node.parent.is_intermediate:
self.output_stack.append(Tree('drv', []))
return iter([node.left, node.right])
class TreeForestTransformer(ForestToParseTree):
"""A ForestTransformer with a tree-Transformer-like interface.
By default, it will construct a tree.

def visit_packed_node_out(self, node):
if not node.parent.is_intermediate:
result = self.callbacks[node.rule](self.output_stack.pop().children)
if self.output_stack:
self.output_stack[-1].children.append(result)
else:
self.result = result
Methods provided via inheritance are called based on the rule/symbol
names of nodes in the forest.

class CompleteForestToAmbiguousTreeVisitor(ForestToTreeVisitor):
"""
An augmented version of ForestToAmbiguousTreeVisitor that is designed to
handle ambiguous intermediate nodes as well as ambiguous symbol nodes.
Methods that act on rules will receive a list of the results of the
transformations of the rule's children. By default, trees and tokens.

On the way down:
Methods that act on tokens will receive a Token.

- When an ambiguous intermediate node is encountered, an '_iambig' node
is inserted into the tree.
- Each possible derivation of an ambiguous intermediate node is represented
by an '_inter' node added as a child of the corresponding '_iambig' node.
Alternatively, methods that act on rules may be annotated with
`handles_ambiguity`. In this case, the function will receive a list
of all the transformations of all the derivations of the rule.
By default, a list of trees where each tree.data is equal to the
rule name or one of its aliases.

On the way up, these nodes are propagated up the tree and collapsed
into a single '_ambig' node for the nearest symbol node ancestor.
This is achieved by the AmbiguousIntermediateExpander contained in
the callbacks.
Non-tree transformations are made possible by override of
`__default__`, `__default_token__`, and `__default_ambig__`.
"""

def _collapse_ambig(self, children):
new_children = []
for child in children:
if child.data == '_ambig':
new_children += child.children
else:
new_children.append(child)
return new_children
def __init__(self, tree_class=Tree, prioritizer=ForestSumVisitor(), resolve_ambiguity=True):
super(TreeForestTransformer, self).__init__(tree_class, dict(), prioritizer, resolve_ambiguity)

def visit_token_node(self, node):
self.output_stack[-1].children.append(node)
def __default__(self, name, data):
"""Default operation on tree (for override).

def visit_symbol_node_in(self, node):
if node.is_ambiguous:
if self.forest_sum_visitor and isinf(node.priority):
self.forest_sum_visitor.visit(node)
if node.is_intermediate:
self.output_stack.append(Tree('_iambig', []))
else:
self.output_stack.append(Tree('_ambig', []))
return iter(node.children)
Returns a tree with name with data as children.
"""
return self.tree_class(name, data)

def visit_symbol_node_out(self, node):
if node.is_ambiguous:
result = self.output_stack.pop()
if not node.is_intermediate:
result = Tree('_ambig', self._collapse_ambig(result.children))
if self.output_stack:
self.output_stack[-1].children.append(result)
else:
self.result = result
def __default_ambig__(self, name, data):
"""Default operation on ambiguous rule (for override).

def visit_packed_node_in(self, node):
if not node.parent.is_intermediate:
self.output_stack.append(Tree('drv', []))
elif node.parent.is_ambiguous:
self.output_stack.append(Tree('_inter', []))
return iter([node.left, node.right])
Wraps data in an '_ambig_ node if it contains more than
one element.'
"""
if len(data) > 1:
return self.tree_class('_ambig', data)
elif data:
return data[0]
raise Discard()

def visit_packed_node_out(self, node):
if not node.parent.is_intermediate:
result = self.callbacks[node.rule](self.output_stack.pop().children)
elif node.parent.is_ambiguous:
result = self.output_stack.pop()
else:
return
if self.output_stack:
self.output_stack[-1].children.append(result)
else:
self.result = result
def __default_token__(self, node):
"""Default operation on Token (for override).

Returns node
"""
return node

def transform_token_node(self, node):
return getattr(self, node.type, self.__default_token__)(node)

def _call_rule_func(self, node, data):
name = node.rule.alias or node.rule.options.template_source or node.rule.origin.name
user_func = getattr(self, name, self.__default__)
if user_func == self.__default__ or hasattr(user_func, 'handles_ambiguity'):
user_func = partial(self.__default__, name)
if not self.resolve_ambiguity:
wrapper = partial(AmbiguousIntermediateExpander, self.tree_class)
user_func = wrapper(user_func)
return user_func(data)

def _call_ambig_func(self, node, data):
name = node.s.name
user_func = getattr(self, name, self.__default_ambig__)
if user_func == self.__default_ambig__ or not hasattr(user_func, 'handles_ambiguity'):
user_func = partial(self.__default_ambig__, name)
return user_func(data)

class ForestToPyDotVisitor(ForestVisitor):
"""
@@ -466,7 +633,8 @@ class ForestToPyDotVisitor(ForestVisitor):
graph_node_shape = "diamond"
graph_node = self.pydot.Node(graph_node_id, style=graph_node_style, fillcolor="#{:06x}".format(graph_node_color), shape=graph_node_shape, label=graph_node_label)
self.graph.add_node(graph_node)
return iter([node.left, node.right])
yield node.left
yield node.right

def visit_packed_node_out(self, node):
graph_node_id = str(id(node))


+ 4
- 3
lark/parsers/xearley.py View File

@@ -16,6 +16,7 @@ Earley's power in parsing any CFG.

from collections import defaultdict

from ..tree import Tree
from ..exceptions import UnexpectedCharacters
from ..lexer import Token
from ..grammar import Terminal
@@ -24,8 +25,8 @@ from .earley_forest import SymbolNode


class Parser(BaseParser):
def __init__(self, parser_conf, term_matcher, resolve_ambiguity=True, ignore = (), complete_lex = False, debug=False):
BaseParser.__init__(self, parser_conf, term_matcher, resolve_ambiguity, debug)
def __init__(self, parser_conf, term_matcher, resolve_ambiguity=True, ignore = (), complete_lex = False, debug=False, tree_class=Tree):
BaseParser.__init__(self, parser_conf, term_matcher, resolve_ambiguity, debug, tree_class)
self.ignore = [Terminal(t) for t in ignore]
self.complete_lex = complete_lex

@@ -148,4 +149,4 @@ class Parser(BaseParser):

## Column is now the final column in the parse.
assert i == len(columns)-1
return to_scan
return to_scan

+ 228
- 0
tests/test_tree_forest_transformer.py View File

@@ -0,0 +1,228 @@
from __future__ import absolute_import

import unittest

from lark import Lark
from lark.lexer import Token
from lark.tree import Tree
from lark.visitors import Visitor, Transformer, Discard
from lark.parsers.earley_forest import TreeForestTransformer, handles_ambiguity

class TestTreeForestTransformer(unittest.TestCase):

grammar = """
start: ab bc cd
!ab: "A" "B"?
!bc: "B"? "C"?
!cd: "C"? "D"
"""

parser = Lark(grammar, parser='earley', ambiguity='forest')
forest = parser.parse("ABCD")

def test_identity_resolve_ambiguity(self):
l = Lark(self.grammar, parser='earley', ambiguity='resolve')
tree1 = l.parse("ABCD")
tree2 = TreeForestTransformer(resolve_ambiguity=True).transform(self.forest)
self.assertEqual(tree1, tree2)

def test_identity_explicit_ambiguity(self):
l = Lark(self.grammar, parser='earley', ambiguity='explicit')
tree1 = l.parse("ABCD")
tree2 = TreeForestTransformer(resolve_ambiguity=False).transform(self.forest)
self.assertEqual(tree1, tree2)

def test_tree_class(self):

class CustomTree(Tree):
pass

class TreeChecker(Visitor):
def __default__(self, tree):
assert isinstance(tree, CustomTree)

tree = TreeForestTransformer(resolve_ambiguity=False, tree_class=CustomTree).transform(self.forest)
TreeChecker().visit(tree)

def test_token_calls(self):

visited = [False] * 4

class CustomTransformer(TreeForestTransformer):
def A(self, node):
assert node.type == 'A'
visited[0] = True
def B(self, node):
assert node.type == 'B'
visited[1] = True
def C(self, node):
assert node.type == 'C'
visited[2] = True
def D(self, node):
assert node.type == 'D'
visited[3] = True

tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
assert visited == [True] * 4

def test_default_token(self):

token_count = [0]

class CustomTransformer(TreeForestTransformer):
def __default_token__(self, node):
token_count[0] += 1
assert isinstance(node, Token)

tree = CustomTransformer(resolve_ambiguity=True).transform(self.forest)
self.assertEqual(token_count[0], 4)

def test_rule_calls(self):

visited_start = [False]
visited_ab = [False]
visited_bc = [False]
visited_cd = [False]

class CustomTransformer(TreeForestTransformer):
def start(self, data):
visited_start[0] = True
def ab(self, data):
visited_ab[0] = True
def bc(self, data):
visited_bc[0] = True
def cd(self, data):
visited_cd[0] = True

tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
self.assertTrue(visited_start[0])
self.assertTrue(visited_ab[0])
self.assertTrue(visited_bc[0])
self.assertTrue(visited_cd[0])

def test_default_rule(self):

rule_count = [0]

class CustomTransformer(TreeForestTransformer):
def __default__(self, name, data):
rule_count[0] += 1

tree = CustomTransformer(resolve_ambiguity=True).transform(self.forest)
self.assertEqual(rule_count[0], 4)

def test_default_ambig(self):

ambig_count = [0]

class CustomTransformer(TreeForestTransformer):
def __default_ambig__(self, name, data):
if len(data) > 1:
ambig_count[0] += 1

tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
self.assertEqual(ambig_count[0], 1)

def test_handles_ambiguity(self):

class CustomTransformer(TreeForestTransformer):
@handles_ambiguity
def start(self, data):
assert isinstance(data, list)
assert len(data) == 4
for tree in data:
assert tree.data == 'start'
return 'handled'

@handles_ambiguity
def ab(self, data):
assert isinstance(data, list)
assert len(data) == 1
assert data[0].data == 'ab'

tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
self.assertEqual(tree, 'handled')

def test_discard(self):

class CustomTransformer(TreeForestTransformer):
def bc(self, data):
raise Discard()

def D(self, node):
raise Discard()

class TreeChecker(Transformer):
def bc(self, children):
assert False

def D(self, token):
assert False

tree = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
TreeChecker(visit_tokens=True).transform(tree)

def test_aliases(self):

visited_ambiguous = [False]
visited_full = [False]

class CustomTransformer(TreeForestTransformer):
@handles_ambiguity
def start(self, data):
for tree in data:
assert tree.data == 'ambiguous' or tree.data == 'full'

def ambiguous(self, data):
visited_ambiguous[0] = True
assert len(data) == 3
assert data[0].data == 'ab'
assert data[1].data == 'bc'
assert data[2].data == 'cd'
return self.tree_class('ambiguous', data)

def full(self, data):
visited_full[0] = True
assert len(data) == 1
assert data[0].data == 'abcd'
return self.tree_class('full', data)

grammar = """
start: ab bc cd -> ambiguous
| abcd -> full
!ab: "A" "B"?
!bc: "B"? "C"?
!cd: "C"? "D"
!abcd: "ABCD"
"""

l = Lark(grammar, parser='earley', ambiguity='forest')
forest = l.parse('ABCD')
tree = CustomTransformer(resolve_ambiguity=False).transform(forest)
self.assertTrue(visited_ambiguous[0])
self.assertTrue(visited_full[0])

def test_transformation(self):

class CustomTransformer(TreeForestTransformer):
def __default__(self, name, data):
result = []
for item in data:
if isinstance(item, list):
result += item
else:
result.append(item)
return result

def __default_token__(self, node):
return node.lower()

def __default_ambig__(self, name, data):
return data[0]

result = CustomTransformer(resolve_ambiguity=False).transform(self.forest)
expected = ['a', 'b', 'c', 'd']
self.assertEqual(result, expected)

if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save