Skip to content

Commit c098bb8

Browse files
committed
Add transformations specific to sum-reduction
1 parent 7a72ba9 commit c098bb8

File tree

3 files changed

+287
-0
lines changed

3 files changed

+287
-0
lines changed

doc/ref_transform.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Manipulating Instructions
8080

8181
.. autofunction:: add_barrier
8282

83+
Manipulating Reductions
84+
-----------------------
85+
86+
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
87+
88+
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
89+
8390
Registering Library Routines
8491
----------------------------
8592

loopy/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
from loopy.transform.parameter import assume, fix_parameters
121121
from loopy.transform.save import save_and_reload_temporaries
122122
from loopy.transform.add_barrier import add_barrier
123+
from loopy.transform.reduction import (
124+
hoist_invariant_multiplicative_terms_in_sum_reduction,
125+
extract_multiplicative_terms_in_sum_reduction_as_subst)
123126
from loopy.transform.callable import (register_callable,
124127
merge, inline_callable_kernel, rename_callable)
125128
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
@@ -247,6 +250,9 @@
247250

248251
"add_barrier",
249252

253+
"hoist_invariant_multiplicative_terms_in_sum_reduction",
254+
"extract_multiplicative_terms_in_sum_reduction_as_subst",
255+
250256
"register_callable",
251257
"merge",
252258

