Skip to content

Commit ae67637

Browse files
committed
Add regressions pt.inline_calls
1 parent 3a05f3f commit ae67637

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

test/test_codegen.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,6 +1752,86 @@ def test_two_rolls(ctx_factory):
17521752
np.testing.assert_allclose(np_out, pt_out)
17531753

17541754

1755+
def test_function_call(ctx_factory):
1756+
cl_ctx = ctx_factory()
1757+
cq = cl.CommandQueue(cl_ctx)
1758+
1759+
def f(x):
1760+
return 2*x
1761+
1762+
def g(x):
1763+
return 2*x, 3*x
1764+
1765+
def h(x, y):
1766+
return {"twice": 2*x+y, "thrice": 3*x+y}
1767+
1768+
def build_expression(tracer):
1769+
x = pt.arange(500, dtype=np.float32)
1770+
twice_x = tracer(f, x)
1771+
twice_x_2, thrice_x_2 = tracer(g, x)
1772+
1773+
result = tracer(h, x, 2*x)
1774+
twice_x_3 = result["twice"]
1775+
thrice_x_3 = result["thrice"]
1776+
1777+
return {"foo": 3.14 + twice_x_3,
1778+
"bar": 4 * thrice_x_3,
1779+
"baz": 65 * twice_x,
1780+
"quux": 7 * twice_x_2}
1781+
1782+
result1 = pt.tag_all_calls_to_be_inlined(
1783+
pt.make_dict_of_named_arrays(build_expression(pt.trace_call)))
1784+
result2 = pt.make_dict_of_named_arrays(
1785+
build_expression(lambda fn, *args: fn(*args)))
1786+
1787+
_, outputs = pt.generate_loopy(result1)(cq, out_host=True)
1788+
_, expected = pt.generate_loopy(result2)(cq, out_host=True)
1789+
1790+
assert len(outputs) == len(expected)
1791+
1792+
for key in outputs.keys():
1793+
np.testing.assert_allclose(outputs[key], expected[key])
1794+
1795+
1796+
def test_nested_function_calls(ctx_factory):
1797+
from functools import partial
1798+
1799+
ctx = ctx_factory()
1800+
cq = cl.CommandQueue(ctx)
1801+
1802+
rng = np.random.default_rng(0)
1803+
ref_tracer = lambda f, *args, identifier: f(*args) # noqa: E731
1804+
1805+
def foo(tracer, x, y):
1806+
return 2*x + 3*y
1807+
1808+
def bar(tracer, x, y):
1809+
foo_x_y = tracer(partial(foo, tracer), x, y, identifier="foo")
1810+
return foo_x_y * x * y
1811+
1812+
def call_bar(tracer, x, y):
1813+
return tracer(partial(bar, tracer), x, y, identifier="bar")
1814+
1815+
x1_np, y1_np = rng.random((2, 13, 29))
1816+
x2_np, y2_np = rng.random((2, 4, 29))
1817+
x1, y1 = pt.make_data_wrapper(x1_np), pt.make_data_wrapper(y1_np)
1818+
x2, y2 = pt.make_data_wrapper(x2_np), pt.make_data_wrapper(y2_np)
1819+
result = pt.make_dict_of_named_arrays({"out1": call_bar(pt.trace_call, x1, y1),
1820+
"out2": call_bar(pt.trace_call, x2, y2)}
1821+
)
1822+
result = pt.tag_all_calls_to_be_inlined(result)
1823+
expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1),
1824+
"out2": call_bar(ref_tracer, x2, y2)}
1825+
)
1826+
1827+
_, result_out = pt.generate_loopy(result)(cq)
1828+
_, expect_out = pt.generate_loopy(expect)(cq)
1829+
1830+
assert result_out.keys() == expect_out.keys()
1831+
for k in expect_out:
1832+
np.testing.assert_allclose(result_out[k], expect_out[k])
1833+
1834+
17551835
if __name__ == "__main__":
17561836
if len(sys.argv) > 1:
17571837
exec(sys.argv[1])

0 commit comments

Comments
 (0)