@@ -1902,7 +1902,8 @@ def build_expression(tracer):
1902
1902
np .testing .assert_allclose (outputs [key ], expected [key ])
1903
1903
1904
1904
1905
- def test_nested_function_calls (ctx_factory ):
1905
+ @pytest .mark .parametrize ("should_concatenate_bar" , (False , True ))
1906
+ def test_nested_function_calls (ctx_factory , should_concatenate_bar ):
1906
1907
from functools import partial
1907
1908
1908
1909
ctx = ctx_factory ()
@@ -1936,6 +1937,14 @@ def call_bar(tracer, x, y):
1936
1937
"out2" : call_bar (pt .trace_call , x2 , y2 )}
1937
1938
)
1938
1939
result = pt .tag_all_calls_to_be_inlined (result )
1940
+ if should_concatenate_bar :
1941
+ from pytato .transform .calls import CallsiteCollector
1942
+ assert len (CallsiteCollector (())(result )) == 4
1943
+ result = pt .concatenate_calls (
1944
+ result ,
1945
+ lambda x : pt .tags .FunctionIdentifier ("bar" ) in x .call .function .tags )
1946
+ assert len (CallsiteCollector (())(result )) == 2
1947
+
1939
1948
expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
1940
1949
"out2" : call_bar (ref_tracer , x2 , y2 )}
1941
1950
)
@@ -1948,6 +1957,109 @@ def call_bar(tracer, x, y):
1948
1957
np .testing .assert_allclose (result_out [k ], expect_out [k ])
1949
1958
1950
1959
1960
+ def test_concatenate_calls_no_nested (ctx_factory ):
1961
+ rng = np .random .default_rng (0 )
1962
+
1963
+ ctx = ctx_factory ()
1964
+ cq = cl .CommandQueue (ctx )
1965
+
1966
+ def foo (x , y ):
1967
+ return 3 * x + 4 * y + 42 * pt .sin (x ) + 1729 * pt .tan (y )* pt .maximum (x , y )
1968
+
1969
+ x1 = pt .make_placeholder ("x1" , (10 , 4 ), np .float64 )
1970
+ x2 = pt .make_placeholder ("x2" , (10 , 4 ), np .float64 )
1971
+
1972
+ y1 = pt .make_placeholder ("y1" , (10 , 4 ), np .float64 )
1973
+ y2 = pt .make_placeholder ("y2" , (10 , 4 ), np .float64 )
1974
+
1975
+ z1 = pt .make_placeholder ("z1" , (10 , 4 ), np .float64 )
1976
+ z2 = pt .make_placeholder ("z2" , (10 , 4 ), np .float64 )
1977
+
1978
+ result = pt .make_dict_of_named_arrays ({"out1" : 2 * pt .trace_call (foo , 2 * x1 , 3 * x2 ),
1979
+ "out2" : 4 * pt .trace_call (foo , 4 * y1 , 9 * y2 ),
1980
+ "out3" : 6 * pt .trace_call (foo , 7 * z1 , 8 * z2 )
1981
+ })
1982
+
1983
+ concatenated_result = pt .concatenate_calls (
1984
+ result , lambda x : pt .tags .FunctionIdentifier ("foo" ) in x .call .function .tags )
1985
+
1986
+ result = pt .tag_all_calls_to_be_inlined (result )
1987
+ concatenated_result = pt .tag_all_calls_to_be_inlined (concatenated_result )
1988
+
1989
+ assert (pt .analysis .get_num_nodes (pt .inline_calls (result ))
1990
+ > pt .analysis .get_num_nodes (pt .inline_calls (concatenated_result )))
1991
+
1992
+ x1_np , x2_np , y1_np , y2_np , z1_np , z2_np = rng .random ((6 , 10 , 4 ))
1993
+
1994
+ _ , out_dict1 = pt .generate_loopy (result )(cq ,
1995
+ x1 = x1_np , x2 = x2_np ,
1996
+ y1 = y1_np , y2 = y2_np ,
1997
+ z1 = z1_np , z2 = z2_np )
1998
+
1999
+ _ , out_dict2 = pt .generate_loopy (concatenated_result )(cq ,
2000
+ x1 = x1_np , x2 = x2_np ,
2001
+ y1 = y1_np , y2 = y2_np ,
2002
+ z1 = z1_np , z2 = z2_np )
2003
+ assert out_dict1 .keys () == out_dict2 .keys ()
2004
+
2005
+ for key in out_dict1 :
2006
+ np .testing .assert_allclose (out_dict1 [key ], out_dict2 [key ])
2007
+
2008
+
2009
+ def test_concatenation_via_constant_expressions (ctx_factory ):
2010
+
2011
+ from pytato .transform .calls import CallsiteCollector
2012
+
2013
+ rng = np .random .default_rng (0 )
2014
+
2015
+ ctx = ctx_factory ()
2016
+ cq = cl .CommandQueue (ctx )
2017
+
2018
+ def resampling (coords , iels ):
2019
+ return coords [iels ]
2020
+
2021
+ n_el = 1000
2022
+ n_dof = 20
2023
+ n_dim = 3
2024
+
2025
+ n_left_els = 17
2026
+ n_right_els = 29
2027
+
2028
+ coords_dofs_np = rng .random ((n_el , n_dim , n_dof ), np .float64 )
2029
+ left_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_left_els )
2030
+ right_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_right_els )
2031
+
2032
+ coords_dofs = pt .make_data_wrapper (coords_dofs_np )
2033
+ left_bnd_iels = pt .make_data_wrapper (left_bnd_iels_np )
2034
+ right_bnd_iels = pt .make_data_wrapper (right_bnd_iels_np )
2035
+
2036
+ lcoords = pt .trace_call (resampling , coords_dofs , left_bnd_iels )
2037
+ rcoords = pt .trace_call (resampling , coords_dofs , right_bnd_iels )
2038
+
2039
+ result = pt .make_dict_of_named_arrays ({"lcoords" : lcoords ,
2040
+ "rcoords" : rcoords })
2041
+ result = pt .tag_all_calls_to_be_inlined (result )
2042
+
2043
+ assert len (CallsiteCollector (())(result )) == 2
2044
+ concated_result = pt .concatenate_calls (
2045
+ result ,
2046
+ lambda cs : pt .tags .FunctionIdentifier ("resampling" ) in cs .call .function .tags
2047
+ )
2048
+ assert len (CallsiteCollector (())(concated_result )) == 1
2049
+
2050
+ _ , out_result = pt .generate_loopy (result )(cq )
2051
+ np .testing .assert_allclose (out_result ["lcoords" ],
2052
+ coords_dofs_np [left_bnd_iels_np ])
2053
+ np .testing .assert_allclose (out_result ["rcoords" ],
2054
+ coords_dofs_np [right_bnd_iels_np ])
2055
+
2056
+ _ , out_concated_result = pt .generate_loopy (result )(cq )
2057
+ np .testing .assert_allclose (out_concated_result ["lcoords" ],
2058
+ coords_dofs_np [left_bnd_iels_np ])
2059
+ np .testing .assert_allclose (out_concated_result ["rcoords" ],
2060
+ coords_dofs_np [right_bnd_iels_np ])
2061
+
2062
+
1951
2063
if __name__ == "__main__" :
1952
2064
if len (sys .argv ) > 1 :
1953
2065
exec (sys .argv [1 ])
0 commit comments