Skip to content

Forward-mode AD through @trace loop fails with shape mismatch in stablehlo.while #2361

@KookiesNKareem

Description

@KookiesNKareem

Forward-mode Enzyme.gradient through a @trace loop fails in the MLIR DifferentiatePass. The tangent of a loop-carried accumulator (tensor<10xf32>) is expanded to a Jacobian (tensor<10x10xf32>), but the stablehlo.while condition block still expects the primal shape.

expect operands to be compatible with condition block arguments but got
'tensor<i64>', 'tensor<10xf32>', 'tensor<10x10xf32>'
vs
'tensor<i64>', 'tensor<10xf32>', 'tensor<10xf32>'
# Reactant CPU forward-mode AD through @trace loop fails in DifferentiatePass

using Reactant
using Reactant: @jit
using ReactantCore: @trace
using Enzyme

Reactant.set_default_backend("cpu")
Reactant.allowscalar(true)

@inline valof(::Val{N}) where N = N

# Works: no @trace
function loss_no_trace(x)
    return sum(x .* x)
end

# Crashes: @trace loop accumulating over timesteps
function loss_with_trace(x, coeffs, n_val)
    N = valof(n_val)
    acc = zero(x)
    @trace for i in 1:N
        c = coeffs[i]
        acc = acc .+ c .* x
    end
    return sum(acc .* acc)
end

N_elem = 10
N_steps = 4
x_ra = Reactant.to_rarray(ones(Float32, N_elem))
c_ra = Reactant.to_rarray(Float32.(1:N_steps))

# Test 1: Forward-mode without @trace
println("Test 1: Forward-mode, no @trace")
try
    grad = @jit Enzyme.gradient(Forward, loss_no_trace, x_ra)
    println("grad = $(Array(grad[1]))")
catch e
    print(e)
end

# Test 2: Forward-mode with @trace
println("\nTest 2: Forward-mode, with @trace")
try
    grad = @jit Enzyme.gradient(Forward, loss_with_trace, x_ra, Const(c_ra), Const(Val(N_steps)))
    println("grad = $(Array(grad[1]))")
catch e
    print(e)
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions