@@ -1752,6 +1752,86 @@ def test_two_rolls(ctx_factory):
17521752 np .testing .assert_allclose (np_out , pt_out )
17531753
17541754
1755+ def test_function_call (ctx_factory ):
1756+ cl_ctx = ctx_factory ()
1757+ cq = cl .CommandQueue (cl_ctx )
1758+
1759+ def f (x ):
1760+ return 2 * x
1761+
1762+ def g (x ):
1763+ return 2 * x , 3 * x
1764+
1765+ def h (x , y ):
1766+ return {"twice" : 2 * x + y , "thrice" : 3 * x + y }
1767+
1768+ def build_expression (tracer ):
1769+ x = pt .arange (500 , dtype = np .float32 )
1770+ twice_x = tracer (f , x )
1771+ twice_x_2 , thrice_x_2 = tracer (g , x )
1772+
1773+ result = tracer (h , x , 2 * x )
1774+ twice_x_3 = result ["twice" ]
1775+ thrice_x_3 = result ["thrice" ]
1776+
1777+ return {"foo" : 3.14 + twice_x_3 ,
1778+ "bar" : 4 * thrice_x_3 ,
1779+ "baz" : 65 * twice_x ,
1780+ "quux" : 7 * twice_x_2 }
1781+
1782+ result1 = pt .tag_all_calls_to_be_inlined (
1783+ pt .make_dict_of_named_arrays (build_expression (pt .trace_call )))
1784+ result2 = pt .make_dict_of_named_arrays (
1785+ build_expression (lambda fn , * args : fn (* args )))
1786+
1787+ _ , outputs = pt .generate_loopy (result1 )(cq , out_host = True )
1788+ _ , expected = pt .generate_loopy (result2 )(cq , out_host = True )
1789+
1790+ assert len (outputs ) == len (expected )
1791+
1792+ for key in outputs .keys ():
1793+ np .testing .assert_allclose (outputs [key ], expected [key ])
1794+
1795+
1796+ def test_nested_function_calls (ctx_factory ):
1797+ from functools import partial
1798+
1799+ ctx = ctx_factory ()
1800+ cq = cl .CommandQueue (ctx )
1801+
1802+ rng = np .random .default_rng (0 )
1803+ ref_tracer = lambda f , * args , identifier : f (* args ) # noqa: E731
1804+
1805+ def foo (tracer , x , y ):
1806+ return 2 * x + 3 * y
1807+
1808+ def bar (tracer , x , y ):
1809+ foo_x_y = tracer (partial (foo , tracer ), x , y , identifier = "foo" )
1810+ return foo_x_y * x * y
1811+
1812+ def call_bar (tracer , x , y ):
1813+ return tracer (partial (bar , tracer ), x , y , identifier = "bar" )
1814+
1815+ x1_np , y1_np = rng .random ((2 , 13 , 29 ))
1816+ x2_np , y2_np = rng .random ((2 , 4 , 29 ))
1817+ x1 , y1 = pt .make_data_wrapper (x1_np ), pt .make_data_wrapper (y1_np )
1818+ x2 , y2 = pt .make_data_wrapper (x2_np ), pt .make_data_wrapper (y2_np )
1819+ result = pt .make_dict_of_named_arrays ({"out1" : call_bar (pt .trace_call , x1 , y1 ),
1820+ "out2" : call_bar (pt .trace_call , x2 , y2 )}
1821+ )
1822+ result = pt .tag_all_calls_to_be_inlined (result )
1823+ expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
1824+ "out2" : call_bar (ref_tracer , x2 , y2 )}
1825+ )
1826+
1827+ _ , result_out = pt .generate_loopy (result )(cq )
1828+ _ , expect_out = pt .generate_loopy (expect )(cq )
1829+
1830+ assert result_out .keys () == expect_out .keys ()
1831+ for k in expect_out :
1832+ np .testing .assert_allclose (result_out [k ], expect_out [k ])
1833+
1834+
17551835if __name__ == "__main__" :
17561836 if len (sys .argv ) > 1 :
17571837 exec (sys .argv [1 ])
0 commit comments