From 0e44d625251d4bfe9350513130306f200d142293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 09:52:53 -0500 Subject: [PATCH 01/23] add support for complex and complex tensor types to `getConstantAttr` --- .../CoreDialectsAutoDiffImplementations.cpp | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index aceebb95117..9e9370959f5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -31,18 +31,27 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return cast(ATI.createNullAttr()); } if (auto T = dyn_cast(type)) { - auto ET = dyn_cast(T.getElementType()); - if (!ET) { - llvm::errs() << " unsupported eltype: " << ET << " of type " << type - << "\n"; + if (auto ET = dyn_cast(T.getElementType())) { + APFloat values[] = {APFloat(ET.getFloatSemantics(), value)}; + return DenseElementsAttr::get(cast(type), + ArrayRef(values)); + } else if (auto ET = dyn_cast(T.getElementType())) { + std::complex values[] = {std::complex(std::stod(value), 0)}; + return DenseElementsAttr::get(cast(type), + ArrayRef>(values)); + } else { + llvm::errs() << " unsupported eltype: " << T.getElementType() + << " of type " << type << "\n"; } - APFloat values[] = {APFloat(ET.getFloatSemantics(), value)}; - return DenseElementsAttr::get(cast(type), - ArrayRef(values)); + } else if (auto T = cast(type)) { + APFloat apvalue(T.getFloatSemantics(), value); + return FloatAttr::get(T, apvalue); + } else if (auto T = cast(type)) { + std::complex cvalue(std::stod(value), 0); + return ComplexAttr::get(T, cvalue); + } else { + llvm::errs() << " unsupported type: " << type << "\n"; } - auto T = cast(type); - APFloat apvalue(T.getFloatSemantics(), value); - return FloatAttr::get(T, apvalue); } void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, From b7e2aad3617cdb890c435de4f2d2cb147fab7543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:27:51 -0500 Subject: [PATCH 02/23] use `std::complex` --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 9e9370959f5..1d30596c86e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -36,9 +36,11 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return DenseElementsAttr::get(cast(type), ArrayRef(values)); } else if (auto ET = dyn_cast(T.getElementType())) { - std::complex values[] = {std::complex(std::stod(value), 0)}; + std::complex values[] = { + std::complex(APFloat(ET.getElementType().getFloatSemantics(), value), + APFloat(ET.getElementType().getFloatSemantics(), "0"))}; return DenseElementsAttr::get(cast(type), - ArrayRef>(values)); + ArrayRef>(values)); } else { llvm::errs() << " unsupported eltype: " << T.getElementType() << " of type " << type << "\n"; From 6e5c7616f1b917cd44adad65b72cd0d89a6d1566 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:27:58 -0500 Subject: [PATCH 03/23] format --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 1d30596c86e..f4deddd507f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -42,7 +42,7 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return DenseElementsAttr::get(cast(type), ArrayRef>(values)); } else { - llvm::errs() << " unsupported eltype: " << T.getElementType() + llvm::errs() << " unsupported eltype: " << T.getElementType() << " of type " << type << "\n"; } } else if (auto T = cast(type)) { From 2a5c96bb18b5d27c5dc15aef214e7597fcd56f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:29:56 -0500 Subject: [PATCH 04/23] fix --- .../CoreDialectsAutoDiffImplementations.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index f4deddd507f..482a1178263 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -35,10 +35,11 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, APFloat values[] = {APFloat(ET.getFloatSemantics(), value)}; return DenseElementsAttr::get(cast(type), ArrayRef(values)); - } else if (auto ET = dyn_cast(T.getElementType())) { + } else if (auto CET = dyn_cast(T.getElementType())) { + auto ET = cast(CET.getElementType()); std::complex values[] = { - std::complex(APFloat(ET.getElementType().getFloatSemantics(), value), - APFloat(ET.getElementType().getFloatSemantics(), "0"))}; + std::complex(APFloat(ET.getFloatSemantics(), value), + APFloat(ET.getFloatSemantics(), "0"))}; return DenseElementsAttr::get(cast(type), ArrayRef>(values)); } else { From b18fd8bc20f90723f9029cdc6dcc4466d88f35ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:31:57 -0500 Subject: [PATCH 05/23] fix scalar case --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 482a1178263..0377c4851d4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -50,7 +50,9 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, APFloat apvalue(T.getFloatSemantics(), value); return FloatAttr::get(T, apvalue); } else if (auto T = cast(type)) { - std::complex cvalue(std::stod(value), 0); + auto F = cast(T.getElementType()); + std::complex cvalue(APFloat(F.getFloatSemantics(), value), + APFloat(F.getFloatSemantics(), "0")); return ComplexAttr::get(T, cvalue); } else { llvm::errs() << " unsupported type: " << type << "\n"; From 269cf6343880fa1d80ea7583bfcc46ced0eba101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:37:51 -0500 Subject: [PATCH 06/23] fix scalar case --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 0377c4851d4..215d7910627 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -51,9 +51,8 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return FloatAttr::get(T, apvalue); } else if (auto T = cast(type)) { auto F = cast(T.getElementType()); - std::complex cvalue(APFloat(F.getFloatSemantics(), value), - APFloat(F.getFloatSemantics(), "0")); - return ComplexAttr::get(T, cvalue); + return complex::NumberAttr::get(T, APFloat(F.getFloatSemantics(), value), + APFloat(F.getFloatSemantics(), "0")); } else { llvm::errs() << " unsupported type: " << type << "\n"; } From 8ec7a8322c0e2e112d0ce866a8dd303847f81409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:43:52 -0500 Subject: [PATCH 07/23] fix ref to complex namespace --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 215d7910627..33048decc8b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -18,6 +18,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "Passes/Utils.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Matchers.h" using namespace mlir; @@ -51,8 +52,9 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return FloatAttr::get(T, apvalue); } else if (auto T = cast(type)) { auto F = cast(T.getElementType()); - return complex::NumberAttr::get(T, APFloat(F.getFloatSemantics(), value), - APFloat(F.getFloatSemantics(), "0")); + return mlir::complex::NumberAttr::get(T, + APFloat(F.getFloatSemantics(), value), + APFloat(F.getFloatSemantics(), "0")); } else { llvm::errs() << " unsupported type: " << type << "\n"; } From 717a4c3ea9d0a356216e861047747b58878e8fa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Apr 2026 10:46:52 -0500 Subject: [PATCH 08/23] fix --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 33048decc8b..3525214dc4e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -52,9 +52,9 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return FloatAttr::get(T, apvalue); } else if (auto T = cast(type)) { auto F = cast(T.getElementType()); - return mlir::complex::NumberAttr::get(T, - APFloat(F.getFloatSemantics(), value), - APFloat(F.getFloatSemantics(), "0")); + return mlir::complex::NumberAttr::get( + T, APFloat(F.getFloatSemantics(), value).convertToDouble(), + APFloat(F.getFloatSemantics(), "0").convertToDouble()); } else { llvm::errs() << " unsupported type: " << type << "\n"; } From a50d275d225230966ab799f2d622dfcbc166b4b0 Mon Sep 17 00:00:00 2001 From: Sergio Sanchez Ramirez <15837247+mofeing@users.noreply.github.com> Date: Thu, 30 Apr 2026 20:06:38 +0200 Subject: [PATCH 09/23] try test on `complex::RsqrtOp` diff --- enzyme/Enzyme/MLIR/Implementations/Common.td | 3 +++ .../Implementations/ComplexDerivatives.td | 14 ++++++++++++++ .../test/MLIR/ForwardMode/complex_rsqrt.mlir | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+) create mode 100644 enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 01a53969510..70fdcecedf5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -183,6 +183,9 @@ def LlvmExpF : LlvmInst<"ExpOp">; def ComplexCreate : ComplexInst<"CreateOp">; def ComplexRe : ComplexInst<"ReOp">; def ComplexIm : ComplexInst<"ImOp">; +def ComplexPow : ComplexInst<"PowOp">; +def ComplexSqrt : ComplexInst<"SqrtOp">; +def ComplexRsqrt : ComplexInst<"RsqrtOp">; def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td index 0e968aa8d32..62a1db20181 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td @@ -45,3 +45,17 @@ def : MLIRDerivative<"complex", "CreateOp", (Op $re, $im), ] >; +def : MLIRDerivative<"complex", "SqrtOp", (Op $x), + [ + (ComplexRsqrt (DiffeRet)) + ] + >; + +def : MLIRDerivative<"complex", "RsqrtOp", (Op $x), + [ + (ComplexPow + (DiffeRet), + (ConstantFP<"-1.5", "complex", "ConstantOp">) + ) + ] + >; diff --git a/enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir b/enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir new file mode 100644 index 00000000000..cb157d02119 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir @@ -0,0 +1,19 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s + +module { + func.func @rsqrt(%x: complex) -> complex { + %next = complex.rsqrt %x : complex + return %next : complex + } + + func.func @drsqrt(%x: complex, %dx: complex) -> complex { + %r = enzyme.fwddiff @rsqrt(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (complex, complex) -> complex + return %r : complex + } +} + +// CHECK: func.func private @fwddiffersqrt(%arg0: complex, %arg1: complex) +// CHECK-NEXT: %0 = complex.constant [-1.5, 0.0] : complex +// CHECK-NEXT: %1 = complex.rsqrt %arg1, %0 : complex +// CHECK-NEXT: return %1 : complex +// CHECK-NEXT: } From a234306499d4cdcfabe816faad6ac8c4d9a4e4cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 6 May 2026 07:52:16 -0500 Subject: [PATCH 10/23] fix `getConstantAttr` for `ComplexType` `mlir::complex::ConstantOp` doesn't accept a `NumberAttr`, but an `ArrayAttr` --- .../CoreDialectsAutoDiffImplementations.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 3525214dc4e..80d7e8bca6f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -52,9 +52,10 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return FloatAttr::get(T, apvalue); } else if (auto T = cast(type)) { auto F = cast(T.getElementType()); - return mlir::complex::NumberAttr::get( - T, APFloat(F.getFloatSemantics(), value).convertToDouble(), - APFloat(F.getFloatSemantics(), "0").convertToDouble()); + return mlir::ArrayAttr::get({ + FloatAttr::get(F, APFloat(F.getFloatSemantics(), value)), + FloatAttr::get(F, APFloat(F.getFloatSemantics(), "0")); + }); } else { llvm::errs() << " unsupported type: " << type << "\n"; } From bf8fdf1e4584665ae197534a40397a150635a0d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 6 May 2026 07:53:07 -0500 Subject: [PATCH 11/23] fix casting on `complex.rsqrt` derivative --- enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td index 62a1db20181..290f4442209 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td @@ -55,7 +55,7 @@ def : MLIRDerivative<"complex", "RsqrtOp", (Op $x), [ (ComplexPow (DiffeRet), - (ConstantFP<"-1.5", "complex", "ConstantOp">) + (ConstantFP<"-1.5", "complex", "ConstantOp", "mlir::ArrayAttr">) ) ] >; From 1810c60caed63299186088f8e668dfc985b08e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 05:37:31 -0500 Subject: [PATCH 12/23] comment `ComplexType` case in `getConstantAttr` --- .../CoreDialectsAutoDiffImplementations.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 80d7e8bca6f..54688b720f2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -50,12 +50,13 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, } else if (auto T = cast(type)) { APFloat apvalue(T.getFloatSemantics(), value); return FloatAttr::get(T, apvalue); - } else if (auto T = cast(type)) { - auto F = cast(T.getElementType()); - return mlir::ArrayAttr::get({ - FloatAttr::get(F, APFloat(F.getFloatSemantics(), value)), - FloatAttr::get(F, APFloat(F.getFloatSemantics(), "0")); - }); + // NOTE `complex::ConstantOp` doesn't accept `TypedAttr`, only `ArrayAttr` + // } else if (auto T = cast(type)) { + // auto F = cast(T.getElementType()); + // return mlir::ArrayAttr::get({ + // FloatAttr::get(F, APFloat(F.getFloatSemantics(), value)), + // FloatAttr::get(F, APFloat(F.getFloatSemantics(), "0")); + // }); } else { llvm::errs() << " unsupported type: " << type << "\n"; } From 971e1a051fb9f87563889ca39ea49a855f1e1e5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 05:46:34 -0500 Subject: [PATCH 13/23] Revert "fix casting on `complex.rsqrt` derivative" This reverts commit 1e7a1a1b45e7a9647b42239de87e35fd2494254a. --- enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td index 290f4442209..62a1db20181 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td @@ -55,7 +55,7 @@ def : MLIRDerivative<"complex", "RsqrtOp", (Op $x), [ (ComplexPow (DiffeRet), - (ConstantFP<"-1.5", "complex", "ConstantOp", "mlir::ArrayAttr">) + (ConstantFP<"-1.5", "complex", "ConstantOp">) ) ] >; From 8628eee259719288ff6dfa183609e15f47aea5cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 05:46:42 -0500 Subject: [PATCH 14/23] Revert "try test on `complex::RsqrtOp` diff" This reverts commit 75456e983253c212369fb533092b2f14f0fa4279. --- enzyme/Enzyme/MLIR/Implementations/Common.td | 3 --- .../Implementations/ComplexDerivatives.td | 14 -------------- .../test/MLIR/ForwardMode/complex_rsqrt.mlir | 19 ------------------- 3 files changed, 36 deletions(-) delete mode 100644 enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 70fdcecedf5..01a53969510 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -183,9 +183,6 @@ def LlvmExpF : LlvmInst<"ExpOp">; def ComplexCreate : ComplexInst<"CreateOp">; def ComplexRe : ComplexInst<"ReOp">; def ComplexIm : ComplexInst<"ImOp">; -def ComplexPow : ComplexInst<"PowOp">; -def ComplexSqrt : ComplexInst<"SqrtOp">; -def ComplexRsqrt : ComplexInst<"RsqrtOp">; def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td index 62a1db20181..0e968aa8d32 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td @@ -45,17 +45,3 @@ def : MLIRDerivative<"complex", "CreateOp", (Op $re, $im), ] >; -def : MLIRDerivative<"complex", "SqrtOp", (Op $x), - [ - (ComplexRsqrt (DiffeRet)) - ] - >; - -def : MLIRDerivative<"complex", "RsqrtOp", (Op $x), - [ - (ComplexPow - (DiffeRet), - (ConstantFP<"-1.5", "complex", "ConstantOp">) - ) - ] - >; diff --git a/enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir b/enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir deleted file mode 100644 index cb157d02119..00000000000 --- a/enzyme/test/MLIR/ForwardMode/complex_rsqrt.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s - -module { - func.func @rsqrt(%x: complex) -> complex { - %next = complex.rsqrt %x : complex - return %next : complex - } - - func.func @drsqrt(%x: complex, %dx: complex) -> complex { - %r = enzyme.fwddiff @rsqrt(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (complex, complex) -> complex - return %r : complex - } -} - -// CHECK: func.func private @fwddiffersqrt(%arg0: complex, %arg1: complex) -// CHECK-NEXT: %0 = complex.constant [-1.5, 0.0] : complex -// CHECK-NEXT: %1 = complex.rsqrt %arg1, %0 : complex -// CHECK-NEXT: return %1 : complex -// CHECK-NEXT: } From 776d206504dcd577110bbf4ce3b943032bfd38b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 07:20:37 -0500 Subject: [PATCH 15/23] format --- .../CoreDialectsAutoDiffImplementations.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 54688b720f2..7b37a8c0528 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -50,13 +50,13 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, } else if (auto T = cast(type)) { APFloat apvalue(T.getFloatSemantics(), value); return FloatAttr::get(T, apvalue); - // NOTE `complex::ConstantOp` doesn't accept `TypedAttr`, only `ArrayAttr` - // } else if (auto T = cast(type)) { - // auto F = cast(T.getElementType()); - // return mlir::ArrayAttr::get({ - // FloatAttr::get(F, APFloat(F.getFloatSemantics(), value)), - // FloatAttr::get(F, APFloat(F.getFloatSemantics(), "0")); - // }); + // NOTE `complex::ConstantOp` doesn't accept `TypedAttr`, only `ArrayAttr` + // } else if (auto T = cast(type)) { + // auto F = cast(T.getElementType()); + // return mlir::ArrayAttr::get({ + // FloatAttr::get(F, APFloat(F.getFloatSemantics(), value)), + // FloatAttr::get(F, APFloat(F.getFloatSemantics(), "0")); + // }); } else { llvm::errs() << " unsupported type: " << type << "\n"; } From bd18ce5547dabb827223c1154005d90057b3b783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 07:35:59 -0500 Subject: [PATCH 16/23] fix return type --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 7b37a8c0528..3ce07167c5b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -46,6 +46,7 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, } else { llvm::errs() << " unsupported eltype: " << T.getElementType() << " of type " << type << "\n"; + llvm_unreachable("unsupported eltype"); } } else if (auto T = cast(type)) { APFloat apvalue(T.getFloatSemantics(), value); @@ -59,6 +60,7 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, // }); } else { llvm::errs() << " unsupported type: " << type << "\n"; + llvm_unreachable("unsupported eltype"); } } From 5b9336b14ccda139eb791485382a22703f5ea44b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 07:37:08 -0500 Subject: [PATCH 17/23] format --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 3ce07167c5b..85185c3a7c0 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -60,7 +60,7 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, // }); } else { llvm::errs() << " unsupported type: " << type << "\n"; - llvm_unreachable("unsupported eltype"); + llvm_unreachable("unsupported eltype"); } } From 385d16e44b4c82bd47353c8a3b5bdf104abb69c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 May 2026 08:54:20 -0500 Subject: [PATCH 18/23] try fix missing header for `PropertyRef` --- enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp | 1 + enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 728937413c3..baf4bbafa85 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -1,6 +1,7 @@ #include "llvm/ADT/APSInt.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" #include "CloneFunction.h" diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 7dd2613868d..8e268254071 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -6,6 +6,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/IR/OperationSupport.h" // TODO: this shouldn't depend on specific dialects except Enzyme. #include "mlir/Dialect/LLVMIR/LLVMDialect.h" From 0c0157b01b8ebb59ffeab28425803f66c7888bfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 8 May 2026 02:54:36 -0500 Subject: [PATCH 19/23] Revert "try fix missing header for `PropertyRef`" This reverts commit 385d16e44b4c82bd47353c8a3b5bdf104abb69c8. --- enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp | 1 - enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index baf4bbafa85..728937413c3 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -1,7 +1,6 @@ #include "llvm/ADT/APSInt.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OperationSupport.h" #include "CloneFunction.h" diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 8e268254071..7dd2613868d 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -6,7 +6,6 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/IR/OperationSupport.h" // TODO: this shouldn't depend on specific dialects except Enzyme. #include "mlir/Dialect/LLVMIR/LLVMDialect.h" From a6b2d520643afbcd53787328a9b00bff9f46e4f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 9 May 2026 10:23:53 +0200 Subject: [PATCH 20/23] replace std::complex with mlir::Complex --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 85185c3a7c0..0c43094cb2b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -38,11 +38,11 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, ArrayRef(values)); } else if (auto CET = dyn_cast(T.getElementType())) { auto ET = cast(CET.getElementType()); - std::complex values[] = { - std::complex(APFloat(ET.getFloatSemantics(), value), + mlir::Complex values[] = { + mlir::Complex(APFloat(ET.getFloatSemantics(), value), APFloat(ET.getFloatSemantics(), "0"))}; return DenseElementsAttr::get(cast(type), - ArrayRef>(values)); + ArrayRef>(values)); } else { llvm::errs() << " unsupported eltype: " << T.getElementType() << " of type " << type << "\n"; From 7090fdf69f2b9650e6e0fdd1ccdfe2049159200c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 9 May 2026 16:45:45 -0500 Subject: [PATCH 21/23] fmt --- .../Implementations/CoreDialectsAutoDiffImplementations.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 0c43094cb2b..999f03035d4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -40,7 +40,7 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, auto ET = cast(CET.getElementType()); mlir::Complex values[] = { mlir::Complex(APFloat(ET.getFloatSemantics(), value), - APFloat(ET.getFloatSemantics(), "0"))}; + APFloat(ET.getFloatSemantics(), "0"))}; return DenseElementsAttr::get(cast(type), ArrayRef>(values)); } else { From 997cb47509483149ba38426754c62e5d165758b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 10 May 2026 11:12:54 -0500 Subject: [PATCH 22/23] test autodiff on complex tensor --- enzyme/test/MLIR/ReverseMode/arith.mlir | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/enzyme/test/MLIR/ReverseMode/arith.mlir b/enzyme/test/MLIR/ReverseMode/arith.mlir index 79951db94e2..414e8303005 100644 --- a/enzyme/test/MLIR/ReverseMode/arith.mlir +++ b/enzyme/test/MLIR/ReverseMode/arith.mlir @@ -20,3 +20,26 @@ func.func @dselect(%c: i1, %a: f64, %b: f64, %dr: f64) -> (f64, f64) { // CHECK-NEXT: %[[db:.+]] = arith.select %[[c]], %[[zero]], %[[dr]] : f64 // CHECK-NEXT: return %[[da]], %[[db]] : f64, f64 // CHECK-NEXT: } + +// ----- + +func.func @select_tensor_complex(%c: i1, %a: tensor>, %b: tensor>) -> tensor> { + %res = arith.select %c, %a, %b : tensor> + return %res : tensor> +} + +func.func @dselect_tensor_complex(%c: i1, %a: tensor>, %b: tensor>, %dr: tensor>) -> (tensor>, tensor>) { + %0:2 = enzyme.autodiff @select_tensor_complex(%c, %a, %b, %dr) + { + activity=[#enzyme, #enzyme, #enzyme], + ret_activity=[#enzyme] + } : (i1, tensor>, tensor>, tensor>) -> (tensor>, tensor>) + return %0#0, %0#1 : tensor>, tensor> +} + +// CHECK: func.func private @diffeselect_tensor_complex(%[[c:.+]]: i1, %[[a:.+]]: tensor>, %[[b:.+]]: tensor>, %[[dr:.+]]: tensor>) -> (tensor>, tensor>) { +// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<(0.000000e+00,0.000000e+00)> : tensor> +// CHECK-NEXT: %[[da:.+]] = arith.select %[[c]], %[[dr]], %[[zero]] : tensor> +// CHECK-NEXT: %[[db:.+]] = arith.select %[[c]], %[[zero]], %[[dr]] : tensor> +// CHECK-NEXT: return %[[da]], %[[db]] : tensor>, tensor> +// CHECK-NEXT: } From 95061bf7629c9f469e4e6f319909cf9cd9c86bb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 11 May 2026 15:10:25 -0500 Subject: [PATCH 23/23] Revert "test autodiff on complex tensor" This reverts commit 997cb47509483149ba38426754c62e5d165758b7. --- enzyme/test/MLIR/ReverseMode/arith.mlir | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/enzyme/test/MLIR/ReverseMode/arith.mlir b/enzyme/test/MLIR/ReverseMode/arith.mlir index 414e8303005..79951db94e2 100644 --- a/enzyme/test/MLIR/ReverseMode/arith.mlir +++ b/enzyme/test/MLIR/ReverseMode/arith.mlir @@ -20,26 +20,3 @@ func.func @dselect(%c: i1, %a: f64, %b: f64, %dr: f64) -> (f64, f64) { // CHECK-NEXT: %[[db:.+]] = arith.select %[[c]], %[[zero]], %[[dr]] : f64 // CHECK-NEXT: return %[[da]], %[[db]] : f64, f64 // CHECK-NEXT: } - -// ----- - -func.func @select_tensor_complex(%c: i1, %a: tensor>, %b: tensor>) -> tensor> { - %res = arith.select %c, %a, %b : tensor> - return %res : tensor> -} - -func.func @dselect_tensor_complex(%c: i1, %a: tensor>, %b: tensor>, %dr: tensor>) -> (tensor>, tensor>) { - %0:2 = enzyme.autodiff @select_tensor_complex(%c, %a, %b, %dr) - { - activity=[#enzyme, #enzyme, #enzyme], - ret_activity=[#enzyme] - } : (i1, tensor>, tensor>, tensor>) -> (tensor>, tensor>) - return %0#0, %0#1 : tensor>, tensor> -} - -// CHECK: func.func private @diffeselect_tensor_complex(%[[c:.+]]: i1, %[[a:.+]]: tensor>, %[[b:.+]]: tensor>, %[[dr:.+]]: tensor>) -> (tensor>, tensor>) { -// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<(0.000000e+00,0.000000e+00)> : tensor> -// CHECK-NEXT: %[[da:.+]] = arith.select %[[c]], %[[dr]], %[[zero]] : tensor> -// CHECK-NEXT: %[[db:.+]] = arith.select %[[c]], %[[zero]], %[[dr]] : tensor> -// CHECK-NEXT: return %[[da]], %[[db]] : tensor>, tensor> -// CHECK-NEXT: }