1- // RUN: %eopt %s --enzyme-wrap="infn=reduce outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --canonicalize -- enzyme-simplify-math --canonicalize | FileCheck %s
1+ // RUN: %eopt %s --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math --canonicalize --split-input-file | FileCheck %s
22
33func.func @reduce (%x: f32 , %ub: index ) -> (f32 ) {
44 %lb = arith.constant 0 : index
@@ -22,7 +22,15 @@ func.func @reduce(%x: f32, %ub: index) -> (f32) {
2222 return %sum2 : f32
2323}
2424
25- // CHECK: func.func @reduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
25+ func.func @dreduce (%x: f32 , %ub: index , %dres: f32 ) -> (f32 ) {
26+ %res = enzyme.autodiff @reduce (%x , %ub , %dres ) {
27+ activity = [#enzyme <activity enzyme_active >, #enzyme <activity enzyme_const >],
28+ ret_activity = [#enzyme <activity enzyme_activenoneed >]
29+ } : (f32 , index , f32 ) -> f32
30+ return %res : f32
31+ }
32+
33+ // CHECK: func.func private @differeduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
2634// CHECK-NEXT: %c3 = arith.constant 3 : index
2735// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f32
2836// CHECK-NEXT: %c4 = arith.constant 4 : index
@@ -61,3 +69,74 @@ func.func @reduce(%x: f32, %ub: index) -> (f32) {
6169// CHECK-NEXT: memref.dealloc %alloc : memref<?x4xf32>
6270// CHECK-NEXT: return %[[v1]]#1 : f32
6371// CHECK-NEXT: }
72+
73+ // -----
74+
75+ func.func private @reverse_index (%lb: index , %ub: index , %x: memref <?xf32 >) -> f32 {
76+ %c1 = arith.constant 1 : index
77+ %zero = arith.constant 0.0 : f32
78+ %poison = ub.poison : f32
79+ // Verify correct loop indices. e.g. if the loop iterates [3, 4, 5, 6],
80+ // the reversed loop indices are [6, 5, 4, 3] while the *canonical*
81+ // reversed indices are [3, 2, 1, 0]. Memref accesses should use the
82+ // reversed indices while caches should use the canonical reversed indices.
83+ %res:2 = scf.for %iv = %lb to %ub step %c1 iter_args (%acc = %zero , %acc2 = %poison ) -> (f32 , f32 ) {
84+ %ld = memref.load %x [%iv ] : memref <?xf32 >
85+ %mulf = arith.mulf %ld , %ld : f32
86+ memref.store %mulf , %x [%iv ] : memref <?xf32 >
87+ %addf = arith.addf %mulf , %acc : f32
88+ scf.yield %addf , %addf : f32 , f32
89+ }
90+ return %res#1 : f32
91+ }
92+
93+ func.func @dreverse_index (%lb: index , %ub: index , %x: memref <?xf32 >, %dx: memref <?xf32 >, %dr: f32 ) {
94+ enzyme.autodiff @reverse_index (%lb , %ub , %x , %dx , %dr ) {
95+ activity = [#enzyme <activity enzyme_const >, #enzyme <activity enzyme_const >, #enzyme <activity enzyme_dup >],
96+ ret_activity = [#enzyme <activity enzyme_activenoneed >]
97+ } : (index , index , memref <?xf32 >, memref <?xf32 >, f32 ) -> ()
98+ return
99+ }
100+
101+ // CHECK-LABEL: func.func private @differeverse_index(
102+ // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index,
103+ // CHECK-SAME: %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<?xf32>,
104+ // CHECK-SAME: %[[ARG4:.*]]: f32) {
105+ // CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
106+ // CHECK: %[[CONSTANT_1:.*]] = arith.constant 0.000000e+00 : f32
107+ // CHECK: %[[SUBI_0:.*]] = arith.subi %[[ARG1]], %[[ARG0]] : index
108+ // CHECK: %[[ALLOC_0:.*]] = memref.alloc(%[[SUBI_0]]) : memref<?xf32>
109+ // CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[CONSTANT_1]]) -> (f32) {
110+ // CHECK: %[[SUBI_1:.*]] = arith.subi %[[VAL_0]], %[[ARG0]] : index
111+ // CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
112+ // CHECK: memref.store %[[LOAD_0]], %[[ALLOC_0]]{{\[}}%[[SUBI_1]]] : memref<?xf32>
113+ // CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_0]], %[[LOAD_0]] : f32
114+ // CHECK: memref.store %[[MULF_0]], %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
115+ // CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_1]] : f32
116+ // CHECK: scf.yield %[[ADDF_0]] : f32
117+ // CHECK: }
118+ // CHECK: %[[SUBI_2:.*]] = arith.subi %[[ARG1]], %[[ARG0]] : index
119+ // CHECK: %[[FOR_1:.*]]:2 = scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_3:.*]] = %[[CONSTANT_1]], %[[VAL_4:.*]] = %[[ARG4]]) -> (f32, f32) {
120+ // CHECK: %[[SUBI_3:.*]] = arith.subi %[[VAL_2]], %[[ARG0]] : index
121+ // CHECK: %[[SUBI_4:.*]] = arith.subi %[[SUBI_2]], %[[CONSTANT_0]] : index
122+ // CHECK: %[[RCANON_IV:.*]] = arith.subi %[[SUBI_4]], %[[SUBI_3]] : index
123+ // CHECK: %[[ADDI_0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : index
124+ // CHECK: %[[SUBI_6:.*]] = arith.subi %[[ADDI_0]], %[[CONSTANT_0]] : index
125+ // CHECK: %[[R_IV:.*]] = arith.subi %[[SUBI_6]], %[[VAL_2]] : index
126+ // CHECK: %[[LOAD_1:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[RCANON_IV]]] : memref<?xf32>
127+ // The gradient signal should be added together
128+ // CHECK: %[[ADDF_1:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
129+ // CHECK: %[[LOAD_2:.*]] = memref.load %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
130+ // CHECK: %[[ADDF_2:.*]] = arith.addf %[[ADDF_1]], %[[LOAD_2]] : f32
131+ // CHECK: memref.store %[[CONSTANT_1]], %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
132+ // CHECK: %[[MULF_1:.*]] = arith.mulf %[[ADDF_2]], %[[LOAD_1]] : f32
133+ // CHECK: %[[MULF_2:.*]] = arith.mulf %[[ADDF_2]], %[[LOAD_1]] : f32
134+ // CHECK: %[[ADDF_3:.*]] = arith.addf %[[MULF_1]], %[[MULF_2]] : f32
135+ // CHECK: %[[LOAD_3:.*]] = memref.load %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
136+ // CHECK: %[[ADDF_4:.*]] = arith.addf %[[LOAD_3]], %[[ADDF_3]] : f32
137+ // CHECK: memref.store %[[ADDF_4]], %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
138+ // CHECK: scf.yield %[[ADDF_1]], %[[CONSTANT_1]] : f32, f32
139+ // CHECK: }
140+ // CHECK: memref.dealloc %[[ALLOC_0]] : memref<?xf32>
141+ // CHECK: return
142+ // CHECK: }
0 commit comments