|
44 | 44 | from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa
|
45 | 45 |
|
46 | 46 | import pytato as pt
|
47 |
| -from testlib import assert_allclose_to_numpy, get_random_pt_dag |
| 47 | +from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref |
48 | 48 | import pymbolic.primitives as p
|
49 | 49 |
|
50 | 50 |
|
@@ -2002,6 +2002,147 @@ def call_bar(tracer, x, y):
|
2002 | 2002 | np.testing.assert_allclose(result_out[k], expect_out[k])
|
2003 | 2003 |
|
2004 | 2004 |
|
| 2005 | +def _evaluator_for_indirection_folding(cl_ctx, dictofarys): |
| 2006 | + from immutables import Map |
| 2007 | + cq = cl.CommandQueue(cl_ctx) |
| 2008 | + _, out_dict = pt.generate_loopy(dictofarys)(cq) |
| 2009 | + return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()}) |
| 2010 | + |
| 2011 | + |
| 2012 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 2013 | +def test_push_indirections_0(ctx_factory, fold_constant_idxs): |
| 2014 | + from testlib import (are_all_indexees_materialized_nodes, |
| 2015 | + are_all_indexer_arrays_datawrappers) |
| 2016 | + |
| 2017 | + cl_ctx = cl.create_some_context() |
| 2018 | + rng = np.random.default_rng(0) |
| 2019 | + x_np = rng.random((10, 4)) |
| 2020 | + map1_np = rng.integers(0, 10, size=17) |
| 2021 | + map2_np = rng.integers(0, 17, size=29) |
| 2022 | + |
| 2023 | + x = pt.make_data_wrapper(x_np) |
| 2024 | + map1 = pt.make_data_wrapper(map1_np) |
| 2025 | + map2 = pt.make_data_wrapper(map2_np) |
| 2026 | + |
| 2027 | + y = 3.14 * ((42*((2*x)[map1]))[map2]) |
| 2028 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 2029 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 2030 | + ) |
| 2031 | + |
| 2032 | + if fold_constant_idxs: |
| 2033 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 2034 | + y_transformed = pt.fold_constant_indirections( |
| 2035 | + y_transformed, |
| 2036 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 2037 | + doa) |
| 2038 | + ) |
| 2039 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 2040 | + |
| 2041 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 2042 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 2043 | + |
| 2044 | + |
| 2045 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 2046 | +def test_push_indirections_1(ctx_factory, fold_constant_idxs): |
| 2047 | + from testlib import (are_all_indexees_materialized_nodes, |
| 2048 | + are_all_indexer_arrays_datawrappers) |
| 2049 | + |
| 2050 | + cl_ctx = cl.create_some_context() |
| 2051 | + rng = np.random.default_rng(0) |
| 2052 | + x_np = rng.random((100, 4)) |
| 2053 | + map1_np = rng.integers(0, 20, size=17) |
| 2054 | + |
| 2055 | + x = pt.make_data_wrapper(x_np) |
| 2056 | + map1 = pt.make_data_wrapper(map1_np) |
| 2057 | + |
| 2058 | + y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1]) |
| 2059 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 2060 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 2061 | + ) |
| 2062 | + |
| 2063 | + if fold_constant_idxs: |
| 2064 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 2065 | + y_transformed = pt.fold_constant_indirections( |
| 2066 | + y_transformed, |
| 2067 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 2068 | + doa) |
| 2069 | + ) |
| 2070 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 2071 | + |
| 2072 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 2073 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 2074 | + |
| 2075 | + |
| 2076 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 2077 | +def test_push_indirections_2(ctx_factory, fold_constant_idxs): |
| 2078 | + from testlib import (are_all_indexees_materialized_nodes, |
| 2079 | + are_all_indexer_arrays_datawrappers) |
| 2080 | + |
| 2081 | + cl_ctx = cl.create_some_context() |
| 2082 | + rng = np.random.default_rng(0) |
| 2083 | + x_np = rng.random((100, 10)) |
| 2084 | + map1_np = rng.integers(0, 20, size=17) |
| 2085 | + map2_np = rng.integers(0, 4, size=29) |
| 2086 | + |
| 2087 | + x = pt.make_data_wrapper(x_np) |
| 2088 | + map1 = pt.make_data_wrapper(map1_np) |
| 2089 | + map2 = pt.make_data_wrapper(map2_np) |
| 2090 | + |
| 2091 | + y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7] |
| 2092 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 2093 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 2094 | + ) |
| 2095 | + |
| 2096 | + if fold_constant_idxs: |
| 2097 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 2098 | + y_transformed = pt.fold_constant_indirections( |
| 2099 | + y_transformed, |
| 2100 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 2101 | + doa) |
| 2102 | + ) |
| 2103 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 2104 | + |
| 2105 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 2106 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 2107 | + |
| 2108 | + |
| 2109 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 2110 | +def test_push_indirections_3(ctx_factory, fold_constant_idxs): |
| 2111 | + from testlib import (are_all_indexees_materialized_nodes, |
| 2112 | + are_all_indexer_arrays_datawrappers) |
| 2113 | + |
| 2114 | + cl_ctx = cl.create_some_context() |
| 2115 | + rng = np.random.default_rng(0) |
| 2116 | + x_np = rng.random((10, 4)) |
| 2117 | + map1_np = rng.integers(0, 10, size=17) |
| 2118 | + map2_np = rng.integers(0, 17, size=29) |
| 2119 | + map3_np = rng.integers(0, 4, size=60) |
| 2120 | + map4_np = rng.integers(0, 60, size=22) |
| 2121 | + |
| 2122 | + x = pt.make_data_wrapper(x_np) |
| 2123 | + map1 = pt.make_data_wrapper(map1_np) |
| 2124 | + map2 = pt.make_data_wrapper(map2_np) |
| 2125 | + map3 = pt.make_data_wrapper(map3_np) |
| 2126 | + map4 = pt.make_data_wrapper(map4_np) |
| 2127 | + |
| 2128 | + y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4]) |
| 2129 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 2130 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 2131 | + ) |
| 2132 | + |
| 2133 | + if fold_constant_idxs: |
| 2134 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 2135 | + y_transformed = pt.fold_constant_indirections( |
| 2136 | + y_transformed, |
| 2137 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 2138 | + doa) |
| 2139 | + ) |
| 2140 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 2141 | + |
| 2142 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 2143 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 2144 | + |
| 2145 | + |
2005 | 2146 | if __name__ == "__main__":
|
2006 | 2147 | if len(sys.argv) > 1:
|
2007 | 2148 | exec(sys.argv[1])
|
|
0 commit comments