88"""
99
1010from kirin import rewrite
11- from kirin .prelude import python_basic
11+ from kirin .prelude import structural , python_basic
1212from kirin .dialects import py , scf , func , ilist , lowering
1313
1414basic_scf = python_basic .union (
@@ -27,8 +27,13 @@ def main():
2727 curr = prev + i + 1
2828 return curr
2929
30+ expected_return_val = main .py_func ()
31+ assert expected_return_val == 6
32+
3033 rewrite .Walk (scf .trim .UnusedYield ()).rewrite (main .code )
31- assert main () == main .py_func () # 6
34+ actual_return_val = main ()
35+
36+ assert actual_return_val == expected_return_val
3237
3338
3439def test_trim_prev_curr_unused_after_loop ():
@@ -60,10 +65,34 @@ def main():
6065 last_prev = prev
6166 return last_prev
6267
63- expected_return_val = main .py_func ()
64- assert expected_return_val == 3
68+ expected = main .py_func ()
69+ assert expected == 3
6570
6671 rewrite .Walk (scf .trim .UnusedYield ()).rewrite (main .code )
67- actual_return_val = main ()
72+ actual = main ()
73+
74+ assert actual == expected
75+
76+
77+ def test_trim_with_lists ():
78+
79+ @structural (fold = False , typeinfer = True )
80+ def mwe ():
81+
82+ result = 0
83+ start = ilist .IList ([0 ])
84+ stop = ilist .IList ([1 ])
85+ for _ in range (10 ):
86+ result = start [0 ] + stop [0 ]
87+ start = stop
88+ stop = ilist .IList ([result ])
89+
90+ return result
91+
92+ expected_return_val = mwe .py_func ()
93+ assert expected_return_val == 89
94+
95+ rewrite .Walk (scf .trim .UnusedYield ()).rewrite (mwe .code )
96+ actual_return_val = mwe ()
6897
6998 assert actual_return_val == expected_return_val
0 commit comments