Skip to content

Commit 9ae9474

Browse files
committed
test push indirections
1 parent f02ba1c commit 9ae9474

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

test/test_codegen.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa
4545

4646
import pytato as pt
47-
from testlib import assert_allclose_to_numpy, get_random_pt_dag
47+
from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref
4848
import pymbolic.primitives as p
4949

5050

@@ -2002,6 +2002,147 @@ def call_bar(tracer, x, y):
20022002
np.testing.assert_allclose(result_out[k], expect_out[k])
20032003

20042004

2005+
def _evaluator_for_indirection_folding(cl_ctx, dictofarys):
2006+
from immutables import Map
2007+
cq = cl.CommandQueue(cl_ctx)
2008+
_, out_dict = pt.generate_loopy(dictofarys)(cq)
2009+
return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()})
2010+
2011+
2012+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
2013+
def test_push_indirections_0(ctx_factory, fold_constant_idxs):
2014+
from testlib import (are_all_indexees_materialized_nodes,
2015+
are_all_indexer_arrays_datawrappers)
2016+
2017+
cl_ctx = cl.create_some_context()
2018+
rng = np.random.default_rng(0)
2019+
x_np = rng.random((10, 4))
2020+
map1_np = rng.integers(0, 10, size=17)
2021+
map2_np = rng.integers(0, 17, size=29)
2022+
2023+
x = pt.make_data_wrapper(x_np)
2024+
map1 = pt.make_data_wrapper(map1_np)
2025+
map2 = pt.make_data_wrapper(map2_np)
2026+
2027+
y = 3.14 * ((42*((2*x)[map1]))[map2])
2028+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
2029+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
2030+
)
2031+
2032+
if fold_constant_idxs:
2033+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
2034+
y_transformed = pt.fold_constant_indirections(
2035+
y_transformed,
2036+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
2037+
doa)
2038+
)
2039+
assert are_all_indexer_arrays_datawrappers(y_transformed)
2040+
2041+
auto_test_vs_ref(cl_ctx, y, y_transformed)
2042+
assert are_all_indexees_materialized_nodes(y_transformed)
2043+
2044+
2045+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
2046+
def test_push_indirections_1(ctx_factory, fold_constant_idxs):
2047+
from testlib import (are_all_indexees_materialized_nodes,
2048+
are_all_indexer_arrays_datawrappers)
2049+
2050+
cl_ctx = cl.create_some_context()
2051+
rng = np.random.default_rng(0)
2052+
x_np = rng.random((100, 4))
2053+
map1_np = rng.integers(0, 20, size=17)
2054+
2055+
x = pt.make_data_wrapper(x_np)
2056+
map1 = pt.make_data_wrapper(map1_np)
2057+
2058+
y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1])
2059+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
2060+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
2061+
)
2062+
2063+
if fold_constant_idxs:
2064+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
2065+
y_transformed = pt.fold_constant_indirections(
2066+
y_transformed,
2067+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
2068+
doa)
2069+
)
2070+
assert are_all_indexer_arrays_datawrappers(y_transformed)
2071+
2072+
auto_test_vs_ref(cl_ctx, y, y_transformed)
2073+
assert are_all_indexees_materialized_nodes(y_transformed)
2074+
2075+
2076+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
2077+
def test_push_indirections_2(ctx_factory, fold_constant_idxs):
2078+
from testlib import (are_all_indexees_materialized_nodes,
2079+
are_all_indexer_arrays_datawrappers)
2080+
2081+
cl_ctx = cl.create_some_context()
2082+
rng = np.random.default_rng(0)
2083+
x_np = rng.random((100, 10))
2084+
map1_np = rng.integers(0, 20, size=17)
2085+
map2_np = rng.integers(0, 4, size=29)
2086+
2087+
x = pt.make_data_wrapper(x_np)
2088+
map1 = pt.make_data_wrapper(map1_np)
2089+
map2 = pt.make_data_wrapper(map2_np)
2090+
2091+
y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7]
2092+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
2093+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
2094+
)
2095+
2096+
if fold_constant_idxs:
2097+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
2098+
y_transformed = pt.fold_constant_indirections(
2099+
y_transformed,
2100+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
2101+
doa)
2102+
)
2103+
assert are_all_indexer_arrays_datawrappers(y_transformed)
2104+
2105+
auto_test_vs_ref(cl_ctx, y, y_transformed)
2106+
assert are_all_indexees_materialized_nodes(y_transformed)
2107+
2108+
2109+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
2110+
def test_push_indirections_3(ctx_factory, fold_constant_idxs):
2111+
from testlib import (are_all_indexees_materialized_nodes,
2112+
are_all_indexer_arrays_datawrappers)
2113+
2114+
cl_ctx = cl.create_some_context()
2115+
rng = np.random.default_rng(0)
2116+
x_np = rng.random((10, 4))
2117+
map1_np = rng.integers(0, 10, size=17)
2118+
map2_np = rng.integers(0, 17, size=29)
2119+
map3_np = rng.integers(0, 4, size=60)
2120+
map4_np = rng.integers(0, 60, size=22)
2121+
2122+
x = pt.make_data_wrapper(x_np)
2123+
map1 = pt.make_data_wrapper(map1_np)
2124+
map2 = pt.make_data_wrapper(map2_np)
2125+
map3 = pt.make_data_wrapper(map3_np)
2126+
map4 = pt.make_data_wrapper(map4_np)
2127+
2128+
y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4])
2129+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
2130+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
2131+
)
2132+
2133+
if fold_constant_idxs:
2134+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
2135+
y_transformed = pt.fold_constant_indirections(
2136+
y_transformed,
2137+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
2138+
doa)
2139+
)
2140+
assert are_all_indexer_arrays_datawrappers(y_transformed)
2141+
2142+
auto_test_vs_ref(cl_ctx, y, y_transformed)
2143+
assert are_all_indexees_materialized_nodes(y_transformed)
2144+
2145+
20052146
if __name__ == "__main__":
20062147
if len(sys.argv) > 1:
20072148
exec(sys.argv[1])

0 commit comments

Comments
 (0)