@@ -1519,6 +1519,48 @@ def test_prefetch_to_same_temp_var(ctx_factory):
1519
1519
lp .auto_test_vs_ref (ref_tunit , ctx , t_unit )
1520
1520
1521
1521
1522
+ def test_sum_redn_algebraic_transforms (ctx_factory ):
1523
+ from pymbolic import variables
1524
+ from loopy .symbolic import Reduction
1525
+
1526
+ t_unit = lp .make_kernel (
1527
+ "{[e,i,j,x,r]: 0<=e<N_e and 0<=i,j<35 and 0<=x,r<3}" ,
1528
+ """
1529
+ y[i] = sum([r,j], J[x, r, e]*D[r,i,j]*u[e,j])
1530
+ """ ,
1531
+ [lp .GlobalArg ("J,D,u" , dtype = np .float64 , shape = lp .auto ),
1532
+ ...],
1533
+ )
1534
+ knl = t_unit .default_entrypoint
1535
+
1536
+ knl = lp .split_reduction_inward (knl , "j" )
1537
+ knl = lp .hoist_invariant_multiplicative_terms_in_sum_reduction (
1538
+ knl ,
1539
+ reduction_inames = "j"
1540
+ )
1541
+ knl = lp .extract_multiplicative_terms_in_sum_reduction_as_subst (
1542
+ knl ,
1543
+ within = None ,
1544
+ subst_name = "grad_without_jacobi_subst" ,
1545
+ arguments = variables ("r i e" ),
1546
+ terms_filter = lambda x : isinstance (x , Reduction )
1547
+ )
1548
+
1549
+ transformed_t_unit = t_unit .with_kernel (knl )
1550
+ transformed_t_unit = lp .precompute (
1551
+ transformed_t_unit ,
1552
+ "grad_without_jacobi_subst" ,
1553
+ sweep_inames = ["r" , "i" ],
1554
+ precompute_outer_inames = frozenset ({"e" }),
1555
+ temporary_address_space = lp .AddressSpace .PRIVATE )
1556
+
1557
+ x1 = lp .get_op_map (t_unit , subgroup_size = 1 ).eval_and_sum ({"N_e" : 1 })
1558
+ x2 = lp .get_op_map (transformed_t_unit , subgroup_size = 1 ).eval_and_sum ({"N_e" : 1 })
1559
+
1560
+ assert x1 == 33075
1561
+ assert x2 == 7980 # i.e. demonstrates a 4.14x fewer flops
1562
+
1563
+
1522
1564
if __name__ == "__main__" :
1523
1565
if len (sys .argv ) > 1 :
1524
1566
exec (sys .argv [1 ])
0 commit comments