|
44 | 44 | from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa
|
45 | 45 |
|
46 | 46 | import pytato as pt
|
47 |
| -from testlib import assert_allclose_to_numpy, get_random_pt_dag |
| 47 | +from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref |
48 | 48 | import pymbolic.primitives as p
|
49 | 49 |
|
50 | 50 |
|
@@ -1854,6 +1854,147 @@ def test_pad(ctx_factory):
|
1854 | 1854 | np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array)
|
1855 | 1855 |
|
1856 | 1856 |
|
| 1857 | +def _evaluator_for_indirection_folding(cl_ctx, dictofarys): |
| 1858 | + from immutables import Map |
| 1859 | + cq = cl.CommandQueue(cl_ctx) |
| 1860 | + _, out_dict = pt.generate_loopy(dictofarys)(cq) |
| 1861 | + return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()}) |
| 1862 | + |
| 1863 | + |
| 1864 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1865 | +def test_push_indirections_0(ctx_factory, fold_constant_idxs): |
| 1866 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1867 | + are_all_indexer_arrays_datawrappers) |
| 1868 | + |
| 1869 | + cl_ctx = cl.create_some_context() |
| 1870 | + rng = np.random.default_rng(0) |
| 1871 | + x_np = rng.random((10, 4)) |
| 1872 | + map1_np = rng.integers(0, 10, size=17) |
| 1873 | + map2_np = rng.integers(0, 17, size=29) |
| 1874 | + |
| 1875 | + x = pt.make_data_wrapper(x_np) |
| 1876 | + map1 = pt.make_data_wrapper(map1_np) |
| 1877 | + map2 = pt.make_data_wrapper(map2_np) |
| 1878 | + |
| 1879 | + y = 3.14 * ((42*((2*x)[map1]))[map2]) |
| 1880 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 1881 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 1882 | + ) |
| 1883 | + |
| 1884 | + if fold_constant_idxs: |
| 1885 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 1886 | + y_transformed = pt.fold_constant_indirections( |
| 1887 | + y_transformed, |
| 1888 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 1889 | + doa) |
| 1890 | + ) |
| 1891 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 1892 | + |
| 1893 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 1894 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 1895 | + |
| 1896 | + |
| 1897 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1898 | +def test_push_indirections_1(ctx_factory, fold_constant_idxs): |
| 1899 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1900 | + are_all_indexer_arrays_datawrappers) |
| 1901 | + |
| 1902 | + cl_ctx = cl.create_some_context() |
| 1903 | + rng = np.random.default_rng(0) |
| 1904 | + x_np = rng.random((100, 4)) |
| 1905 | + map1_np = rng.integers(0, 20, size=17) |
| 1906 | + |
| 1907 | + x = pt.make_data_wrapper(x_np) |
| 1908 | + map1 = pt.make_data_wrapper(map1_np) |
| 1909 | + |
| 1910 | + y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1]) |
| 1911 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 1912 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 1913 | + ) |
| 1914 | + |
| 1915 | + if fold_constant_idxs: |
| 1916 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 1917 | + y_transformed = pt.fold_constant_indirections( |
| 1918 | + y_transformed, |
| 1919 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 1920 | + doa) |
| 1921 | + ) |
| 1922 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 1923 | + |
| 1924 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 1925 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 1926 | + |
| 1927 | + |
| 1928 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1929 | +def test_push_indirections_2(ctx_factory, fold_constant_idxs): |
| 1930 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1931 | + are_all_indexer_arrays_datawrappers) |
| 1932 | + |
| 1933 | + cl_ctx = cl.create_some_context() |
| 1934 | + rng = np.random.default_rng(0) |
| 1935 | + x_np = rng.random((100, 10)) |
| 1936 | + map1_np = rng.integers(0, 20, size=17) |
| 1937 | + map2_np = rng.integers(0, 4, size=29) |
| 1938 | + |
| 1939 | + x = pt.make_data_wrapper(x_np) |
| 1940 | + map1 = pt.make_data_wrapper(map1_np) |
| 1941 | + map2 = pt.make_data_wrapper(map2_np) |
| 1942 | + |
| 1943 | + y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7] |
| 1944 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 1945 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 1946 | + ) |
| 1947 | + |
| 1948 | + if fold_constant_idxs: |
| 1949 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 1950 | + y_transformed = pt.fold_constant_indirections( |
| 1951 | + y_transformed, |
| 1952 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 1953 | + doa) |
| 1954 | + ) |
| 1955 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 1956 | + |
| 1957 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 1958 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 1959 | + |
| 1960 | + |
| 1961 | +@pytest.mark.parametrize("fold_constant_idxs", (False, True)) |
| 1962 | +def test_push_indirections_3(ctx_factory, fold_constant_idxs): |
| 1963 | + from testlib import (are_all_indexees_materialized_nodes, |
| 1964 | + are_all_indexer_arrays_datawrappers) |
| 1965 | + |
| 1966 | + cl_ctx = cl.create_some_context() |
| 1967 | + rng = np.random.default_rng(0) |
| 1968 | + x_np = rng.random((10, 4)) |
| 1969 | + map1_np = rng.integers(0, 10, size=17) |
| 1970 | + map2_np = rng.integers(0, 17, size=29) |
| 1971 | + map3_np = rng.integers(0, 4, size=60) |
| 1972 | + map4_np = rng.integers(0, 60, size=22) |
| 1973 | + |
| 1974 | + x = pt.make_data_wrapper(x_np) |
| 1975 | + map1 = pt.make_data_wrapper(map1_np) |
| 1976 | + map2 = pt.make_data_wrapper(map2_np) |
| 1977 | + map3 = pt.make_data_wrapper(map3_np) |
| 1978 | + map4 = pt.make_data_wrapper(map4_np) |
| 1979 | + |
| 1980 | + y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4]) |
| 1981 | + y_transformed = pt.push_axis_indirections_towards_materialized_nodes( |
| 1982 | + pt.decouple_multi_axis_indirections_into_single_axis_indirections(y) |
| 1983 | + ) |
| 1984 | + |
| 1985 | + if fold_constant_idxs: |
| 1986 | + assert not are_all_indexer_arrays_datawrappers(y_transformed) |
| 1987 | + y_transformed = pt.fold_constant_indirections( |
| 1988 | + y_transformed, |
| 1989 | + lambda doa: _evaluator_for_indirection_folding(cl_ctx, |
| 1990 | + doa) |
| 1991 | + ) |
| 1992 | + assert are_all_indexer_arrays_datawrappers(y_transformed) |
| 1993 | + |
| 1994 | + auto_test_vs_ref(cl_ctx, y, y_transformed) |
| 1995 | + assert are_all_indexees_materialized_nodes(y_transformed) |
| 1996 | + |
| 1997 | + |
1857 | 1998 | if __name__ == "__main__":
|
1858 | 1999 | if len(sys.argv) > 1:
|
1859 | 2000 | exec(sys.argv[1])
|
|
0 commit comments