diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index aceebb95117..999f03035d4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -18,6 +18,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "Passes/Utils.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Matchers.h" using namespace mlir; @@ -31,18 +32,36 @@ mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, return cast(ATI.createNullAttr()); } if (auto T = dyn_cast(type)) { - auto ET = dyn_cast(T.getElementType()); - if (!ET) { - llvm::errs() << " unsupported eltype: " << ET << " of type " << type - << "\n"; + if (auto ET = dyn_cast(T.getElementType())) { + APFloat values[] = {APFloat(ET.getFloatSemantics(), value)}; + return DenseElementsAttr::get(cast(type), + ArrayRef(values)); + } else if (auto CET = dyn_cast(T.getElementType())) { + auto ET = cast(CET.getElementType()); + mlir::Complex values[] = { + mlir::Complex(APFloat(ET.getFloatSemantics(), value), + APFloat(ET.getFloatSemantics(), "0"))}; + return DenseElementsAttr::get(cast(type), + ArrayRef>(values)); + } else { + llvm::errs() << " unsupported eltype: " << T.getElementType() + << " of type " << type << "\n"; + llvm_unreachable("unsupported eltype"); } - APFloat values[] = {APFloat(ET.getFloatSemantics(), value)}; - return DenseElementsAttr::get(cast(type), - ArrayRef(values)); + } else if (auto T = cast(type)) { + APFloat apvalue(T.getFloatSemantics(), value); + return FloatAttr::get(T, apvalue); + // NOTE `complex::ConstantOp` doesn't accept `TypedAttr`, only `ArrayAttr` + // } else if (auto T = cast(type)) { + // auto F = cast(T.getElementType()); + // return mlir::ArrayAttr::get({ + // FloatAttr::get(F, APFloat(F.getFloatSemantics(), value)), + // FloatAttr::get(F, APFloat(F.getFloatSemantics(), "0")); + // }); + } else { + llvm::errs() << " unsupported type: " << type << "\n"; + llvm_unreachable("unsupported eltype"); } - auto T = cast(type); - APFloat apvalue(T.getFloatSemantics(), value); - return FloatAttr::get(T, apvalue); } void mlir::enzyme::detail::branchingForwardHandler(Operation *inst,