Skip to content

Commit 018be91

Browse files
committed
test push indirections
1 parent 8d36c1f commit 018be91

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

test/test_codegen.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa
4545

4646
import pytato as pt
47-
from testlib import assert_allclose_to_numpy, get_random_pt_dag
47+
from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref
4848
import pymbolic.primitives as p
4949

5050

@@ -1921,6 +1921,147 @@ def test_pad(ctx_factory):
19211921
np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array)
19221922

19231923

1924+
def _evaluator_for_indirection_folding(cl_ctx, dictofarys):
1925+
from immutables import Map
1926+
cq = cl.CommandQueue(cl_ctx)
1927+
_, out_dict = pt.generate_loopy(dictofarys)(cq)
1928+
return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()})
1929+
1930+
1931+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1932+
def test_push_indirections_0(ctx_factory, fold_constant_idxs):
1933+
from testlib import (are_all_indexees_materialized_nodes,
1934+
are_all_indexer_arrays_datawrappers)
1935+
1936+
cl_ctx = cl.create_some_context()
1937+
rng = np.random.default_rng(0)
1938+
x_np = rng.random((10, 4))
1939+
map1_np = rng.integers(0, 10, size=17)
1940+
map2_np = rng.integers(0, 17, size=29)
1941+
1942+
x = pt.make_data_wrapper(x_np)
1943+
map1 = pt.make_data_wrapper(map1_np)
1944+
map2 = pt.make_data_wrapper(map2_np)
1945+
1946+
y = 3.14 * ((42*((2*x)[map1]))[map2])
1947+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
1948+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
1949+
)
1950+
1951+
if fold_constant_idxs:
1952+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
1953+
y_transformed = pt.fold_constant_indirections(
1954+
y_transformed,
1955+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
1956+
doa)
1957+
)
1958+
assert are_all_indexer_arrays_datawrappers(y_transformed)
1959+
1960+
auto_test_vs_ref(cl_ctx, y, y_transformed)
1961+
assert are_all_indexees_materialized_nodes(y_transformed)
1962+
1963+
1964+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1965+
def test_push_indirections_1(ctx_factory, fold_constant_idxs):
1966+
from testlib import (are_all_indexees_materialized_nodes,
1967+
are_all_indexer_arrays_datawrappers)
1968+
1969+
cl_ctx = cl.create_some_context()
1970+
rng = np.random.default_rng(0)
1971+
x_np = rng.random((100, 4))
1972+
map1_np = rng.integers(0, 20, size=17)
1973+
1974+
x = pt.make_data_wrapper(x_np)
1975+
map1 = pt.make_data_wrapper(map1_np)
1976+
1977+
y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1])
1978+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
1979+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
1980+
)
1981+
1982+
if fold_constant_idxs:
1983+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
1984+
y_transformed = pt.fold_constant_indirections(
1985+
y_transformed,
1986+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
1987+
doa)
1988+
)
1989+
assert are_all_indexer_arrays_datawrappers(y_transformed)
1990+
1991+
auto_test_vs_ref(cl_ctx, y, y_transformed)
1992+
assert are_all_indexees_materialized_nodes(y_transformed)
1993+
1994+
1995+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1996+
def test_push_indirections_2(ctx_factory, fold_constant_idxs):
1997+
from testlib import (are_all_indexees_materialized_nodes,
1998+
are_all_indexer_arrays_datawrappers)
1999+
2000+
cl_ctx = cl.create_some_context()
2001+
rng = np.random.default_rng(0)
2002+
x_np = rng.random((100, 10))
2003+
map1_np = rng.integers(0, 20, size=17)
2004+
map2_np = rng.integers(0, 4, size=29)
2005+
2006+
x = pt.make_data_wrapper(x_np)
2007+
map1 = pt.make_data_wrapper(map1_np)
2008+
map2 = pt.make_data_wrapper(map2_np)
2009+
2010+
y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7]
2011+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
2012+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
2013+
)
2014+
2015+
if fold_constant_idxs:
2016+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
2017+
y_transformed = pt.fold_constant_indirections(
2018+
y_transformed,
2019+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
2020+
doa)
2021+
)
2022+
assert are_all_indexer_arrays_datawrappers(y_transformed)
2023+
2024+
auto_test_vs_ref(cl_ctx, y, y_transformed)
2025+
assert are_all_indexees_materialized_nodes(y_transformed)
2026+
2027+
2028+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
2029+
def test_push_indirections_3(ctx_factory, fold_constant_idxs):
2030+
from testlib import (are_all_indexees_materialized_nodes,
2031+
are_all_indexer_arrays_datawrappers)
2032+
2033+
cl_ctx = cl.create_some_context()
2034+
rng = np.random.default_rng(0)
2035+
x_np = rng.random((10, 4))
2036+
map1_np = rng.integers(0, 10, size=17)
2037+
map2_np = rng.integers(0, 17, size=29)
2038+
map3_np = rng.integers(0, 4, size=60)
2039+
map4_np = rng.integers(0, 60, size=22)
2040+
2041+
x = pt.make_data_wrapper(x_np)
2042+
map1 = pt.make_data_wrapper(map1_np)
2043+
map2 = pt.make_data_wrapper(map2_np)
2044+
map3 = pt.make_data_wrapper(map3_np)
2045+
map4 = pt.make_data_wrapper(map4_np)
2046+
2047+
y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4])
2048+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
2049+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
2050+
)
2051+
2052+
if fold_constant_idxs:
2053+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
2054+
y_transformed = pt.fold_constant_indirections(
2055+
y_transformed,
2056+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
2057+
doa)
2058+
)
2059+
assert are_all_indexer_arrays_datawrappers(y_transformed)
2060+
2061+
auto_test_vs_ref(cl_ctx, y, y_transformed)
2062+
assert are_all_indexees_materialized_nodes(y_transformed)
2063+
2064+
19242065
if __name__ == "__main__":
19252066
if len(sys.argv) > 1:
19262067
exec(sys.argv[1])

0 commit comments

Comments
 (0)