add support for complex tensor types in getConstantAttr#2790
Conversation
c45359c to
8628eee
Compare
getConstantAttrgetConstantAttr
|
format fails, otherwise all is good |
wsmoses
left a comment
There was a problem hiding this comment.
though also what happened to the test?
|
I removed it because I'm gonna submit now a test for |
|
okay, just found out that the diff rule test of funny how |
|
|
I am aware. The problem here is that I've tried with // -----
../test/MLIR/ReverseMode/math.mlir:27:10: error: 'math.sqrt' op operand #0 must be floating-point-like, but got 'tensor<complex<f64>>'
%res = math.sqrt %x : tensor<complex<f64>>
^
../test/MLIR/ReverseMode/math.mlir:27:10: note: see current operation: %0 = "math.sqrt"(%arg0) <{fastmath = #arith.fastmath<none>}> : (tensor<complex<f64>>) -> tensor<complex<f64>>
// -----Not even with It's weird that you can instantiate a complex tensor like this %0 = arith.constant dense<(0.0,0.0)> : tensor<complex<f64>>but you cannot use any I think it's untestable right now: it requires a discussion and fix upstream. But at least it seems to fix the problems in stablehlo as tests in EnzymeAD/Enzyme-JAX#2426 pass now. |
This reverts commit 997cb47.
diff rules that use
HLOConstantFPin Enzyme-JAX are broken for complex numbers, due tomlir::enzyme::getConstantAttrnot being able to handle them.for example, the diff rule of
stablehlo.rsqrtis defined as the following:which works for real numbers. but for complex numbers, like the following example,
it errors:
we haven't checked this because we don't test complex numbers on diff rules that do not have a specific behaviour for them (i.e.
SelectIfComplex).this pr adds support for complex numbers in
getConstantAttr. the current semantics is to create aComplexAttror aDenseElementsAttrwithComplexTypeeltype with the givenvalueas the real part. not sure if we want to write a way to set the imaginary part.