Skip to content

Commit aec862f

Browse files
authored
fix (#2693)
1 parent 08c8877 commit aec862f

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class AutoDiffCallRev
160160
}
161161

162162
std::vector<bool> volatile_args(narg, true);
163-
std::vector<bool> returnShadow(narg, false);
163+
std::vector<bool> returnShadow(nret, false);
164164
std::vector<bool> returnPrimal(nret, false);
165165

166166
auto type_args = gutils->TA.getAnalyzedTypeInfo(fn);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %eopt --enzyme %s | FileCheck %s
2+
3+
module {
4+
func.func private @helper(%x: f64) -> f64 {
5+
%c = arith.constant 1.5 : f64
6+
%r = arith.mulf %x, %c : f64
7+
return %r : f64
8+
}
9+
10+
func.func private @inner_1arg_4ret(%arg0: f64) -> (f64, f64, f64, f64) {
11+
%cst2 = arith.constant 2.0 : f64
12+
%cst3 = arith.constant 3.0 : f64
13+
%cst4 = arith.constant 4.0 : f64
14+
%cst5 = arith.constant 5.0 : f64
15+
%h1 = func.call @helper(%arg0) : (f64) -> f64
16+
%a = arith.mulf %h1, %cst2 : f64
17+
%h2 = func.call @helper(%a) : (f64) -> f64
18+
%b = arith.mulf %h2, %cst3 : f64
19+
%h3 = func.call @helper(%b) : (f64) -> f64
20+
%c = arith.addf %h3, %cst4 : f64
21+
%h4 = func.call @helper(%c) : (f64) -> f64
22+
%d = arith.mulf %h4, %cst5 : f64
23+
return %a, %b, %c, %d : f64, f64, f64, f64
24+
}
25+
26+
func.func private @helper2(%x: f64, %y: f64) -> (f64, f64) {
27+
%sum = arith.addf %x, %y : f64
28+
%prod = arith.mulf %x, %y : f64
29+
return %sum, %prod : f64, f64
30+
}
31+
32+
func.func @outer_to_diff(%arg0: f64) -> f64 {
33+
%prep = func.call @helper(%arg0) : (f64) -> f64
34+
%results:4 = func.call @inner_1arg_4ret(%prep) : (f64) -> (f64, f64, f64, f64)
35+
%h:2 = func.call @helper2(%results#0, %results#1) : (f64, f64) -> (f64, f64)
36+
%sum1 = arith.addf %h#0, %h#1 : f64
37+
%sum2 = arith.addf %sum1, %results#2 : f64
38+
%sum3 = arith.addf %sum2, %results#3 : f64
39+
return %sum3 : f64
40+
}
41+
42+
func.func @test(%arg0: f64, %seed: f64) -> f64 {
43+
%r:2 = enzyme.autodiff @outer_to_diff(%arg0, %seed) {
44+
activity = [#enzyme<activity enzyme_active>],
45+
ret_activity = [#enzyme<activity enzyme_active>]
46+
} : (f64, f64) -> (f64, f64)
47+
return %r#1 : f64
48+
}
49+
}
50+
51+
// CHECK: func.func @test
52+
// CHECK: call @diffeouter_to_diff
53+
// CHECK: func.func private @diffeouter_to_diff
54+
// CHECK: func.func private @diffeinner_1arg_4ret({{.+}}: f64, {{.+}}: f64, {{.+}}: f64, {{.+}}: f64, {{.+}}: f64) -> f64

0 commit comments

Comments
 (0)