Skip to content

Commit 7f51390

Browse files
committed
Test for transforms.reduction
1 parent 2f695ba commit 7f51390

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

test/test_transform.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,48 @@ def test_prefetch_to_same_temp_var(ctx_factory):
15191519
lp.auto_test_vs_ref(ref_tunit, ctx, t_unit)
15201520

15211521

1522+
def test_sum_redn_algebraic_transforms(ctx_factory):
1523+
from pymbolic import variables
1524+
from loopy.symbolic import Reduction
1525+
1526+
t_unit = lp.make_kernel(
1527+
"{[e,i,j,x,r]: 0<=e<N_e and 0<=i,j<35 and 0<=x,r<3}",
1528+
"""
1529+
y[i] = sum([r,j], J[x, r, e]*D[r,i,j]*u[e,j])
1530+
""",
1531+
[lp.GlobalArg("J,D,u", dtype=np.float64, shape=lp.auto),
1532+
...],
1533+
)
1534+
knl = t_unit.default_entrypoint
1535+
1536+
knl = lp.split_reduction_inward(knl, "j")
1537+
knl = lp.hoist_invariant_multiplicative_terms_in_sum_reduction(
1538+
knl,
1539+
reduction_inames="j"
1540+
)
1541+
knl = lp.extract_multiplicative_terms_in_sum_reduction_as_subst(
1542+
knl,
1543+
within=None,
1544+
subst_name="grad_without_jacobi_subst",
1545+
arguments=variables("r i e"),
1546+
terms_filter=lambda x: isinstance(x, Reduction)
1547+
)
1548+
1549+
transformed_t_unit = t_unit.with_kernel(knl)
1550+
transformed_t_unit = lp.precompute(
1551+
transformed_t_unit,
1552+
"grad_without_jacobi_subst",
1553+
sweep_inames=["r", "i"],
1554+
precompute_outer_inames=frozenset({"e"}),
1555+
temporary_address_space=lp.AddressSpace.PRIVATE)
1556+
1557+
x1 = lp.get_op_map(t_unit, subgroup_size=1).eval_and_sum({"N_e": 1})
1558+
x2 = lp.get_op_map(transformed_t_unit, subgroup_size=1).eval_and_sum({"N_e": 1})
1559+
1560+
assert x1 == 33075
1561+
assert x2 == 7980 # i.e. demonstrates a 4.14x fewer flops
1562+
1563+
15221564
if __name__ == "__main__":
15231565
if len(sys.argv) > 1:
15241566
exec(sys.argv[1])

0 commit comments

Comments
 (0)