Skip to content

Commit 01801e2

Browse files
committed
introduce lp.decouple_domain
1 parent 9fc33ae commit 01801e2

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

loopy/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
from loopy.transform.parameter import assume, fix_parameters
123123
from loopy.transform.save import save_and_reload_temporaries
124124
from loopy.transform.add_barrier import add_barrier
125+
from loopy.transform.domain import decouple_domain
125126
from loopy.transform.callable import (register_callable,
126127
merge, inline_callable_kernel, rename_callable)
127128
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
@@ -251,6 +252,8 @@
251252

252253
"add_barrier",
253254

255+
"decouple_domain",
256+
254257
"register_callable",
255258
"merge",
256259

loopy/transform/domain.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni"
2+
3+
__license__ = """
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in
12+
all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
THE SOFTWARE.
21+
"""
22+
23+
__doc__ = """
24+
.. currentmodule:: loopy
25+
26+
.. autofunction:: decouple_domain
27+
"""
28+
29+
import islpy as isl
30+
31+
from loopy.translation_unit import for_each_kernel
32+
from loopy.kernel import LoopKernel
33+
from loopy.diagnostic import LoopyError
34+
from collections.abc import Collection
35+
36+
37+
@for_each_kernel
38+
def decouple_domain(kernel: LoopKernel,
39+
inames: Collection[str],
40+
parent_inames: Collection[str]) -> LoopKernel:
41+
r"""
42+
Returns a copy of *kernel* with altered domains. The home domain of
43+
*inames* i.e. :math:`\mathcal{D}^{\text{home}}({\text{inames}})` is
44+
replaced with two domains :math:`\mathcal{D}_1` and :math:`\mathcal{D}_2`.
45+
:math:`\mathcal{D}_1` is the domain with dimensions corresponding to *inames*
46+
projected out and :math:`\mathcal{D}_2` is the domain with all the dimensions
47+
other than the ones corresponding to *inames* projected out.
48+
49+
.. note::
50+
51+
An error is raised if all the *inames* do not correspond to the same home
52+
domain of *kernel*.
53+
"""
54+
55+
if not inames:
56+
raise LoopyError("No inames were provided to decouple into"
57+
" a different domain.")
58+
59+
hdi = kernel.get_home_domain_index(next(iter(inames)))
60+
for iname in inames:
61+
if kernel.get_home_domain_index(iname) != hdi:
62+
raise LoopyError("inames are not a part of the same home domain.")
63+
64+
for parent_iname in parent_inames:
65+
if parent_iname not in set(kernel.domains[hdi].get_var_dict()):
66+
raise LoopyError(f"Parent iname '{parent_iname}' not a part of the"
67+
f" corresponding home domain '{kernel.domains[hdi]}'.")
68+
69+
all_dims = frozenset(kernel.domains[hdi].get_var_dict())
70+
D1 = kernel.domains[hdi]
71+
D2 = kernel.domains[hdi]
72+
73+
for iname in sorted(all_dims):
74+
if iname in inames:
75+
dt, pos = D1.get_var_dict()[iname]
76+
D1 = D1.project_out(dt, pos, 1)
77+
elif iname in parent_inames:
78+
dt, pos = D2.get_var_dict()[iname]
79+
if dt != isl.dim_type.param:
80+
n_params = D2.dim(isl.dim_type.param)
81+
D2 = D2.move_dims(isl.dim_type.param, n_params, dt, pos, 1)
82+
else:
83+
dt, pos = D2.get_var_dict()[iname]
84+
D2 = D2.project_out(dt, pos, 1)
85+
86+
new_domains = kernel.domains[:]
87+
new_domains[hdi] = D1
88+
new_domains.append(D2)
89+
kernel = kernel.copy(domains=new_domains)
90+
return kernel

0 commit comments

Comments
 (0)