This repo contains code to mirror other repos. It also contains the code that is getting mirrored.
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 

295 řádky
8.5 KiB

  1. from __future__ import absolute_import
  2. import unittest
  3. from unittest import TestCase
  4. import copy
  5. import pickle
  6. import functools
  7. from lark.tree import Tree
  8. from lark.lexer import Token
  9. from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard, Transformer_InPlace, \
  10. Transformer_InPlaceRecursive, Transformer_NonRecursive, merge_transformers
  11. class TestTrees(TestCase):
  12. def setUp(self):
  13. self.tree1 = Tree('a', [Tree(x, y) for x, y in zip('bcd', 'xyz')])
  14. def test_deepcopy(self):
  15. assert self.tree1 == copy.deepcopy(self.tree1)
  16. def test_pickle(self):
  17. s = copy.deepcopy(self.tree1)
  18. data = pickle.dumps(s, protocol=pickle.HIGHEST_PROTOCOL)
  19. assert pickle.loads(data) == s
  20. def test_repr_runnable(self):
  21. assert self.tree1 == eval(repr(self.tree1))
  22. def test_iter_subtrees(self):
  23. expected = [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z'),
  24. Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')])]
  25. nodes = list(self.tree1.iter_subtrees())
  26. self.assertEqual(nodes, expected)
  27. def test_iter_subtrees_topdown(self):
  28. expected = [Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]),
  29. Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]
  30. nodes = list(self.tree1.iter_subtrees_topdown())
  31. self.assertEqual(nodes, expected)
  32. def test_visitor(self):
  33. class Visitor1(Visitor):
  34. def __init__(self):
  35. self.nodes=[]
  36. def __default__(self,tree):
  37. self.nodes.append(tree)
  38. class Visitor1_Recursive(Visitor_Recursive):
  39. def __init__(self):
  40. self.nodes=[]
  41. def __default__(self,tree):
  42. self.nodes.append(tree)
  43. visitor1=Visitor1()
  44. visitor1_recursive=Visitor1_Recursive()
  45. expected_top_down = [Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]),
  46. Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')]
  47. expected_botton_up= [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z'),
  48. Tree('a', [Tree('b', 'x'), Tree('c', 'y'), Tree('d', 'z')])]
  49. visitor1.visit(self.tree1)
  50. self.assertEqual(visitor1.nodes,expected_botton_up)
  51. visitor1_recursive.visit(self.tree1)
  52. self.assertEqual(visitor1_recursive.nodes,expected_botton_up)
  53. visitor1.nodes=[]
  54. visitor1_recursive.nodes=[]
  55. visitor1.visit_topdown(self.tree1)
  56. self.assertEqual(visitor1.nodes,expected_top_down)
  57. visitor1_recursive.visit_topdown(self.tree1)
  58. self.assertEqual(visitor1_recursive.nodes,expected_top_down)
  59. def test_interp(self):
  60. t = Tree('a', [Tree('b', []), Tree('c', []), 'd'])
  61. class Interp1(Interpreter):
  62. def a(self, tree):
  63. return self.visit_children(tree) + ['e']
  64. def b(self, tree):
  65. return 'B'
  66. def c(self, tree):
  67. return 'C'
  68. self.assertEqual(Interp1().visit(t), list('BCde'))
  69. class Interp2(Interpreter):
  70. @visit_children_decor
  71. def a(self, values):
  72. return values + ['e']
  73. def b(self, tree):
  74. return 'B'
  75. def c(self, tree):
  76. return 'C'
  77. self.assertEqual(Interp2().visit(t), list('BCde'))
  78. class Interp3(Interpreter):
  79. def b(self, tree):
  80. return 'B'
  81. def c(self, tree):
  82. return 'C'
  83. self.assertEqual(Interp3().visit(t), list('BCd'))
  84. def test_transformer(self):
  85. t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])])
  86. class T(Transformer):
  87. i = v_args(inline=True)(int)
  88. f = v_args(inline=True)(float)
  89. sub = lambda self, values: values[0] - values[1]
  90. def add(self, values):
  91. return sum(values)
  92. res = T().transform(t)
  93. self.assertEqual(res, 2.9)
  94. @v_args(inline=True)
  95. class T(Transformer):
  96. i = int
  97. f = float
  98. sub = lambda self, a, b: a-b
  99. def add(self, a, b):
  100. return a + b
  101. res = T().transform(t)
  102. self.assertEqual(res, 2.9)
  103. @v_args(inline=True)
  104. class T(Transformer):
  105. i = int
  106. f = float
  107. from operator import sub, add
  108. res = T().transform(t)
  109. self.assertEqual(res, 2.9)
  110. def test_vargs(self):
  111. @v_args()
  112. class MyTransformer(Transformer):
  113. @staticmethod
  114. def integer(args):
  115. return 1 # some code here
  116. @classmethod
  117. def integer2(cls, args):
  118. return 2 # some code here
  119. hello = staticmethod(lambda args: 'hello')
  120. x = MyTransformer().transform( Tree('integer', [2]))
  121. self.assertEqual(x, 1)
  122. x = MyTransformer().transform( Tree('integer2', [2]))
  123. self.assertEqual(x, 2)
  124. x = MyTransformer().transform( Tree('hello', [2]))
  125. self.assertEqual(x, 'hello')
  126. def test_inline_static(self):
  127. @v_args(inline=True)
  128. class T(Transformer):
  129. @staticmethod
  130. def test(a, b):
  131. return a + b
  132. x = T().transform(Tree('test', ['a', 'b']))
  133. self.assertEqual(x, 'ab')
  134. def test_vargs_override(self):
  135. t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])])
  136. @v_args(inline=True)
  137. class T(Transformer):
  138. i = int
  139. f = float
  140. sub = lambda self, a, b: a-b
  141. not_a_method = {'other': 'stuff'}
  142. @v_args(inline=False)
  143. def add(self, values):
  144. return sum(values)
  145. res = T().transform(t)
  146. self.assertEqual(res, 2.9)
  147. def test_partial(self):
  148. tree = Tree("start", [Tree("a", ["test1"]), Tree("b", ["test2"])])
  149. def test(prefix, s, postfix):
  150. return prefix + s.upper() + postfix
  151. @v_args(inline=True)
  152. class T(Transformer):
  153. a = functools.partial(test, "@", postfix="!")
  154. b = functools.partial(lambda s: s + "!")
  155. res = T().transform(tree)
  156. assert res.children == ["@TEST1!", "test2!"]
  157. def test_discard(self):
  158. class MyTransformer(Transformer):
  159. def a(self, args):
  160. return 1 # some code here
  161. def b(cls, args):
  162. raise Discard()
  163. t = Tree('root', [
  164. Tree('b', []),
  165. Tree('a', []),
  166. Tree('b', []),
  167. Tree('c', []),
  168. Tree('b', []),
  169. ])
  170. t2 = Tree('root', [1, Tree('c', [])])
  171. x = MyTransformer().transform( t )
  172. self.assertEqual(x, t2)
  173. def test_transformer_variants(self):
  174. tree = Tree('start', [Tree('add', [Token('N', '1'), Token('N', '2')]), Tree('add', [Token('N', '3'), Token('N', '4')])])
  175. for base in (Transformer, Transformer_InPlace, Transformer_NonRecursive, Transformer_InPlaceRecursive):
  176. class T(base):
  177. def add(self, children):
  178. return sum(children)
  179. def N(self, token):
  180. return int(token)
  181. copied = copy.deepcopy(tree)
  182. result = T().transform(copied)
  183. self.assertEqual(result, Tree('start', [3, 7]))
  184. def test_merge_transformers(self):
  185. tree = Tree('start', [
  186. Tree('main', [
  187. Token("A", '1'), Token("B", '2')
  188. ]),
  189. Tree("module__main", [
  190. Token("A", "2"), Token("B", "3")
  191. ])
  192. ])
  193. class T1(Transformer):
  194. A = int
  195. B = int
  196. main = sum
  197. start = list
  198. def module__main(self, children):
  199. return sum(children)
  200. class T2(Transformer):
  201. A = int
  202. B = int
  203. main = sum
  204. start = list
  205. class T3(Transformer):
  206. def main(self, children):
  207. return sum(children)
  208. class T4(Transformer):
  209. main = sum
  210. t1_res = T1().transform(tree)
  211. composed_res = merge_transformers(T2(), module=T3()).transform(tree)
  212. self.assertEqual(t1_res, composed_res)
  213. composed_res2 = merge_transformers(T2(), module=T4()).transform(tree)
  214. self.assertEqual(t1_res, composed_res2)
  215. with self.assertRaises(AttributeError):
  216. merge_transformers(T1(), module=T3())
  217. if __name__ == '__main__':
  218. unittest.main()