diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 74e42dfea..6340f285d 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -303,6 +303,24 @@ function SciMLBase.init( ) end +_npartials(::Nothing) = nothing +_npartials(partials::Partials) = length(partials) +_npartials(partials::AbstractArray{<:Partials}) = isempty(partials) ? 0 : length(first(partials)) + +function _check_supported_forwarddiff_partials!(∂_A, ∂_b) + nA = _npartials(∂_A) + nb = _npartials(∂_b) + if nA === 0 || nb === 0 + throw( + ArgumentError( + "LinearSolve does not support ForwardDiff.Dual values with zero partials (N = 0). " * + "Use primal values (non-Dual numbers) or construct Dual numbers with at least one partial." + ) + ) + end + return nothing +end + function __dual_init( prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, args...; @@ -324,6 +342,7 @@ function __dual_init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) + _check_supported_forwarddiff_partials!(∂_A, ∂_b) primal_prob = LinearProblem{SciMLBase.isinplace(prob)}(new_A, new_b; u0 = new_u0) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 7cb82b854..99480b44a 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -343,3 +343,9 @@ backslash_large = A_large_dual \ b_large_dual # Test partials match @test ForwardDiff.partials.(sol_large.u) ≈ ForwardDiff.partials.(backslash_large) + +# Dual numbers with no partials (N=0) are unsupported; ensure a clear error is thrown. +zero_partials_dual = ForwardDiff.Dual{Nothing}(1.0, ForwardDiff.Partials{0, Float64}(())) +prob_zero_partials = LinearProblem([zero_partials_dual], [zero_partials_dual]) + +@test_throws ArgumentError("LinearSolve does not support ForwardDiff.Dual values with zero partials (N = 0). Use primal values (non-Dual numbers) or construct Dual numbers with at least one partial.") init(prob_zero_partials, LUFactorization())