Skip to content

Commit 127e998

Browse files
authored
mlir: implement more Complex derivatives (#2691)
* mlir: implement more Complex derivatives * mul per differet
1 parent 87cd10e commit 127e998

File tree

5 files changed

+69
-3
lines changed

5 files changed

+69
-3
lines changed

enzyme/Enzyme/MLIR/Implementations/Common.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,14 @@ def CheckedDivF : SubRoutine<(Op $diffret, $x),
180180
def LlvmCheckedMulF : LlvmInst<"FMulOp">;
181181
def LlvmExpF : LlvmInst<"ExpOp">;
182182

183+
def ComplexCreate : ComplexInst<"CreateOp">;
184+
def ComplexRe : ComplexInst<"ReOp">;
185+
def ComplexIm : ComplexInst<"ImOp">;
186+
183187
def CosF : MathInst<"CosOp">;
184188
def SinF : MathInst<"SinOp">;
185189
def ExpF : MathInst<"ExpOp">;
186190
def SqrtF : MathInst<"SqrtOp">;
191+
def AbsF : MathInst<"AbsFOp">;
187192

188193
#endif // ENZYME_MLIR_IMPLEMENTATIONS_COMMON

enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,33 @@ def : MLIRDerivative<"complex", "MulOp", (Op $x, $y),
1515
(CMul (DiffeRet), $x)
1616
]
1717
>;
18+
19+
def : MLIRDerivative<"complex", "ImOp", (Op $x),
20+
[
21+
(ComplexCreate
22+
(TypeOf $x),
23+
(ConstantFP<"0", "arith", "ConstantOp">),
24+
(NegF (DiffeRet))
25+
)
26+
],
27+
(ComplexIm (Shadow $x))
28+
>;
29+
30+
def : MLIRDerivative<"complex", "ReOp", (Op $x),
31+
[
32+
(ComplexCreate
33+
(TypeOf $x),
34+
(DiffeRet),
35+
(ConstantFP<"0", "arith", "ConstantOp">)
36+
)
37+
],
38+
(ComplexRe (Shadow $x))
39+
>;
40+
41+
def : MLIRDerivative<"complex", "CreateOp", (Op $re, $im),
42+
[
43+
(ComplexRe (DiffeRet)),
44+
(ComplexIm (DiffeRet))
45+
]
46+
>;
47+

enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "Interfaces/GradientUtils.h"
1717
#include "Interfaces/GradientUtilsReverse.h"
1818
#include "mlir/Dialect/Arith/IR/Arith.h"
19+
#include "mlir/Dialect/Complex/IR/Complex.h"
1920
#include "mlir/Dialect/Math/IR/Math.h"
2021
#include "mlir/IR/DialectRegistry.h"
2122
#include "mlir/Support/LogicalResult.h"

enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def : MLIRDerivative<"math", "ExpOp", (Op $x),
1010
(CheckedMulF (DiffeRet), (ExpF $x))
1111
]
1212
>;
13-
def : MLIRDerivative<"math", "SinOp", (Op $x),
13+
def : MLIRDerivative<"math", "SinOp", (Op $x),
1414
[
1515
(CheckedMulF (DiffeRet), (CosF $x))
1616
]
@@ -27,7 +27,13 @@ def : MLIRDerivative<"math", "AtanOp", (Op $x),
2727
>;
2828
def : MLIRDerivative<"math", "AbsFOp", (Op $x),
2929
[
30-
// TODO: handle complex
31-
(Arith_Select (CmpF (Arith_OGE), $x, (ConstantFP<"0","arith","ConstantOp"> $x)), (DiffeRet), (NegF (DiffeRet)))
30+
(SelectIfComplex $x,
31+
(ComplexCreate
32+
(TypeOf $x),
33+
(CheckedMulF (DiffeRet), (DivF (ComplexRe $x), (AbsF $x))),
34+
(CheckedMulF (DiffeRet), (NegF (DivF (ComplexIm $x), (AbsF $x))))
35+
),
36+
(Arith_Select (CmpF (Arith_OGE), $x, (ConstantFP<"0","arith","ConstantOp"> $x)), (DiffeRet), (NegF (DiffeRet)))
37+
)
3238
]
3339
>;
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %eopt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active retTys=enzyme_active,enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops | FileCheck %s
2+
3+
module {
4+
func.func @main(%x: complex<f32>) -> (f32, f32) {
5+
%0 = complex.re %x : complex<f32>
6+
%1 = complex.im %x : complex<f32>
7+
return %0, %1 : f32, f32
8+
}
9+
}
10+
11+
// CHECK: func.func @main(%arg0: complex<f32>, %arg1: f32, %arg2: f32) -> complex<f32> {
12+
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
13+
// CHECK-NEXT: %cst_0 = complex.constant [0.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
14+
// CHECK-NEXT: %0 = arith.addf %arg1, %cst : f32
15+
// CHECK-NEXT: %1 = arith.addf %arg2, %cst : f32
16+
// CHECK-NEXT: %2 = arith.negf %1 : f32
17+
// CHECK-NEXT: %3 = complex.create %cst, %2 : complex<f32>
18+
// CHECK-NEXT: %4 = complex.conj %3 : complex<f32>
19+
// CHECK-NEXT: %5 = complex.add %cst_0, %4 : complex<f32>
20+
// CHECK-NEXT: %6 = complex.create %0, %cst : complex<f32>
21+
// CHECK-NEXT: %7 = complex.conj %6 : complex<f32>
22+
// CHECK-NEXT: %8 = complex.add %5, %7 : complex<f32>
23+
// CHECK-NEXT: return %8 : complex<f32>
24+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)