@@ -2,7 +2,7 @@ | |||||
# This example shows how to write a basic calculator with variables. | # This example shows how to write a basic calculator with variables. | ||||
# | # | ||||
from lark import Lark, InlineTransformer | |||||
from lark import Lark, Transformer, children_args_inline | |||||
try: | try: | ||||
input = raw_input # For Python2 compatibility | input = raw_input # For Python2 compatibility | ||||
@@ -34,7 +34,7 @@ calc_grammar = """ | |||||
%ignore WS_INLINE | %ignore WS_INLINE | ||||
""" | """ | ||||
class CalculateTree(InlineTransformer): | |||||
class CalculateTree(SimpleTransformer): | |||||
from operator import add, sub, mul, truediv as div, neg | from operator import add, sub, mul, truediv as div, neg | ||||
number = float | number = float | ||||
@@ -21,11 +21,10 @@ from ..grammar import NonTerminal | |||||
class Derivation(Tree): | class Derivation(Tree): | ||||
_hash = None | |||||
def __init__(self, rule, items=None): | def __init__(self, rule, items=None): | ||||
Tree.__init__(self, 'drv', items or []) | Tree.__init__(self, 'drv', items or []) | ||||
self.rule = rule | |||||
self.meta.rule = rule | |||||
self._hash = None | |||||
def _pretty_label(self): # Nicer pretty for debugging the parser | def _pretty_label(self): # Nicer pretty for debugging the parser | ||||
return self.rule.origin if self.rule else self.data | return self.rule.origin if self.rule else self.data | ||||
@@ -236,4 +235,4 @@ class ApplyCallbacks(Transformer_InPlace): | |||||
self.postprocess = postprocess | self.postprocess = postprocess | ||||
def drv(self, tree): | def drv(self, tree): | ||||
return self.postprocess[tree.rule](tree.children) | |||||
return self.postprocess[tree.meta.rule](tree.children) |
@@ -16,7 +16,7 @@ def _sum_priority(tree): | |||||
for n in tree.iter_subtrees(): | for n in tree.iter_subtrees(): | ||||
try: | try: | ||||
p += n.rule.options.priority or 0 | |||||
p += n.meta.rule.options.priority or 0 | |||||
except AttributeError: | except AttributeError: | ||||
pass | pass | ||||
@@ -26,8 +26,8 @@ def _compare_priority(tree1, tree2): | |||||
tree1.iter_subtrees() | tree1.iter_subtrees() | ||||
def _compare_drv(tree1, tree2): | def _compare_drv(tree1, tree2): | ||||
rule1 = getattr(tree1, 'rule', None) | |||||
rule2 = getattr(tree2, 'rule', None) | |||||
rule1 = getattr(tree1.meta, 'rule', None) | |||||
rule2 = getattr(tree2.meta, 'rule', None) | |||||
if None == rule1 == rule2: | if None == rule1 == rule2: | ||||
return compare(tree1, tree2) | return compare(tree1, tree2) | ||||
@@ -45,7 +45,7 @@ def _compare_drv(tree1, tree2): | |||||
if c: | if c: | ||||
return c | return c | ||||
c = _compare_rules(tree1.rule, tree2.rule) | |||||
c = _compare_rules(tree1.meta.rule, tree2.meta.rule) | |||||
if c: | if c: | ||||
return c | return c | ||||
@@ -65,7 +65,7 @@ def _standard_resolve_ambig(tree): | |||||
best = max(tree.children, key=key_f) | best = max(tree.children, key=key_f) | ||||
assert best.data == 'drv' | assert best.data == 'drv' | ||||
tree.set('drv', best.children) | tree.set('drv', best.children) | ||||
tree.rule = best.rule # needed for applying callbacks | |||||
tree.meta.rule = best.meta.rule # needed for applying callbacks | |||||
def standard_resolve_ambig(tree): | def standard_resolve_ambig(tree): | ||||
for ambig in tree.find_data('_ambig'): | for ambig in tree.find_data('_ambig'): | ||||
@@ -93,7 +93,7 @@ def _antiscore_sum_resolve_ambig(tree): | |||||
best = min(tree.children, key=_antiscore_sum_drv) | best = min(tree.children, key=_antiscore_sum_drv) | ||||
assert best.data == 'drv' | assert best.data == 'drv' | ||||
tree.set('drv', best.children) | tree.set('drv', best.children) | ||||
tree.rule = best.rule # needed for applying callbacks | |||||
tree.meta.rule = best.meta.rule # needed for applying callbacks | |||||
def antiscore_sum_resolve_ambig(tree): | def antiscore_sum_resolve_ambig(tree): | ||||
for ambig in tree.find_data('_ambig'): | for ambig in tree.find_data('_ambig'): | ||||
@@ -1,4 +1,4 @@ | |||||
from inspect import isclass | |||||
from inspect import isclass, getmembers, getmro | |||||
from functools import wraps | from functools import wraps | ||||
from .utils import smart_decorator | from .utils import smart_decorator | ||||
@@ -16,6 +16,30 @@ class Base: | |||||
"Default operation on tree (for override)" | "Default operation on tree (for override)" | ||||
return tree | return tree | ||||
@classmethod | |||||
def _apply_decorator(cls, decorator): | |||||
mro = getmro(cls) | |||||
assert mro[0] is cls | |||||
libmembers = {name for _cls in mro[1:] for name, _ in getmembers(_cls)} | |||||
for name, value in getmembers(cls): | |||||
if name.startswith('_') or name in libmembers: | |||||
continue | |||||
setattr(cls, name, decorator(value)) | |||||
return cls | |||||
class SimpleBase(Base): | |||||
def _call_userfunc(self, tree): | |||||
# Assumes tree is already transformed | |||||
try: | |||||
f = getattr(self, tree.data) | |||||
except AttributeError: | |||||
return self.__default__(tree) | |||||
else: | |||||
return f(tree.children) | |||||
class Transformer(Base): | class Transformer(Base): | ||||
def _transform_children(self, children): | def _transform_children(self, children): | ||||
for c in children: | for c in children: | ||||
@@ -35,6 +59,7 @@ class Transformer(Base): | |||||
return TransformerChain(self, other) | return TransformerChain(self, other) | ||||
class TransformerChain(object): | class TransformerChain(object): | ||||
def __init__(self, *transformers): | def __init__(self, *transformers): | ||||
self.transformers = transformers | self.transformers = transformers | ||||
@@ -110,8 +135,22 @@ class Interpreter(object): | |||||
def _children_args__func(f): | |||||
@wraps(f) | |||||
def _apply_decorator(obj, decorator): | |||||
try: | |||||
_apply = obj._apply_decorator | |||||
except AttributeError: | |||||
return decorator(obj) | |||||
else: | |||||
return _apply(decorator) | |||||
def _children_args__func(func): | |||||
if getattr(func, '_children_args_decorated', False): | |||||
return func | |||||
@wraps(func) | |||||
def create_decorator(_f, with_self): | def create_decorator(_f, with_self): | ||||
if with_self: | if with_self: | ||||
def f(self, tree): | def f(self, tree): | ||||
@@ -119,55 +158,34 @@ def _children_args__func(f): | |||||
else: | else: | ||||
def f(args): | def f(args): | ||||
return _f(tree.children) | return _f(tree.children) | ||||
f._children_args_decorated = True | |||||
return f | return f | ||||
return smart_decorator(f, create_decorator) | |||||
def _children_args__class(cls): | |||||
def _call_userfunc(self, tree): | |||||
# Assumes tree is already transformed | |||||
try: | |||||
f = getattr(self, tree.data) | |||||
except AttributeError: | |||||
return self.__default__(tree) | |||||
else: | |||||
return f(tree.children) | |||||
cls._call_userfunc = _call_userfunc | |||||
return cls | |||||
return smart_decorator(func, create_decorator) | |||||
def children_args(obj): | def children_args(obj): | ||||
decorator = _children_args__class if isclass(obj) and issubclass(obj, Base) else _children_args__func | |||||
return decorator(obj) | |||||
return _apply_decorator(obj, _children_args__func) | |||||
def _children_args_inline__func(f): | |||||
@wraps(f) | |||||
def _children_args_inline__func(func): | |||||
if getattr(func, '_children_args_decorated', False): | |||||
return func | |||||
@wraps(func) | |||||
def create_decorator(_f, with_self): | def create_decorator(_f, with_self): | ||||
if with_self: | if with_self: | ||||
def f(self, tree): | def f(self, tree): | ||||
return _f(self, *tree.children) | return _f(self, *tree.children) | ||||
else: | else: | ||||
def f(args): | |||||
def f(self, tree): | |||||
print ('##', _f, tree) | |||||
return _f(*tree.children) | return _f(*tree.children) | ||||
f._children_args_decorated = True | |||||
return f | return f | ||||
return smart_decorator(f, create_decorator) | |||||
return smart_decorator(func, create_decorator) | |||||
def _children_args_inline__class(cls): | |||||
def _call_userfunc(self, tree): | |||||
# Assumes tree is already transformed | |||||
try: | |||||
f = getattr(self, tree.data) | |||||
except AttributeError: | |||||
return self.__default__(tree) | |||||
else: | |||||
return f(*tree.children) | |||||
cls._call_userfunc = _call_userfunc | |||||
return cls | |||||
def children_args_inline(obj): | def children_args_inline(obj): | ||||
decorator = _children_args_inline__class if isclass(obj) and issubclass(obj, Base) else _children_args_inline__func | |||||
return decorator(obj) | |||||
return _apply_decorator(obj, _children_args_inline__func) |
@@ -6,7 +6,7 @@ import copy | |||||
import pickle | import pickle | ||||
from lark.tree import Tree | from lark.tree import Tree | ||||
from lark.visitors import Interpreter, visit_children_decor | |||||
from lark.visitors import Transformer, Interpreter, visit_children_decor, children_args_inline, children_args | |||||
class TestTrees(TestCase): | class TestTrees(TestCase): | ||||
@@ -59,6 +59,58 @@ class TestTrees(TestCase): | |||||
self.assertEqual(Interp3().visit(t), list('BCd')) | self.assertEqual(Interp3().visit(t), list('BCd')) | ||||
def test_transformer(self): | |||||
t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) | |||||
class T(Transformer): | |||||
i = children_args_inline(int) | |||||
f = children_args_inline(float) | |||||
sub = lambda self, tree: tree.children[0] - tree.children[1] | |||||
def add(self, tree): | |||||
return sum(tree.children) | |||||
res = T().transform(t) | |||||
self.assertEqual(res, 2.9) | |||||
@children_args_inline | |||||
class T(Transformer): | |||||
i = int | |||||
f = float | |||||
sub = lambda self, a, b: a-b | |||||
def add(self, a, b): | |||||
return a + b | |||||
res = T().transform(t) | |||||
self.assertEqual(res, 2.9) | |||||
@children_args_inline | |||||
class T(Transformer): | |||||
i = int | |||||
f = float | |||||
from operator import sub, add | |||||
res = T().transform(t) | |||||
self.assertEqual(res, 2.9) | |||||
@children_args | |||||
class T(Transformer): | |||||
i = children_args_inline(int) | |||||
f = children_args_inline(float) | |||||
sub = lambda self, values: values[0] - values[1] | |||||
def add(self, values): | |||||
return sum(values) | |||||
res = T().transform(t) | |||||
self.assertEqual(res, 2.9) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||