Skip to content

Commit cf70922

Browse files
kaushikcfdinducer
authored andcommitted
Add regressions for pt.concatenate_calls
1 parent 1b05a12 commit cf70922

File tree

1 file changed

+113
-1
lines changed

1 file changed

+113
-1
lines changed

test/test_codegen.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1902,7 +1902,8 @@ def build_expression(tracer):
19021902
np.testing.assert_allclose(outputs[key], expected[key])
19031903

19041904

1905-
def test_nested_function_calls(ctx_factory):
1905+
@pytest.mark.parametrize("should_concatenate_bar", (False, True))
1906+
def test_nested_function_calls(ctx_factory, should_concatenate_bar):
19061907
from functools import partial
19071908

19081909
ctx = ctx_factory()
@@ -1936,6 +1937,14 @@ def call_bar(tracer, x, y):
19361937
"out2": call_bar(pt.trace_call, x2, y2)}
19371938
)
19381939
result = pt.tag_all_calls_to_be_inlined(result)
1940+
if should_concatenate_bar:
1941+
from pytato.transform.calls import CallsiteCollector
1942+
assert len(CallsiteCollector(())(result)) == 4
1943+
result = pt.concatenate_calls(
1944+
result,
1945+
lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags)
1946+
assert len(CallsiteCollector(())(result)) == 2
1947+
19391948
expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1),
19401949
"out2": call_bar(ref_tracer, x2, y2)}
19411950
)
@@ -1948,6 +1957,109 @@ def call_bar(tracer, x, y):
19481957
np.testing.assert_allclose(result_out[k], expect_out[k])
19491958

19501959

1960+
def test_concatenate_calls_no_nested(ctx_factory):
1961+
rng = np.random.default_rng(0)
1962+
1963+
ctx = ctx_factory()
1964+
cq = cl.CommandQueue(ctx)
1965+
1966+
def foo(x, y):
1967+
return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y)
1968+
1969+
x1 = pt.make_placeholder("x1", (10, 4), np.float64)
1970+
x2 = pt.make_placeholder("x2", (10, 4), np.float64)
1971+
1972+
y1 = pt.make_placeholder("y1", (10, 4), np.float64)
1973+
y2 = pt.make_placeholder("y2", (10, 4), np.float64)
1974+
1975+
z1 = pt.make_placeholder("z1", (10, 4), np.float64)
1976+
z2 = pt.make_placeholder("z2", (10, 4), np.float64)
1977+
1978+
result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2),
1979+
"out2": 4*pt.trace_call(foo, 4*y1, 9*y2),
1980+
"out3": 6*pt.trace_call(foo, 7*z1, 8*z2)
1981+
})
1982+
1983+
concatenated_result = pt.concatenate_calls(
1984+
result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags)
1985+
1986+
result = pt.tag_all_calls_to_be_inlined(result)
1987+
concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result)
1988+
1989+
assert (pt.analysis.get_num_nodes(pt.inline_calls(result))
1990+
> pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result)))
1991+
1992+
x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4))
1993+
1994+
_, out_dict1 = pt.generate_loopy(result)(cq,
1995+
x1=x1_np, x2=x2_np,
1996+
y1=y1_np, y2=y2_np,
1997+
z1=z1_np, z2=z2_np)
1998+
1999+
_, out_dict2 = pt.generate_loopy(concatenated_result)(cq,
2000+
x1=x1_np, x2=x2_np,
2001+
y1=y1_np, y2=y2_np,
2002+
z1=z1_np, z2=z2_np)
2003+
assert out_dict1.keys() == out_dict2.keys()
2004+
2005+
for key in out_dict1:
2006+
np.testing.assert_allclose(out_dict1[key], out_dict2[key])
2007+
2008+
2009+
def test_concatenation_via_constant_expressions(ctx_factory):
2010+
2011+
from pytato.transform.calls import CallsiteCollector
2012+
2013+
rng = np.random.default_rng(0)
2014+
2015+
ctx = ctx_factory()
2016+
cq = cl.CommandQueue(ctx)
2017+
2018+
def resampling(coords, iels):
2019+
return coords[iels]
2020+
2021+
n_el = 1000
2022+
n_dof = 20
2023+
n_dim = 3
2024+
2025+
n_left_els = 17
2026+
n_right_els = 29
2027+
2028+
coords_dofs_np = rng.random((n_el, n_dim, n_dof), np.float64)
2029+
left_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_left_els)
2030+
right_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_right_els)
2031+
2032+
coords_dofs = pt.make_data_wrapper(coords_dofs_np)
2033+
left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np)
2034+
right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np)
2035+
2036+
lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels)
2037+
rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels)
2038+
2039+
result = pt.make_dict_of_named_arrays({"lcoords": lcoords,
2040+
"rcoords": rcoords})
2041+
result = pt.tag_all_calls_to_be_inlined(result)
2042+
2043+
assert len(CallsiteCollector(())(result)) == 2
2044+
concated_result = pt.concatenate_calls(
2045+
result,
2046+
lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags
2047+
)
2048+
assert len(CallsiteCollector(())(concated_result)) == 1
2049+
2050+
_, out_result = pt.generate_loopy(result)(cq)
2051+
np.testing.assert_allclose(out_result["lcoords"],
2052+
coords_dofs_np[left_bnd_iels_np])
2053+
np.testing.assert_allclose(out_result["rcoords"],
2054+
coords_dofs_np[right_bnd_iels_np])
2055+
2056+
_, out_concated_result = pt.generate_loopy(result)(cq)
2057+
np.testing.assert_allclose(out_concated_result["lcoords"],
2058+
coords_dofs_np[left_bnd_iels_np])
2059+
np.testing.assert_allclose(out_concated_result["rcoords"],
2060+
coords_dofs_np[right_bnd_iels_np])
2061+
2062+
19512063
if __name__ == "__main__":
19522064
if len(sys.argv) > 1:
19532065
exec(sys.argv[1])

0 commit comments

Comments
 (0)