|
| 1 | +// RUN: %eopt --enzyme %s | FileCheck %s |
| 2 | + |
| 3 | +module { |
| 4 | + func.func private @helper(%x: f64) -> f64 { |
| 5 | + %c = arith.constant 1.5 : f64 |
| 6 | + %r = arith.mulf %x, %c : f64 |
| 7 | + return %r : f64 |
| 8 | + } |
| 9 | + |
| 10 | + func.func private @inner_1arg_4ret(%arg0: f64) -> (f64, f64, f64, f64) { |
| 11 | + %cst2 = arith.constant 2.0 : f64 |
| 12 | + %cst3 = arith.constant 3.0 : f64 |
| 13 | + %cst4 = arith.constant 4.0 : f64 |
| 14 | + %cst5 = arith.constant 5.0 : f64 |
| 15 | + %h1 = func.call @helper(%arg0) : (f64) -> f64 |
| 16 | + %a = arith.mulf %h1, %cst2 : f64 |
| 17 | + %h2 = func.call @helper(%a) : (f64) -> f64 |
| 18 | + %b = arith.mulf %h2, %cst3 : f64 |
| 19 | + %h3 = func.call @helper(%b) : (f64) -> f64 |
| 20 | + %c = arith.addf %h3, %cst4 : f64 |
| 21 | + %h4 = func.call @helper(%c) : (f64) -> f64 |
| 22 | + %d = arith.mulf %h4, %cst5 : f64 |
| 23 | + return %a, %b, %c, %d : f64, f64, f64, f64 |
| 24 | + } |
| 25 | + |
| 26 | + func.func private @helper2(%x: f64, %y: f64) -> (f64, f64) { |
| 27 | + %sum = arith.addf %x, %y : f64 |
| 28 | + %prod = arith.mulf %x, %y : f64 |
| 29 | + return %sum, %prod : f64, f64 |
| 30 | + } |
| 31 | + |
| 32 | + func.func @outer_to_diff(%arg0: f64) -> f64 { |
| 33 | + %prep = func.call @helper(%arg0) : (f64) -> f64 |
| 34 | + %results:4 = func.call @inner_1arg_4ret(%prep) : (f64) -> (f64, f64, f64, f64) |
| 35 | + %h:2 = func.call @helper2(%results#0, %results#1) : (f64, f64) -> (f64, f64) |
| 36 | + %sum1 = arith.addf %h#0, %h#1 : f64 |
| 37 | + %sum2 = arith.addf %sum1, %results#2 : f64 |
| 38 | + %sum3 = arith.addf %sum2, %results#3 : f64 |
| 39 | + return %sum3 : f64 |
| 40 | + } |
| 41 | + |
| 42 | + func.func @test(%arg0: f64, %seed: f64) -> f64 { |
| 43 | + %r:2 = enzyme.autodiff @outer_to_diff(%arg0, %seed) { |
| 44 | + activity = [#enzyme<activity enzyme_active>], |
| 45 | + ret_activity = [#enzyme<activity enzyme_active>] |
| 46 | + } : (f64, f64) -> (f64, f64) |
| 47 | + return %r#1 : f64 |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +// CHECK: func.func @test |
| 52 | +// CHECK: call @diffeouter_to_diff |
| 53 | +// CHECK: func.func private @diffeouter_to_diff |
| 54 | +// CHECK: func.func private @diffeinner_1arg_4ret({{.+}}: f64, {{.+}}: f64, {{.+}}: f64, {{.+}}: f64, {{.+}}: f64) -> f64 |
0 commit comments