Skip to content

Commit 66e173d

Browse files
committed
adds a regression to test concatenate_calls
1 parent a558381 commit 66e173d

File tree

1 file changed

+113
-1
lines changed

1 file changed

+113
-1
lines changed

test/test_codegen.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1793,7 +1793,8 @@ def build_expression(tracer):
17931793
np.testing.assert_allclose(outputs[key], expected[key])
17941794

17951795

1796-
def test_nested_function_calls(ctx_factory):
1796+
@pytest.mark.parametrize("should_concatenate_bar", (False, True))
1797+
def test_nested_function_calls(ctx_factory, should_concatenate_bar):
17971798
from functools import partial
17981799

17991800
ctx = ctx_factory()
@@ -1820,6 +1821,14 @@ def call_bar(tracer, x, y):
18201821
"out2": call_bar(pt.trace_call, x2, y2)}
18211822
)
18221823
result = pt.tag_all_calls_to_be_inlined(result)
1824+
if should_concatenate_bar:
1825+
from pytato.transform.calls import CallsiteCollector
1826+
assert len(CallsiteCollector(())(result)) == 4
1827+
result = pt.concatenate_calls(
1828+
result,
1829+
lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags)
1830+
assert len(CallsiteCollector(())(result)) == 2
1831+
18231832
expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1),
18241833
"out2": call_bar(ref_tracer, x2, y2)}
18251834
)
@@ -1832,6 +1841,109 @@ def call_bar(tracer, x, y):
18321841
np.testing.assert_allclose(result_out[k], expect_out[k])
18331842

18341843

1844+
def test_concatenate_calls_no_nested(ctx_factory):
1845+
rng = np.random.default_rng(0)
1846+
1847+
ctx = ctx_factory()
1848+
cq = cl.CommandQueue(ctx)
1849+
1850+
def foo(x, y):
1851+
return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y)
1852+
1853+
x1 = pt.make_placeholder("x1", (10, 4), np.float64)
1854+
x2 = pt.make_placeholder("x2", (10, 4), np.float64)
1855+
1856+
y1 = pt.make_placeholder("y1", (10, 4), np.float64)
1857+
y2 = pt.make_placeholder("y2", (10, 4), np.float64)
1858+
1859+
z1 = pt.make_placeholder("z1", (10, 4), np.float64)
1860+
z2 = pt.make_placeholder("z2", (10, 4), np.float64)
1861+
1862+
result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2),
1863+
"out2": 4*pt.trace_call(foo, 4*y1, 9*y2),
1864+
"out3": 6*pt.trace_call(foo, 7*z1, 8*z2)
1865+
})
1866+
1867+
concatenated_result = pt.concatenate_calls(
1868+
result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags)
1869+
1870+
result = pt.tag_all_calls_to_be_inlined(result)
1871+
concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result)
1872+
1873+
assert (pt.analysis.get_num_nodes(pt.inline_calls(result))
1874+
> pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result)))
1875+
1876+
x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4))
1877+
1878+
_, out_dict1 = pt.generate_loopy(result)(cq,
1879+
x1=x1_np, x2=x2_np,
1880+
y1=y1_np, y2=y2_np,
1881+
z1=z1_np, z2=z2_np)
1882+
1883+
_, out_dict2 = pt.generate_loopy(concatenated_result)(cq,
1884+
x1=x1_np, x2=x2_np,
1885+
y1=y1_np, y2=y2_np,
1886+
z1=z1_np, z2=z2_np)
1887+
assert out_dict1.keys() == out_dict2.keys()
1888+
1889+
for key in out_dict1:
1890+
np.testing.assert_allclose(out_dict1[key], out_dict2[key])
1891+
1892+
1893+
def test_concatenation_via_constant_expressions(ctx_factory):
1894+
1895+
from pytato.transform.calls import CallsiteCollector
1896+
1897+
rng = np.random.default_rng(0)
1898+
1899+
ctx = ctx_factory()
1900+
cq = cl.CommandQueue(ctx)
1901+
1902+
def resampling(coords, iels):
1903+
return coords[iels]
1904+
1905+
Nel = 1000
1906+
Ndof = 20
1907+
Ndim = 3
1908+
1909+
Nleft_els = 17
1910+
Nright_els = 29
1911+
1912+
coords_dofs_np = rng.random((Nel, Ndim, Ndof), np.float64)
1913+
left_bnd_iels_np = rng.integers(low=0, high=Nel, size=Nleft_els)
1914+
right_bnd_iels_np = rng.integers(low=0, high=Nel, size=Nright_els)
1915+
1916+
coords_dofs = pt.make_data_wrapper(coords_dofs_np)
1917+
left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np)
1918+
right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np)
1919+
1920+
lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels)
1921+
rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels)
1922+
1923+
result = pt.make_dict_of_named_arrays({"lcoords": lcoords,
1924+
"rcoords": rcoords})
1925+
result = pt.tag_all_calls_to_be_inlined(result)
1926+
1927+
assert len(CallsiteCollector(())(result)) == 2
1928+
concated_result = pt.concatenate_calls(
1929+
result,
1930+
lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags
1931+
)
1932+
assert len(CallsiteCollector(())(concated_result)) == 1
1933+
1934+
_, out_result = pt.generate_loopy(result)(cq)
1935+
np.testing.assert_allclose(out_result["lcoords"],
1936+
coords_dofs_np[left_bnd_iels_np])
1937+
np.testing.assert_allclose(out_result["rcoords"],
1938+
coords_dofs_np[right_bnd_iels_np])
1939+
1940+
_, out_concated_result = pt.generate_loopy(result)(cq)
1941+
np.testing.assert_allclose(out_concated_result["lcoords"],
1942+
coords_dofs_np[left_bnd_iels_np])
1943+
np.testing.assert_allclose(out_concated_result["rcoords"],
1944+
coords_dofs_np[right_bnd_iels_np])
1945+
1946+
18351947
if __name__ == "__main__":
18361948
if len(sys.argv) > 1:
18371949
exec(sys.argv[1])

0 commit comments

Comments
 (0)