@@ -1793,7 +1793,8 @@ def build_expression(tracer):
17931793 np .testing .assert_allclose (outputs [key ], expected [key ])
17941794
17951795
1796- def test_nested_function_calls (ctx_factory ):
1796+ @pytest .mark .parametrize ("should_concatenate_bar" , (False , True ))
1797+ def test_nested_function_calls (ctx_factory , should_concatenate_bar ):
17971798 from functools import partial
17981799
17991800 ctx = ctx_factory ()
@@ -1820,6 +1821,14 @@ def call_bar(tracer, x, y):
18201821 "out2" : call_bar (pt .trace_call , x2 , y2 )}
18211822 )
18221823 result = pt .tag_all_calls_to_be_inlined (result )
1824+ if should_concatenate_bar :
1825+ from pytato .transform .calls import CallsiteCollector
1826+ assert len (CallsiteCollector (())(result )) == 4
1827+ result = pt .concatenate_calls (
1828+ result ,
1829+ lambda x : pt .tags .FunctionIdentifier ("bar" ) in x .call .function .tags )
1830+ assert len (CallsiteCollector (())(result )) == 2
1831+
18231832 expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
18241833 "out2" : call_bar (ref_tracer , x2 , y2 )}
18251834 )
@@ -1832,6 +1841,109 @@ def call_bar(tracer, x, y):
18321841 np .testing .assert_allclose (result_out [k ], expect_out [k ])
18331842
18341843
1844+ def test_concatenate_calls_no_nested (ctx_factory ):
1845+ rng = np .random .default_rng (0 )
1846+
1847+ ctx = ctx_factory ()
1848+ cq = cl .CommandQueue (ctx )
1849+
1850+ def foo (x , y ):
1851+ return 3 * x + 4 * y + 42 * pt .sin (x ) + 1729 * pt .tan (y )* pt .maximum (x , y )
1852+
1853+ x1 = pt .make_placeholder ("x1" , (10 , 4 ), np .float64 )
1854+ x2 = pt .make_placeholder ("x2" , (10 , 4 ), np .float64 )
1855+
1856+ y1 = pt .make_placeholder ("y1" , (10 , 4 ), np .float64 )
1857+ y2 = pt .make_placeholder ("y2" , (10 , 4 ), np .float64 )
1858+
1859+ z1 = pt .make_placeholder ("z1" , (10 , 4 ), np .float64 )
1860+ z2 = pt .make_placeholder ("z2" , (10 , 4 ), np .float64 )
1861+
1862+ result = pt .make_dict_of_named_arrays ({"out1" : 2 * pt .trace_call (foo , 2 * x1 , 3 * x2 ),
1863+ "out2" : 4 * pt .trace_call (foo , 4 * y1 , 9 * y2 ),
1864+ "out3" : 6 * pt .trace_call (foo , 7 * z1 , 8 * z2 )
1865+ })
1866+
1867+ concatenated_result = pt .concatenate_calls (
1868+ result , lambda x : pt .tags .FunctionIdentifier ("foo" ) in x .call .function .tags )
1869+
1870+ result = pt .tag_all_calls_to_be_inlined (result )
1871+ concatenated_result = pt .tag_all_calls_to_be_inlined (concatenated_result )
1872+
1873+ assert (pt .analysis .get_num_nodes (pt .inline_calls (result ))
1874+ > pt .analysis .get_num_nodes (pt .inline_calls (concatenated_result )))
1875+
1876+ x1_np , x2_np , y1_np , y2_np , z1_np , z2_np = rng .random ((6 , 10 , 4 ))
1877+
1878+ _ , out_dict1 = pt .generate_loopy (result )(cq ,
1879+ x1 = x1_np , x2 = x2_np ,
1880+ y1 = y1_np , y2 = y2_np ,
1881+ z1 = z1_np , z2 = z2_np )
1882+
1883+ _ , out_dict2 = pt .generate_loopy (concatenated_result )(cq ,
1884+ x1 = x1_np , x2 = x2_np ,
1885+ y1 = y1_np , y2 = y2_np ,
1886+ z1 = z1_np , z2 = z2_np )
1887+ assert out_dict1 .keys () == out_dict2 .keys ()
1888+
1889+ for key in out_dict1 :
1890+ np .testing .assert_allclose (out_dict1 [key ], out_dict2 [key ])
1891+
1892+
1893+ def test_concatenation_via_constant_expressions (ctx_factory ):
1894+
1895+ from pytato .transform .calls import CallsiteCollector
1896+
1897+ rng = np .random .default_rng (0 )
1898+
1899+ ctx = ctx_factory ()
1900+ cq = cl .CommandQueue (ctx )
1901+
1902+ def resampling (coords , iels ):
1903+ return coords [iels ]
1904+
1905+ n_el = 1000
1906+ n_dof = 20
1907+ n_dim = 3
1908+
1909+ n_left_els = 17
1910+ n_right_els = 29
1911+
1912+ coords_dofs_np = rng .random ((n_el , n_dim , n_dof ), np .float64 )
1913+ left_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_left_els )
1914+ right_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_right_els )
1915+
1916+ coords_dofs = pt .make_data_wrapper (coords_dofs_np )
1917+ left_bnd_iels = pt .make_data_wrapper (left_bnd_iels_np )
1918+ right_bnd_iels = pt .make_data_wrapper (right_bnd_iels_np )
1919+
1920+ lcoords = pt .trace_call (resampling , coords_dofs , left_bnd_iels )
1921+ rcoords = pt .trace_call (resampling , coords_dofs , right_bnd_iels )
1922+
1923+ result = pt .make_dict_of_named_arrays ({"lcoords" : lcoords ,
1924+ "rcoords" : rcoords })
1925+ result = pt .tag_all_calls_to_be_inlined (result )
1926+
1927+ assert len (CallsiteCollector (())(result )) == 2
1928+ concated_result = pt .concatenate_calls (
1929+ result ,
1930+ lambda cs : pt .tags .FunctionIdentifier ("resampling" ) in cs .call .function .tags
1931+ )
1932+ assert len (CallsiteCollector (())(concated_result )) == 1
1933+
1934+ _ , out_result = pt .generate_loopy (result )(cq )
1935+ np .testing .assert_allclose (out_result ["lcoords" ],
1936+ coords_dofs_np [left_bnd_iels_np ])
1937+ np .testing .assert_allclose (out_result ["rcoords" ],
1938+ coords_dofs_np [right_bnd_iels_np ])
1939+
1940+ _ , out_concated_result = pt .generate_loopy (result )(cq )
1941+ np .testing .assert_allclose (out_concated_result ["lcoords" ],
1942+ coords_dofs_np [left_bnd_iels_np ])
1943+ np .testing .assert_allclose (out_concated_result ["rcoords" ],
1944+ coords_dofs_np [right_bnd_iels_np ])
1945+
1946+
18351947if __name__ == "__main__" :
18361948 if len (sys .argv ) > 1 :
18371949 exec (sys .argv [1 ])
0 commit comments