|
44 | 44 | from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa
|
45 | 45 |
|
46 | 46 | import pytato as pt
|
47 |
| -from testlib import assert_allclose_to_numpy, get_random_pt_dag |
| 47 | +from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref |
48 | 48 | import pymbolic.primitives as p
|
49 | 49 |
|
50 | 50 |
|
@@ -1921,6 +1921,147 @@ def test_pad(ctx_factory):
|
1921 | 1921 | np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array)
|
1922 | 1922 |
|
1923 | 1923 |
|
| 1924 | +def _evaluator_for_indirection_folding(cl_ctx, dictofarys): |
| 1925 | + from immutables import Map |
| 1926 | + cq = cl.CommandQueue(cl_ctx) |
| 1927 | + _, out_dict = pt.generate_loopy(dictofarys)(cq) |
| 1928 | + return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()}) |
| 1929 | + |
| 1930 | + |
| 1931 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1932 | +def test_push_indirections_0(ctx_factory, fold_constant_idxs): |
| 1933 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1934 | + are_all_indexer_arrays_datawrappers) |
| 1935 | + |
| 1936 | + cl_ctx = cl.create_some_context() |
| 1937 | + rng = np.random.default_rng(0) |
| 1938 | + x_np = rng.random((10, 4)) |
| 1939 | + map1_np = rng.integers(0, 10, size=17) |
| 1940 | + map2_np = rng.integers(0, 17, size=29) |
| 1941 | + |
| 1942 | + x = pt.make_data_wrapper(x_np) |
| 1943 | + map1 = pt.make_data_wrapper(map1_np) |
| 1944 | + map2 = pt.make_data_wrapper(map2_np) |
| 1945 | + |
| 1946 | + y = 3.14 * ((42*((2*x)[map1]))[map2]) |
| 1947 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 1948 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 1949 | + ) |
| 1950 | + |
| 1951 | + if fold_constant_idxs: |
| 1952 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 1953 | + y_transformed = pt.fold_constant_indirections( |
| 1954 | + y_transformed, |
| 1955 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 1956 | + doa) |
| 1957 | + ) |
| 1958 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 1959 | + |
| 1960 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 1961 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 1962 | + |
| 1963 | + |
| 1964 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1965 | +def test_push_indirections_1(ctx_factory, fold_constant_idxs): |
| 1966 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1967 | + are_all_indexer_arrays_datawrappers) |
| 1968 | + |
| 1969 | + cl_ctx = cl.create_some_context() |
| 1970 | + rng = np.random.default_rng(0) |
| 1971 | + x_np = rng.random((100, 4)) |
| 1972 | + map1_np = rng.integers(0, 20, size=17) |
| 1973 | + |
| 1974 | + x = pt.make_data_wrapper(x_np) |
| 1975 | + map1 = pt.make_data_wrapper(map1_np) |
| 1976 | + |
| 1977 | + y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1]) |
| 1978 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 1979 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 1980 | + ) |
| 1981 | + |
| 1982 | + if fold_constant_idxs: |
| 1983 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 1984 | + y_transformed = pt.fold_constant_indirections( |
| 1985 | + y_transformed, |
| 1986 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 1987 | + doa) |
| 1988 | + ) |
| 1989 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 1990 | + |
| 1991 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 1992 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 1993 | + |
| 1994 | + |
| 1995 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1996 | +def test_push_indirections_2(ctx_factory, fold_constant_idxs): |
| 1997 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1998 | + are_all_indexer_arrays_datawrappers) |
| 1999 | + |
| 2000 | + cl_ctx = cl.create_some_context() |
| 2001 | + rng = np.random.default_rng(0) |
| 2002 | + x_np = rng.random((100, 10)) |
| 2003 | + map1_np = rng.integers(0, 20, size=17) |
| 2004 | + map2_np = rng.integers(0, 4, size=29) |
| 2005 | + |
| 2006 | + x = pt.make_data_wrapper(x_np) |
| 2007 | + map1 = pt.make_data_wrapper(map1_np) |
| 2008 | + map2 = pt.make_data_wrapper(map2_np) |
| 2009 | + |
| 2010 | + y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7] |
| 2011 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 2012 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 2013 | + ) |
| 2014 | + |
| 2015 | + if fold_constant_idxs: |
| 2016 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 2017 | + y_transformed = pt.fold_constant_indirections( |
| 2018 | + y_transformed, |
| 2019 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 2020 | + doa) |
| 2021 | + ) |
| 2022 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 2023 | + |
| 2024 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 2025 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 2026 | + |
| 2027 | + |
| 2028 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 2029 | +def test_push_indirections_3(ctx_factory, fold_constant_idxs): |
| 2030 | + from testlib import (are_all_indexees_materialized_nodes, |
| 2031 | + are_all_indexer_arrays_datawrappers) |
| 2032 | + |
| 2033 | + cl_ctx = cl.create_some_context() |
| 2034 | + rng = np.random.default_rng(0) |
| 2035 | + x_np = rng.random((10, 4)) |
| 2036 | + map1_np = rng.integers(0, 10, size=17) |
| 2037 | + map2_np = rng.integers(0, 17, size=29) |
| 2038 | + map3_np = rng.integers(0, 4, size=60) |
| 2039 | + map4_np = rng.integers(0, 60, size=22) |
| 2040 | + |
| 2041 | + x = pt.make_data_wrapper(x_np) |
| 2042 | + map1 = pt.make_data_wrapper(map1_np) |
| 2043 | + map2 = pt.make_data_wrapper(map2_np) |
| 2044 | + map3 = pt.make_data_wrapper(map3_np) |
| 2045 | + map4 = pt.make_data_wrapper(map4_np) |
| 2046 | + |
| 2047 | + y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4]) |
| 2048 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 2049 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 2050 | + ) |
| 2051 | + |
| 2052 | + if fold_constant_idxs: |
| 2053 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 2054 | + y_transformed = pt.fold_constant_indirections( |
| 2055 | + y_transformed, |
| 2056 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 2057 | + doa) |
| 2058 | + ) |
| 2059 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 2060 | + |
| 2061 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 2062 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 2063 | + |
| 2064 | + |
1924 | 2065 | if __name__ == "__main__":
|
1925 | 2066 | if len(sys.argv) > 1:
|
1926 | 2067 | exec(sys.argv[1])
|
|
0 commit comments