-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
Forward-mode Enzyme.gradient through a @trace loop fails in the MLIR DifferentiatePass. The tangent of a loop-carried accumulator (tensor<10xf32>) is expanded to a Jacobian (tensor<10x10xf32>), but the stablehlo.while condition block still expects the primal shape.
expect operands to be compatible with condition block arguments but got
'tensor<i64>', 'tensor<10xf32>', 'tensor<10x10xf32>'
vs
'tensor<i64>', 'tensor<10xf32>', 'tensor<10xf32>'
# Reactant CPU forward-mode AD through @trace loop fails in DifferentiatePass
using Reactant
using Reactant: @jit
using ReactantCore: @trace
using Enzyme
Reactant.set_default_backend("cpu")
Reactant.allowscalar(true)
@inline valof(::Val{N}) where N = N
# Works: no @trace
function loss_no_trace(x)
return sum(x .* x)
end
# Crashes: @trace loop accumulating over timesteps
function loss_with_trace(x, coeffs, n_val)
N = valof(n_val)
acc = zero(x)
@trace for i in 1:N
c = coeffs[i]
acc = acc .+ c .* x
end
return sum(acc .* acc)
end
N_elem = 10
N_steps = 4
x_ra = Reactant.to_rarray(ones(Float32, N_elem))
c_ra = Reactant.to_rarray(Float32.(1:N_steps))
# Test 1: Forward-mode without @trace
println("Test 1: Forward-mode, no @trace")
try
grad = @jit Enzyme.gradient(Forward, loss_no_trace, x_ra)
println("grad = $(Array(grad[1]))")
catch e
print(e)
end
# Test 2: Forward-mode with @trace
println("\nTest 2: Forward-mode, with @trace")
try
grad = @jit Enzyme.gradient(Forward, loss_with_trace, x_ra, Const(c_ra), Const(Val(N_steps)))
println("grad = $(Array(grad[1]))")
catch e
print(e)
endReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels