diff --git a/tests/test_trees.py b/tests/test_trees.py index 730b80b..82bf6c9 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -264,10 +264,7 @@ class TestTrees(TestCase): main = sum start = list def module__main(self, children): - prod = 1 - for child in children: - prod *= child - return prod + return sum(children) class T2(Transformer): A = int @@ -277,25 +274,21 @@ class TestTrees(TestCase): class T3(Transformer): def main(self, children): - prod = 1 - for child in children: - prod *= child - return prod + return sum(children) class T4(Transformer): - def other_aspect(self, children): - pass + main = sum + t1_res = T1().transform(tree) composed_res = merge_transformers(T2(), module=T3()).transform(tree) self.assertEqual(t1_res, composed_res) + + composed_res2 = merge_transformers(T2(), module=T4()).transform(tree) + self.assertEqual(t1_res, composed_res2) + with self.assertRaises(AttributeError): merge_transformers(T1(), module=T3()) - try: - composed = merge_transformers(T1(), module=T4()) - except AttributeError: - self.fail("Should be able to add classes that do not conflict") - if __name__ == '__main__': unittest.main()