Skip to content

Commit c330b7e

Browse files
kaushikcfdinducer
authored andcommitted
test push indirections
1 parent 6b075f6 commit c330b7e

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

test/test_codegen.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa
4545

4646
import pytato as pt
47-
from testlib import assert_allclose_to_numpy, get_random_pt_dag
47+
from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref
4848
import pymbolic.primitives as p
4949

5050

@@ -1854,6 +1854,147 @@ def test_pad(ctx_factory):
18541854
np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array)
18551855

18561856

1857+
def _evaluator_for_indirection_folding(cl_ctx, dictofarys):
1858+
from immutables import Map
1859+
cq = cl.CommandQueue(cl_ctx)
1860+
_, out_dict = pt.generate_loopy(dictofarys)(cq)
1861+
return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()})
1862+
1863+
1864+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1865+
def test_push_indirections_0(ctx_factory, fold_constant_idxs):
1866+
from testlib import (are_all_indexees_materialized_nodes,
1867+
are_all_indexer_arrays_datawrappers)
1868+
1869+
cl_ctx = cl.create_some_context()
1870+
rng = np.random.default_rng(0)
1871+
x_np = rng.random((10, 4))
1872+
map1_np = rng.integers(0, 10, size=17)
1873+
map2_np = rng.integers(0, 17, size=29)
1874+
1875+
x = pt.make_data_wrapper(x_np)
1876+
map1 = pt.make_data_wrapper(map1_np)
1877+
map2 = pt.make_data_wrapper(map2_np)
1878+
1879+
y = 3.14 * ((42*((2*x)[map1]))[map2])
1880+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
1881+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
1882+
)
1883+
1884+
if fold_constant_idxs:
1885+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
1886+
y_transformed = pt.fold_constant_indirections(
1887+
y_transformed,
1888+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
1889+
doa)
1890+
)
1891+
assert are_all_indexer_arrays_datawrappers(y_transformed)
1892+
1893+
auto_test_vs_ref(cl_ctx, y, y_transformed)
1894+
assert are_all_indexees_materialized_nodes(y_transformed)
1895+
1896+
1897+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1898+
def test_push_indirections_1(ctx_factory, fold_constant_idxs):
1899+
from testlib import (are_all_indexees_materialized_nodes,
1900+
are_all_indexer_arrays_datawrappers)
1901+
1902+
cl_ctx = cl.create_some_context()
1903+
rng = np.random.default_rng(0)
1904+
x_np = rng.random((100, 4))
1905+
map1_np = rng.integers(0, 20, size=17)
1906+
1907+
x = pt.make_data_wrapper(x_np)
1908+
map1 = pt.make_data_wrapper(map1_np)
1909+
1910+
y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1])
1911+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
1912+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
1913+
)
1914+
1915+
if fold_constant_idxs:
1916+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
1917+
y_transformed = pt.fold_constant_indirections(
1918+
y_transformed,
1919+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
1920+
doa)
1921+
)
1922+
assert are_all_indexer_arrays_datawrappers(y_transformed)
1923+
1924+
auto_test_vs_ref(cl_ctx, y, y_transformed)
1925+
assert are_all_indexees_materialized_nodes(y_transformed)
1926+
1927+
1928+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1929+
def test_push_indirections_2(ctx_factory, fold_constant_idxs):
1930+
from testlib import (are_all_indexees_materialized_nodes,
1931+
are_all_indexer_arrays_datawrappers)
1932+
1933+
cl_ctx = cl.create_some_context()
1934+
rng = np.random.default_rng(0)
1935+
x_np = rng.random((100, 10))
1936+
map1_np = rng.integers(0, 20, size=17)
1937+
map2_np = rng.integers(0, 4, size=29)
1938+
1939+
x = pt.make_data_wrapper(x_np)
1940+
map1 = pt.make_data_wrapper(map1_np)
1941+
map2 = pt.make_data_wrapper(map2_np)
1942+
1943+
y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7]
1944+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
1945+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
1946+
)
1947+
1948+
if fold_constant_idxs:
1949+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
1950+
y_transformed = pt.fold_constant_indirections(
1951+
y_transformed,
1952+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
1953+
doa)
1954+
)
1955+
assert are_all_indexer_arrays_datawrappers(y_transformed)
1956+
1957+
auto_test_vs_ref(cl_ctx, y, y_transformed)
1958+
assert are_all_indexees_materialized_nodes(y_transformed)
1959+
1960+
1961+
@pytest.mark.parametrize("fold_constant_idxs", (False, True))
1962+
def test_push_indirections_3(ctx_factory, fold_constant_idxs):
1963+
from testlib import (are_all_indexees_materialized_nodes,
1964+
are_all_indexer_arrays_datawrappers)
1965+
1966+
cl_ctx = cl.create_some_context()
1967+
rng = np.random.default_rng(0)
1968+
x_np = rng.random((10, 4))
1969+
map1_np = rng.integers(0, 10, size=17)
1970+
map2_np = rng.integers(0, 17, size=29)
1971+
map3_np = rng.integers(0, 4, size=60)
1972+
map4_np = rng.integers(0, 60, size=22)
1973+
1974+
x = pt.make_data_wrapper(x_np)
1975+
map1 = pt.make_data_wrapper(map1_np)
1976+
map2 = pt.make_data_wrapper(map2_np)
1977+
map3 = pt.make_data_wrapper(map3_np)
1978+
map4 = pt.make_data_wrapper(map4_np)
1979+
1980+
y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4])
1981+
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
1982+
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
1983+
)
1984+
1985+
if fold_constant_idxs:
1986+
assert not are_all_indexer_arrays_datawrappers(y_transformed)
1987+
y_transformed = pt.fold_constant_indirections(
1988+
y_transformed,
1989+
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
1990+
doa)
1991+
)
1992+
assert are_all_indexer_arrays_datawrappers(y_transformed)
1993+
1994+
auto_test_vs_ref(cl_ctx, y, y_transformed)
1995+
assert are_all_indexees_materialized_nodes(y_transformed)
1996+
1997+
18571998
if __name__ == "__main__":
18581999
if len(sys.argv) > 1:
18592000
exec(sys.argv[1])

0 commit comments

Comments
 (0)