@@ -133,3 +133,65 @@ func.func @dif_overwrite(%cond: i1, %x: memref<f32>, %dx: memref<f32>) {
133133// CHECK: }
134134// CHECK: return
135135// CHECK: }
136+
137+ // -----
138+
139+ func.func private @active_ptr (%cond: i1 , %x: !llvm.ptr ) -> f32 {
140+ %c4 = llvm.mlir.constant (4 : i64 ) : i64
141+ %ptr = scf.if %cond -> !llvm.ptr {
142+ %gep = llvm.getelementptr %x [%c4 ] : (!llvm.ptr , i64 ) -> !llvm.ptr , f32
143+ %val = llvm.load %x : !llvm.ptr -> f32
144+ %cos = math.cos %val : f32
145+ llvm.store %cos , %x : f32 , !llvm.ptr
146+ scf.yield %gep : !llvm.ptr
147+ } else {
148+ scf.yield %x : !llvm.ptr
149+ }
150+
151+ %ld = llvm.load %ptr : !llvm.ptr -> f32
152+ %cos = math.sin %ld : f32
153+ return %cos : f32
154+ }
155+
156+ func.func @dif_overwrite (%cond: i1 , %x: !llvm.ptr , %dx: !llvm.ptr , %dres: f32 ) {
157+ enzyme.autodiff @active_ptr (%cond , %x , %dx , %dres ) {
158+ activity =[#enzyme <activity enzyme_const >, #enzyme <activity enzyme_dup >],
159+ ret_activity =[#enzyme <activity enzyme_activenoneed >]
160+ } : (i1 , !llvm.ptr , !llvm.ptr , f32 ) -> ()
161+ return
162+ }
163+
164+ // CHECK-LABEL: func.func private @diffeactive_ptr(
165+ // CHECK-SAME: %[[ARG0:.*]]: i1,
166+ // CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr,
167+ // CHECK-SAME: %[[ARG3:.*]]: f32) {
168+ // CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32
169+ // CHECK: %[[IF_0:.*]]:3 = scf.if %[[ARG0]] -> (!llvm.ptr, !llvm.ptr, f32) {
170+ // CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[ARG2]][4] : (!llvm.ptr) -> !llvm.ptr, f32
171+ // CHECK: %[[GETELEMENTPTR_1:.*]] = llvm.getelementptr %[[ARG1]][4] : (!llvm.ptr) -> !llvm.ptr, f32
172+ // CHECK: %[[LOAD_0:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> f32
173+ // CHECK: %[[COS_0:.*]] = math.cos %[[LOAD_0]] : f32
174+ // CHECK: llvm.store %[[COS_0]], %[[ARG1]] : f32, !llvm.ptr
175+ // CHECK: scf.yield %[[GETELEMENTPTR_1]], %[[GETELEMENTPTR_0]], %[[LOAD_0]] : !llvm.ptr, !llvm.ptr, f32
176+ // CHECK: } else {
177+ // CHECK: scf.yield %[[ARG1]], %[[ARG2]], %[[CONSTANT_0]] : !llvm.ptr, !llvm.ptr, f32
178+ // CHECK: }
179+ // CHECK: %[[LOAD_1:.*]] = llvm.load %[[VAL_0:.*]]#0 : !llvm.ptr -> f32
180+ // CHECK: %[[COS_1:.*]] = math.cos %[[LOAD_1]] : f32
181+ // CHECK: %[[MULF_0:.*]] = arith.mulf %[[ARG3]], %[[COS_1]] : f32
182+ // CHECK: %[[LOAD_2:.*]] = llvm.load %[[VAL_0]]#1 : !llvm.ptr -> f32
183+ // CHECK: %[[ADDF_0:.*]] = arith.addf %[[LOAD_2]], %[[MULF_0]] : f32
184+ // CHECK: llvm.store %[[ADDF_0]], %[[VAL_0]]#1 : f32, !llvm.ptr
185+ // CHECK: scf.if %[[ARG0]] {
186+ // CHECK: %[[LOAD_3:.*]] = llvm.load %[[ARG2]] : !llvm.ptr -> f32
187+ // CHECK: llvm.store %[[CONSTANT_0]], %[[ARG2]] : f32, !llvm.ptr
188+ // CHECK: %[[SIN_0:.*]] = math.sin %[[VAL_0]]#2 : f32
189+ // CHECK: %[[NEGF_0:.*]] = arith.negf %[[SIN_0]] : f32
190+ // CHECK: %[[MULF_1:.*]] = arith.mulf %[[LOAD_3]], %[[NEGF_0]] : f32
191+ // CHECK: %[[LOAD_4:.*]] = llvm.load %[[ARG2]] : !llvm.ptr -> f32
192+ // CHECK: %[[ADDF_1:.*]] = arith.addf %[[LOAD_4]], %[[MULF_1]] : f32
193+ // CHECK: llvm.store %[[ADDF_1]], %[[ARG2]] : f32, !llvm.ptr
194+ // CHECK: } else {
195+ // CHECK: }
196+ // CHECK: return
197+ // CHECK: }
0 commit comments