diff --git a/lark-stubs/visitors.pyi b/lark-stubs/visitors.pyi index 86f15d8..3a934ee 100644 --- a/lark-stubs/visitors.pyi +++ b/lark-stubs/visitors.pyi @@ -9,6 +9,7 @@ _R = TypeVar('_R') _FUNC = Callable[..., _T] _DECORATED = Union[_FUNC, type] + class Transformer(ABC, Generic[_T]): def __init__(self, visit_tokens: bool = True) -> None: @@ -38,6 +39,14 @@ class Transformer_InPlace(Transformer): pass +class Transformer_NonRecursive(Transformer): + pass + + +class Transformer_InPlaceRecursive(Transformer): + pass + + class VisitorBase: pass @@ -73,10 +82,10 @@ _InterMethod = Callable[[Type[Interpreter], _T], _R] def v_args( - inline: bool = False, - meta: bool = False, - tree: bool = False, - wrapper: Callable = None + inline: bool = False, + meta: bool = False, + tree: bool = False, + wrapper: Callable = None ) -> Callable[[_DECORATED], _DECORATED]: ... diff --git a/lark/visitors.py b/lark/visitors.py index 7e3bae4..44f0c99 100644 --- a/lark/visitors.py +++ b/lark/visitors.py @@ -218,6 +218,8 @@ class Transformer_NonRecursive(Transformer): else: args = [] stack.append(self._call_userfunc(x, args)) + elif self.__visit_tokens__ and isinstance(x, Token): + stack.append(self._call_userfunc_token(x)) else: stack.append(x) diff --git a/tests/test_trees.py b/tests/test_trees.py index 905ad5a..c7f9787 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -8,7 +8,8 @@ import functools from lark.tree import Tree from lark.lexer import Token -from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard +from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard, Transformer_InPlace, \ + Transformer_InPlaceRecursive, Transformer_NonRecursive class TestTrees(TestCase): @@ -232,6 +233,20 @@ class TestTrees(TestCase): x = MyTransformer().transform( t ) self.assertEqual(x, t2) + + def test_transformer_variants(self): + tree = Tree('start', [Tree('add', [Token('N', '1'), Token('N', '2')]), Tree('add', [Token('N', '3'), Token('N', '4')])]) + for base in (Transformer, Transformer_InPlace, Transformer_NonRecursive, Transformer_InPlaceRecursive): + class T(base): + def add(self, children): + return sum(children) + + def N(self, token): + return int(token) + + copied = copy.deepcopy(tree) + result = T().transform(copied) + self.assertEqual(result, Tree('start', [3, 7])) if __name__ == '__main__':