@@ -1793,7 +1793,8 @@ def build_expression(tracer):
1793
1793
np .testing .assert_allclose (outputs [key ], expected [key ])
1794
1794
1795
1795
1796
- def test_nested_function_calls (ctx_factory ):
1796
+ @pytest .mark .parametrize ("should_concatenate_bar" , (False , True ))
1797
+ def test_nested_function_calls (ctx_factory , should_concatenate_bar ):
1797
1798
from functools import partial
1798
1799
1799
1800
ctx = ctx_factory ()
@@ -1820,6 +1821,14 @@ def call_bar(tracer, x, y):
1820
1821
"out2" : call_bar (pt .trace_call , x2 , y2 )}
1821
1822
)
1822
1823
result = pt .tag_all_calls_to_be_inlined (result )
1824
+ if should_concatenate_bar :
1825
+ from pytato .transform .calls import CallsiteCollector
1826
+ assert len (CallsiteCollector (())(result )) == 4
1827
+ result = pt .concatenate_calls (
1828
+ result ,
1829
+ lambda x : pt .tags .FunctionIdentifier ("bar" ) in x .call .function .tags )
1830
+ assert len (CallsiteCollector (())(result )) == 2
1831
+
1823
1832
expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
1824
1833
"out2" : call_bar (ref_tracer , x2 , y2 )}
1825
1834
)
@@ -1832,6 +1841,109 @@ def call_bar(tracer, x, y):
1832
1841
np .testing .assert_allclose (result_out [k ], expect_out [k ])
1833
1842
1834
1843
1844
+ def test_concatenate_calls_no_nested (ctx_factory ):
1845
+ rng = np .random .default_rng (0 )
1846
+
1847
+ ctx = ctx_factory ()
1848
+ cq = cl .CommandQueue (ctx )
1849
+
1850
+ def foo (x , y ):
1851
+ return 3 * x + 4 * y + 42 * pt .sin (x ) + 1729 * pt .tan (y )* pt .maximum (x , y )
1852
+
1853
+ x1 = pt .make_placeholder ("x1" , (10 , 4 ), np .float64 )
1854
+ x2 = pt .make_placeholder ("x2" , (10 , 4 ), np .float64 )
1855
+
1856
+ y1 = pt .make_placeholder ("y1" , (10 , 4 ), np .float64 )
1857
+ y2 = pt .make_placeholder ("y2" , (10 , 4 ), np .float64 )
1858
+
1859
+ z1 = pt .make_placeholder ("z1" , (10 , 4 ), np .float64 )
1860
+ z2 = pt .make_placeholder ("z2" , (10 , 4 ), np .float64 )
1861
+
1862
+ result = pt .make_dict_of_named_arrays ({"out1" : 2 * pt .trace_call (foo , 2 * x1 , 3 * x2 ),
1863
+ "out2" : 4 * pt .trace_call (foo , 4 * y1 , 9 * y2 ),
1864
+ "out3" : 6 * pt .trace_call (foo , 7 * z1 , 8 * z2 )
1865
+ })
1866
+
1867
+ concatenated_result = pt .concatenate_calls (
1868
+ result , lambda x : pt .tags .FunctionIdentifier ("foo" ) in x .call .function .tags )
1869
+
1870
+ result = pt .tag_all_calls_to_be_inlined (result )
1871
+ concatenated_result = pt .tag_all_calls_to_be_inlined (concatenated_result )
1872
+
1873
+ assert (pt .analysis .get_num_nodes (pt .inline_calls (result ))
1874
+ > pt .analysis .get_num_nodes (pt .inline_calls (concatenated_result )))
1875
+
1876
+ x1_np , x2_np , y1_np , y2_np , z1_np , z2_np = rng .random ((6 , 10 , 4 ))
1877
+
1878
+ _ , out_dict1 = pt .generate_loopy (result )(cq ,
1879
+ x1 = x1_np , x2 = x2_np ,
1880
+ y1 = y1_np , y2 = y2_np ,
1881
+ z1 = z1_np , z2 = z2_np )
1882
+
1883
+ _ , out_dict2 = pt .generate_loopy (concatenated_result )(cq ,
1884
+ x1 = x1_np , x2 = x2_np ,
1885
+ y1 = y1_np , y2 = y2_np ,
1886
+ z1 = z1_np , z2 = z2_np )
1887
+ assert out_dict1 .keys () == out_dict2 .keys ()
1888
+
1889
+ for key in out_dict1 :
1890
+ np .testing .assert_allclose (out_dict1 [key ], out_dict2 [key ])
1891
+
1892
+
1893
+ def test_concatenation_via_constant_expressions (ctx_factory ):
1894
+
1895
+ from pytato .transform .calls import CallsiteCollector
1896
+
1897
+ rng = np .random .default_rng (0 )
1898
+
1899
+ ctx = ctx_factory ()
1900
+ cq = cl .CommandQueue (ctx )
1901
+
1902
+ def resampling (coords , iels ):
1903
+ return coords [iels ]
1904
+
1905
+ Nel = 1000
1906
+ Ndof = 20
1907
+ Ndim = 3
1908
+
1909
+ Nleft_els = 17
1910
+ Nright_els = 29
1911
+
1912
+ coords_dofs_np = rng .random ((Nel , Ndim , Ndof ), np .float64 )
1913
+ left_bnd_iels_np = rng .integers (low = 0 , high = Nel , size = Nleft_els )
1914
+ right_bnd_iels_np = rng .integers (low = 0 , high = Nel , size = Nright_els )
1915
+
1916
+ coords_dofs = pt .make_data_wrapper (coords_dofs_np )
1917
+ left_bnd_iels = pt .make_data_wrapper (left_bnd_iels_np )
1918
+ right_bnd_iels = pt .make_data_wrapper (right_bnd_iels_np )
1919
+
1920
+ lcoords = pt .trace_call (resampling , coords_dofs , left_bnd_iels )
1921
+ rcoords = pt .trace_call (resampling , coords_dofs , right_bnd_iels )
1922
+
1923
+ result = pt .make_dict_of_named_arrays ({"lcoords" : lcoords ,
1924
+ "rcoords" : rcoords })
1925
+ result = pt .tag_all_calls_to_be_inlined (result )
1926
+
1927
+ assert len (CallsiteCollector (())(result )) == 2
1928
+ concated_result = pt .concatenate_calls (
1929
+ result ,
1930
+ lambda cs : pt .tags .FunctionIdentifier ("resampling" ) in cs .call .function .tags
1931
+ )
1932
+ assert len (CallsiteCollector (())(concated_result )) == 1
1933
+
1934
+ _ , out_result = pt .generate_loopy (result )(cq )
1935
+ np .testing .assert_allclose (out_result ["lcoords" ],
1936
+ coords_dofs_np [left_bnd_iels_np ])
1937
+ np .testing .assert_allclose (out_result ["rcoords" ],
1938
+ coords_dofs_np [right_bnd_iels_np ])
1939
+
1940
+ _ , out_concated_result = pt .generate_loopy (result )(cq )
1941
+ np .testing .assert_allclose (out_concated_result ["lcoords" ],
1942
+ coords_dofs_np [left_bnd_iels_np ])
1943
+ np .testing .assert_allclose (out_concated_result ["rcoords" ],
1944
+ coords_dofs_np [right_bnd_iels_np ])
1945
+
1946
+
1835
1947
if __name__ == "__main__" :
1836
1948
if len (sys .argv ) > 1 :
1837
1949
exec (sys .argv [1 ])
0 commit comments