Skip to content

Fix \ for Dual #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 11, 2025
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
style = "blue"
always_use_return = true
align_assignment = true
align_struct_field = true
align_conditional = true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
version:
- 'lts'
- '1'
- 'pre'
# - 'pre'
group:
- Core
- Benchmarks
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SparseConnectivityTracer.jl

## Version `v0.6.16`
* ![Feature][badge-feature] Add more matrix division methods ([#236])

## Version `v0.6.15`
* ![Feature][badge-feature] Add stable API for tracer type via `jacobian_eltype` and `hessian_eltype` ([#233])

Expand Down Expand Up @@ -109,6 +112,7 @@
[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg
[badge-docs]: https://img.shields.io/badge/docs-orange.svg

[#236]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/236
[#233]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/233
[#232]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/232
[#231]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/231
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseConnectivityTracer"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "0.6.15"
version = "0.6.16-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
73 changes: 70 additions & 3 deletions src/overloads/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,57 @@ for (Tx, TA, Ty) in Iterators.filter(
end

## Division
function LinearAlgebra.:\(A::AbstractMatrix{T}, B::AbstractMatrix) where {T<:AbstractTracer}
if size(A, 1) != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end
t = second_order_or(A)
return Fill(t, size(A, 2), size(B, 2))
end
function LinearAlgebra.:\(A::AbstractMatrix{T}, B::AbstractVector) where {T<:AbstractTracer}
if size(A, 1) != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end
t = second_order_or(A)
return Fill(t, size(A, 2))
end

function LinearAlgebra.:\(A::AbstractMatrix, B::AbstractMatrix{T}) where {T<:AbstractTracer}
if size(A, 1) != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end
t = second_order_or(B)
return Fill(t, size(A, 2), size(B, 2))
end
function LinearAlgebra.:\(A::AbstractMatrix, B::AbstractVector{T}) where {T<:AbstractTracer}
if size(A, 1) != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end
t = second_order_or(B)
return Fill(t, size(A, 2))
end

function LinearAlgebra.:\(
A::AbstractMatrix{T}, B::AbstractMatrix{T}
) where {T<:AbstractTracer}
if size(A, 1) != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end
tA = second_order_or(A)
tB = second_order_or(B)
t = second_order_or(tA, tB)
return Fill(t, size(A, 2), size(B, 2))
end
function LinearAlgebra.:\(
A::AbstractMatrix{T}, B::AbstractVecOrMat
A::AbstractMatrix{T}, B::AbstractVector{T}
) where {T<:AbstractTracer}
Ainv = LinearAlgebra.pinv(A)
return Ainv * B
if size(A, 1) != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end
tA = second_order_or(A)
tB = second_order_or(B)
t = second_order_or(tA, tB)
return Fill(t, size(A, 2))
end

## Exponential
Expand Down Expand Up @@ -187,6 +233,27 @@ function LinearAlgebra.logabsdet(A::AbstractMatrix{D}) where {D<:Dual}
t1, t2 = LinearAlgebra.logabsdet(tracers)
return (D(p1, t1), D(p2, t2))
end
function LinearAlgebra.:\(A::AbstractMatrix{<:Dual}, B::AbstractVector)
primals, tracers = split_dual_array(A)
p = primals \ B
t = tracers \ B
return Dual.(p, t)
end
function LinearAlgebra.:\(A::AbstractMatrix, B::AbstractVector{D}) where {D<:Dual}
primals, tracers = split_dual_array(B)
p = A \ primals
t = A \ tracers
return Dual.(p, t)
end
function LinearAlgebra.:\(
A::AbstractMatrix{D1}, B::AbstractVector{D2}
) where {D1<:Dual,D2<:Dual}
A_primals, A_tracers = split_dual_array(A)
B_primals, B_tracers = split_dual_array(B)
p = A_primals \ B_primals
t = A_tracers \ B_tracers
return Dual.(p, t)
end

#==============#
# SparseArrays #
Expand Down
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8"
JuliaFormatter = "1"
JET = "0.9"
13 changes: 10 additions & 3 deletions test/linting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using NNlib: NNlib
using SpecialFunctions: SpecialFunctions

@testset "Code formatting" begin
# Using JuliaFormatter v1 (`add JuliaFormatter@1`)
# https://github.com/domluna/JuliaFormatter.jl/issues/909
@info "...with JuliaFormatter.jl"
@test JuliaFormatter.format(SparseConnectivityTracer; verbose=false, overwrite=false)
end
Expand All @@ -27,9 +29,14 @@ end
)
end

@testset "JET tests" begin
@info "...with JET.jl"
JET.test_package(SparseConnectivityTracer; target_defined_modules=true)
if VERSION < v"1.12"
# JET v0.9 is compatible with Julia <1.12
# JET v0.10 is compatible with Julia ≥1.12
# TODO: Update when 1.12 releases
@testset "JET tests" begin
@info "...with JET.jl"
JET.test_package(SparseConnectivityTracer; target_defined_modules=true)
end
end

@testset "ExplicitImports tests" begin
Expand Down
81 changes: 74 additions & 7 deletions test/test_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import SparseConnectivityTracer as SCT
using SparseConnectivityTracer
using SparseConnectivityTracer:
GradientTracer, IndexSetGradientPattern, isemptytracer, MissingPrimalError
GradientTracer,
Dual,
IndexSetGradientPattern,
isemptytracer,
MissingPrimalError,
split_dual_array,
tracer
using Test

using LinearAlgebra: Symmetric, Diagonal, diagind
Expand Down Expand Up @@ -394,12 +400,73 @@ end
t2 = idx2tracer([2, 4])
t3 = idx2tracer([8, 9])
t4 = idx2tracer([8, 9])
A = [t1 t2; t3 t4]
s_out = idx2set([1, 2, 3, 4, 8, 9])

x = rand(2)
b = A \ x
@test all(t -> sameidx(t, s_out), b)
A_t = [t1 t2; t3 t4]
A_p = rand(2, 2)

t5 = idx2tracer([6])
t6 = idx2tracer([5, 7])
x_t = [t5; t6]
x_p = rand(2)

set_tp = set_A = idx2set([1, 2, 3, 4, 8, 9])
set_pt = set_x = idx2set([5, 6, 7])
set_tt = union(set_tp, set_pt)

@testset "Global" begin
b_pp = A_p \ x_p
@testset "Tracer-Primal" begin
b_tp = A_t \ x_p
@test size(b_tp) == size(b_pp)
@test all(t -> sameidx(t, set_tp), b_tp)
end
@testset "Primal-Tracer" begin
b_pt = A_p \ x_t
@test size(b_pt) == size(b_pp)
@test all(t -> sameidx(t, set_pt), b_pt)
end
@testset "Tracer-Tracer" begin
b_tt = A_t \ x_t
@test size(b_tt) == size(b_pp)
@test all(t -> sameidx(t, set_tt), b_tt)
end
end
@testset "Local" begin
@testset "$P" for P in (Float32, BigFloat)
# https://github.com/adrhill/SparseConnectivityTracer.jl/issues/235

A_p = rand(P, 2, 2)
A_d = Dual.(A_p, A_t)
@test size(A_d) == size(A_p)

x_p = rand(P, 2)
x_d = Dual.(x_p, x_t)
@test size(x_d) == size(x_p)

b_pp = A_p \ x_p

@testset "Dual-Primal" begin
b_dp = A_d \ x_p
primals_dp, _ = split_dual_array(b_dp)
@test size(b_dp) == size(b_pp)
@test primals_dp == b_pp
@test all(d -> sameidx(tracer(d), set_tp), b_dp)
end
@testset "Primal-Dual" begin
b_pd = A_p \ x_d
primals_pd, _ = split_dual_array(b_pd)
@test size(b_pd) == size(b_pp)
@test primals_pd == b_pp
@test all(d -> sameidx(tracer(d), set_pt), b_pd)
end
@testset "Dual-Dual" begin
b_dd = A_d \ x_d
primals_dd, _ = split_dual_array(b_dd)
@test size(b_dd) == size(b_pp)
@test primals_dd == b_pp
@test all(d -> sameidx(tracer(d), set_tt), b_dd)
end
end
end
end

@testset "Dot" begin
Expand Down
Loading