@@ -10,29 +10,63 @@ func.func @foo(%x : f64) -> f64{
1010 return %out : f64
1111}
1212
13-
1413func.func @dfoo (%x: f64 , %dout : f64 ) -> f64 {
1514 %dx = enzyme.autodiff @foo (%x , %dout ) {
16- activity = [#enzyme <activity enzyme_active >],
17- ret_activity = [#enzyme <acitivity enzyme_activenoneed >]
18- } : (f64 , f64 ) -> (f64 )
19- return %dx : f64
20- }
21-
22- func.func @foo2 (%x : f64 ) -> f64 {
23- %cst = arith.constant 0.0000e+00 : f64
24- %buf = memref.alloca () : memref <f64 >
25- memref.store %x , %buf [] : memref <f64 >
26- %y = memref.load %buf [] : memref <f64 >
27- %out = arith.addf %cst , %y : f64
28- return %out : f64
29- }
30-
31-
32- func.func @dfoo2 (%x: f64 , %dout : f64 ) -> f64 {
33- %dx = enzyme.autodiff @foo2 (%x , %dout ) {
3415 activity = [#enzyme <activity enzyme_active >],
3516 ret_activity = [#enzyme <activity enzyme_activenoneed >]
3617 } : (f64 , f64 ) -> (f64 )
3718 return %dx : f64
3819}
20+
21+ // CHECK-LABEL: func.func private @diffefoo(
22+ // CHECK-SAME: %[[X:[^,]+]]: f64,
23+ // CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 {
24+ // CHECK: %[[C0:.*]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
25+ // CHECK: %[[C1:.*]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
26+ // CHECK: %[[G0:.*]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
27+ // CHECK: "enzyme.set"(%[[G0]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
28+ // CHECK: %[[G1:.*]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
29+ // CHECK: "enzyme.set"(%[[G1]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
30+ // CHECK: %[[G2:.*]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
31+ // CHECK: "enzyme.set"(%[[G2]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
32+ // CHECK: cf.br ^bb1
33+ // CHECK: ^bb1:
34+ // CHECK: %[[T5:.*]] = "enzyme.get"(%[[G2]]) : (!enzyme.Gradient<f64>) -> f64
35+ // CHECK: %[[T6:.*]] = arith.addf %[[T5]], %[[DOUT]] : f64
36+ // CHECK: "enzyme.set"(%[[G2]], %[[T6]]) : (!enzyme.Gradient<f64>, f64) -> ()
37+ // CHECK: %[[T7:.*]] = "enzyme.get"(%[[G2]]) : (!enzyme.Gradient<f64>) -> f64
38+ // CHECK: "enzyme.set"(%[[G2]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
39+ // CHECK: %[[T8:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
40+ // CHECK: "enzyme.set"(%[[G1]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
41+ // CHECK: %[[SCOPE:.*]]:2 = memref.alloca_scope -> (f64, f64) {
42+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<f64>
43+ // CHECK: memref.store %{{.*}}, %[[ALLOCA]][] : memref<f64>
44+ // CHECK: %[[ALLOCA5:.*]] = memref.alloca() : memref<f64>
45+ // CHECK: "enzyme.push"(%[[C0]], %[[ALLOCA]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
46+ // CHECK: memref.store %[[X]], %[[ALLOCA5]][] : memref<f64>
47+ // CHECK: "enzyme.push"(%[[C1]], %[[ALLOCA]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
48+ // CHECK: %[[LD:.*]] = memref.load %[[ALLOCA5]][] : memref<f64>
49+ // CHECK: "enzyme.set"(%[[G0]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
50+ // CHECK: %[[T14:.*]] = "enzyme.get"(%[[G0]]) : (!enzyme.Gradient<f64>) -> f64
51+ // CHECK: %[[T15:.*]] = arith.addf %[[T14]], %[[T7]] : f64
52+ // CHECK: "enzyme.set"(%[[G0]], %[[T15]]) : (!enzyme.Gradient<f64>, f64) -> ()
53+ // CHECK: %[[T16:.*]] = "enzyme.get"(%[[G0]]) : (!enzyme.Gradient<f64>) -> f64
54+ // CHECK: %[[POP1:.*]] = "enzyme.pop"(%[[C1]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
55+ // CHECK: %[[LD18:.*]] = memref.load %[[POP1]][] : memref<f64>
56+ // CHECK: %[[T19:.*]] = arith.addf %[[LD18]], %[[T16]] : f64
57+ // CHECK: memref.store %[[T19]], %[[POP1]][] : memref<f64>
58+ // CHECK: %[[POP0:.*]] = "enzyme.pop"(%[[C0]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
59+ // CHECK: %[[LD21:.*]] = memref.load %[[POP0]][] : memref<f64>
60+ // CHECK: %[[T22:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
61+ // CHECK: %[[T23:.*]] = arith.addf %[[T22]], %[[LD21]] : f64
62+ // CHECK: "enzyme.set"(%[[G1]], %[[T23]]) : (!enzyme.Gradient<f64>, f64) -> ()
63+ // CHECK: memref.store %{{.*}}, %[[POP0]][] : memref<f64>
64+ // CHECK: %[[T24:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
65+ // CHECK: memref.alloca_scope.return %[[LD]], %[[T24]] : f64, f64
66+ // CHECK: }
67+ // CHECK: "enzyme.set"(%[[G1]], %[[T8]]) : (!enzyme.Gradient<f64>, f64) -> ()
68+ // CHECK: %[[T10:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
69+ // CHECK: %[[T11:.*]] = arith.addf %[[T10]], %[[SCOPE]]#1 : f64
70+ // CHECK: "enzyme.set"(%[[G1]], %[[T11]]) : (!enzyme.Gradient<f64>, f64) -> ()
71+ // CHECK: %[[T12:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
72+ // CHECK: return %[[T12]] : f64
0 commit comments