Skip to content

Commit 1fe3ffa

Browse files
Fix default adjoint for AppleAccelerate and MKL
This can't be easily tested because it's very architecture-dependent, but it fixes #601
1 parent 28bcf51 commit 1fe3ffa

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/adjoint.jl

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more
2828
linsolve::L = missing
2929
end
3030

31+
function CRC.rrule(T::typeof(SciMLBase.solve), prob::LinearProblem, alg::Nothing, args...; kwargs...)
32+
@show "here?"
33+
assump = OperatorAssumptions(issquare(prob.A))
34+
alg = defaultalg(prob.A, prob.b, assump)
35+
@show alg
36+
CRC.rrule(T, prob, alg, args...; kwargs...)
37+
end
38+
3139
function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
3240
alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(
3341
alg, prob.A, prob.b), kwargs...)

src/default.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,20 @@ end
364364
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
365365
ex = :()
366366
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
367-
newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization,
368-
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
369-
DefaultAlgorithmChoice.RFLUFactorization))
367+
newex = if alg == Symbol(DefaultAlgorithmChoice.RFLUFactorization)
370368
quote
371369
getproperty(cache.cacheval, $(Meta.quot(alg)))[1]' \ dy
372370
end
371+
elseif alg == Symbol(DefaultAlgorithmChoice.MKLLUFactorization)
372+
quote
373+
A = getproperty(cache.cacheval, $(Meta.quot(alg)))[1]
374+
getrs!('T', A.factors, A.ipiv, dy)
375+
end
376+
elseif alg == Symbol(DefaultAlgorithmChoice.AppleAccelerateLUFactorization)
377+
quote
378+
A = getproperty(cache.cacheval, $(Meta.quot(alg)))[1]
379+
aa_getrs!('T', A.factors, A.ipiv, dy)
380+
end
373381
elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
374382
DefaultAlgorithmChoice.QRFactorization,
375383
DefaultAlgorithmChoice.KLUFactorization,

0 commit comments

Comments
 (0)