loopy/transform/reduction.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
"""
2+
.. currentmodule:: loopy
3+
4+
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
5+
6+
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
7+
"""
8+
9+
__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
10+
11+
__license__ = """
12+
Permission is hereby granted, free of charge, to any person obtaining a copy
13+
of this software and associated documentation files (the "Software"), to deal
14+
in the Software without restriction, including without limitation the rights
15+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16+
copies of the Software, and to permit persons to whom the Software is
17+
furnished to do so, subject to the following conditions:
18+
19+
The above copyright notice and this permission notice shall be included in
20+
all copies or substantial portions of the Software.
21+
22+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28+
THE SOFTWARE.
29+
"""
30+
31+
import pymbolic.primitives as p
32+
33+
from typing import (FrozenSet, TypeVar, Callable, List, Tuple, Iterable, Union, Any,
34+
Optional, Sequence)
35+
from loopy.symbolic import IdentityMapper, Reduction, CombineMapper
36+
from loopy.kernel import LoopKernel
37+
from loopy.kernel.data import SubstitutionRule
38+
from loopy.diagnostic import LoopyError
39+
40+
41+
# {{{ partition (copied from more-itertools)
42+
43+
Tpart = TypeVar("Tpart")
44+
45+
46+
def partition(pred: Callable[[Tpart], bool],
47+
iterable: Iterable[Tpart]) -> Tuple[List[Tpart],
48+
List[Tpart]]:
49+
"""
50+
Use a predicate to partition entries into false entries and true
51+
entries
52+
"""
53+
# Inspired from https://docs.python.org/3/library/itertools.html
54+
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
55+
from itertools import tee, filterfalse
56+
t1, t2 = tee(iterable)
57+
return list(filterfalse(pred, t1)), list(filter(pred, t2))
58+
59+
# }}}
60+
61+
62+
# {{{ hoist_reduction_invariant_terms
63+
64+
class EinsumTermsHoister(IdentityMapper):
65+
"""
66+
Mapper to hoist products out of a sum-reduction.
67+
68+
.. attribute:: reduction_inames
69+
70+
Inames of the reduction expressions to perform the hoisting.
71+
"""
72+
def __init__(self, reduction_inames: FrozenSet[str]):
73+
super().__init__()
74+
self.reduction_inames = reduction_inames
75+
76+
# type-ignore-reason: super-class.map_reduction returns 'Any'
77+
def map_reduction(self, expr: Reduction # type: ignore[override]
78+
) -> p.Expression:
79+
if frozenset(expr.inames) != self.reduction_inames:
80+
return super().map_reduction(expr)
81+
82+
from loopy.library.reduction import SumReductionOperation
83+
from loopy.symbolic import get_dependencies
84+
if isinstance(expr.expr, p.Product) and isinstance(expr.operation,
85+
SumReductionOperation):
86+
from pymbolic.primitives import flattened_product
87+
multiplicative_terms = (flattened_product(self.rec(expr.expr).children)
88+
.children)
89+
invariants, variants = partition(lambda x: (get_dependencies(x)
90+
& self.reduction_inames),
91+
multiplicative_terms)
92+
93+
return p.Product(tuple(invariants)) * Reduction(
94+
expr.operation,
95+
inames=expr.inames,
96+
expr=p.Product(tuple(variants)),
97+
allow_simultaneous=expr.allow_simultaneous)
98+
else:
99+
raise NotImplementedError(expr.expr)
100+
101+
102+
def hoist_invariant_multiplicative_terms_in_sum_reduction(
103+
kernel: LoopKernel,
104+
reduction_inames: Union[str, FrozenSet[str]],
105+
within: Any = None
106+
) -> LoopKernel:
107+
"""
108+
Hoists loop-invariant multiplicative terms in a sum-reduction expression.
109+
110+
:arg reduction_inames: The inames over which reduction is performed that defines
111+
the reduction expression that is to be transformed.
112+
:arg within: A match expression understood by :func:`loopy.match.parse_match`
113+
that specifies the instructions over which the transformation is to be
114+
performed.
115+
"""
116+
from loopy.transform.instruction import map_instructions
117+
if isinstance(reduction_inames, str):
118+
reduction_inames = frozenset([reduction_inames])
119+
120+
if not (reduction_inames <= kernel.all_inames()):
121+
raise ValueError(f"Some inames in '{reduction_inames}' not a part of"
122+
" the kernel.")
123+
124+
term_hoister = EinsumTermsHoister(reduction_inames)
125+
126+
return map_instructions(kernel,
127+
insn_match=within,
128+
f=lambda x: x.with_transformed_expressions(term_hoister)
129+
)
130+
131+
# }}}
132+
133+
134+
# {{{ extract_multiplicative_terms_in_sum_reduction_as_subst
135+
136+
class ContainsSumReduction(CombineMapper):
137+
"""
138+
Returns *True* only if the mapper maps over an expression containing a
139+
SumReduction operation.
140+
"""
141+
def combine(self, values: Iterable[bool]) -> bool:
142+
return any(values)
143+
144+
# type-ignore-reason: super-class.map_reduction returns 'Any'
145+
def map_reduction(self, expr: Reduction) -> bool: # type: ignore[override]
146+
from loopy.library.reduction import SumReductionOperation
147+
return (isinstance(expr.operation, SumReductionOperation)
148+
or self.rec(expr.expr))
149+
150+
def map_variable(self, expr: p.Variable) -> bool:
151+
return False
152+
153+
def map_algebraic_leaf(self, expr: Any) -> bool:
154+
return False
155+
156+
157+
class MultiplicativeTermReplacer(IdentityMapper):
158+
"""
159+
Primary mapper of
160+
:func:`extract_multiplicative_terms_in_sum_reduction_as_subst`.
161+
"""
162+
def __init__(self,
163+
*,
164+
terms_filter: Callable[[p.Expression], bool],
165+
subst_name: str,
166+
subst_arguments: Tuple[str, ...]) -> None:
167+
self.subst_name = subst_name
168+
self.subst_arguments = subst_arguments
169+
self.terms_filter = terms_filter
170+
super().__init__()
171+
172+
# mutable state to record the expression collected by the terms_filter
173+
self.collected_subst_rule: Optional[SubstitutionRule] = None
174+
175+
# type-ignore-reason: super-class.map_reduction returns 'Any'
176+
def map_reduction(self, expr: Reduction) -> Reduction: # type: ignore[override]
177+
from loopy.library.reduction import SumReductionOperation
178+
from loopy.symbolic import SubstitutionMapper
179+
if isinstance(expr.operation, SumReductionOperation):
180+
if self.collected_subst_rule is not None:
181+
# => there was already a sum-reduction operation -> raise
182+
raise ValueError("Multiple sum reduction expressions found -> not"
183+
" allowed.")
184+
185+
if isinstance(expr.expr, p.Product):
186+
from pymbolic.primitives import flattened_product
187+
terms = flattened_product(expr.expr.children).children
188+
else:
189+
terms = expr.expression
190+
191+
unfiltered_terms, filtered_terms = partition(self.terms_filter, terms)
192+
submap = SubstitutionMapper({
193+
argument_expr: p.Variable(f"arg{i}")
194+
for i, argument_expr in enumerate(self.subst_arguments)}.get)
195+
self.collected_subst_rule = SubstitutionRule(
196+
name=self.subst_name,
197+
arguments=tuple(f"arg{i}" for i in range(len(self.subst_arguments))),
198+
expression=submap(p.Product(tuple(filtered_terms))
199+
if filtered_terms
200+
else 1)
201+
)
202+
return Reduction(
203+
expr.operation,
204+
expr.inames,
205+
p.Product((p.Variable(self.subst_name)(*self.subst_arguments),
206+
*unfiltered_terms)),
207+
expr.allow_simultaneous)
208+
else:
209+
return super().map_reduction(expr)
210+
211+
212+
def extract_multiplicative_terms_in_sum_reduction_as_subst(
213+
kernel: LoopKernel,
214+
within: Any,
215+
subst_name: str,
216+
arguments: Sequence[p.Expression],
217+
terms_filter: Callable[[p.Expression], bool],
218+
) -> LoopKernel:
219+
"""
220+
Returns a copy of *kernel* with a new substitution named *subst_name* and
221+
*arguments* as arguments for the aggregated multiplicative terms in a
222+
sum-reduction expression.
223+
224+
:arg within: A match expression understood by :func:`loopy.match.parse_match`
225+
to specify the instructions over which the transformation is to be
226+
performed.
227+
:arg terms_filter: A callable to filter which terms of the sum-reduction
228+
comprise the body of substitution rule.
229+
:arg arguments: The sub-expressions of the product of the filtered terms that
230+
form the arguments of the extract substitution rule in the same order.
231+
232+
.. note::
233+
234+
A ``LoopyError`` is raised if none or more than 1 sum-reduction expression
235+
appear in *within*.
236+
"""
237+
from loopy.match import parse_match
238+
within = parse_match(within)
239+
240+
matched_insns = [
241+
insn
242+
for insn in kernel.instructions
243+
if within(kernel, insn) and ContainsSumReduction()((insn.expression,
244+
tuple(insn.predicates)))
245+
]
246+
247+
if len(matched_insns) == 0:
248+
raise LoopyError(f"No instructions found matching '{within}'"
249+
" with sum-reductions found.")
250+
if len(matched_insns) > 1:
251+
raise LoopyError(f"More than one instruction found matching '{within}'"
252+
" with sum-reductions found -> not allowed.")
253+
254+
insn, = matched_insns
255+
replacer = MultiplicativeTermReplacer(subst_name=subst_name,
256+
subst_arguments=tuple(arguments),
257+
terms_filter=terms_filter)
258+
new_insn = insn.with_transformed_expressions(replacer)
259+
new_rule = replacer.collected_subst_rule
260+
new_substitutions = dict(kernel.substitutions).copy()
261+
if subst_name in new_substitutions:
262+
raise LoopyError(f"Kernel '{kernel.name}' already contains a substitution"
263+
" rule named '{subst_name}'.")
264+
assert new_rule is not None
265+
new_substitutions[subst_name] = new_rule
266+
267+
return kernel.copy(instructions=[new_insn if insn.id == new_insn.id else insn
268+
for insn in kernel.instructions],
269+
substitutions=new_substitutions)
270+
271+
# }}}
272+
273+
274+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)