Skip to content

Commit e2b86ef

Browse files
authored
Add more matrix division methods (#236)
Fixing `\` for `Dual`, `BigFloat`
1 parent dfb4c7c commit e2b86ef

File tree

8 files changed

+166
-15
lines changed

8 files changed

+166
-15
lines changed

.JuliaFormatter.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
style = "blue"
2+
always_use_return = true
23
align_assignment = true
34
align_struct_field = true
45
align_conditional = true

.github/workflows/CI.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
version:
2626
- 'lts'
2727
- '1'
28-
- 'pre'
28+
# - 'pre'
2929
group:
3030
- Core
3131
- Benchmarks

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SparseConnectivityTracer.jl
22

3+
## Version `v0.6.16`
4+
* ![Feature][badge-feature] Add more matrix division methods ([#236])
5+
36
## Version `v0.6.15`
47
* ![Feature][badge-feature] Add stable API for tracer type via `jacobian_eltype` and `hessian_eltype` ([#233])
58

@@ -109,6 +112,7 @@
109112
[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg
110113
[badge-docs]: https://img.shields.io/badge/docs-orange.svg
111114

115+
[#236]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/236
112116
[#233]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/233
113117
[#232]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/232
114118
[#231]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/231

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseConnectivityTracer"
22
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
33
authors = ["Adrian Hill <[email protected]>"]
4-
version = "0.6.15"
4+
version = "0.6.16-DEV"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/overloads/arrays.jl

+70-3
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,57 @@ for (Tx, TA, Ty) in Iterators.filter(
115115
end
116116

117117
## Division
118+
function LinearAlgebra.:\(A::AbstractMatrix{T}, B::AbstractMatrix) where {T<:AbstractTracer}
119+
if size(A, 1) != size(B, 1)
120+
throw(DimensionMismatch("arguments must have the same number of rows"))
121+
end
122+
t = second_order_or(A)
123+
return Fill(t, size(A, 2), size(B, 2))
124+
end
125+
function LinearAlgebra.:\(A::AbstractMatrix{T}, B::AbstractVector) where {T<:AbstractTracer}
126+
if size(A, 1) != size(B, 1)
127+
throw(DimensionMismatch("arguments must have the same number of rows"))
128+
end
129+
t = second_order_or(A)
130+
return Fill(t, size(A, 2))
131+
end
132+
133+
function LinearAlgebra.:\(A::AbstractMatrix, B::AbstractMatrix{T}) where {T<:AbstractTracer}
134+
if size(A, 1) != size(B, 1)
135+
throw(DimensionMismatch("arguments must have the same number of rows"))
136+
end
137+
t = second_order_or(B)
138+
return Fill(t, size(A, 2), size(B, 2))
139+
end
140+
function LinearAlgebra.:\(A::AbstractMatrix, B::AbstractVector{T}) where {T<:AbstractTracer}
141+
if size(A, 1) != size(B, 1)
142+
throw(DimensionMismatch("arguments must have the same number of rows"))
143+
end
144+
t = second_order_or(B)
145+
return Fill(t, size(A, 2))
146+
end
147+
148+
function LinearAlgebra.:\(
149+
A::AbstractMatrix{T}, B::AbstractMatrix{T}
150+
) where {T<:AbstractTracer}
151+
if size(A, 1) != size(B, 1)
152+
throw(DimensionMismatch("arguments must have the same number of rows"))
153+
end
154+
tA = second_order_or(A)
155+
tB = second_order_or(B)
156+
t = second_order_or(tA, tB)
157+
return Fill(t, size(A, 2), size(B, 2))
158+
end
118159
function LinearAlgebra.:\(
119-
A::AbstractMatrix{T}, B::AbstractVecOrMat
160+
A::AbstractMatrix{T}, B::AbstractVector{T}
120161
) where {T<:AbstractTracer}
121-
Ainv = LinearAlgebra.pinv(A)
122-
return Ainv * B
162+
if size(A, 1) != size(B, 1)
163+
throw(DimensionMismatch("arguments must have the same number of rows"))
164+
end
165+
tA = second_order_or(A)
166+
tB = second_order_or(B)
167+
t = second_order_or(tA, tB)
168+
return Fill(t, size(A, 2))
123169
end
124170

125171
## Exponential
@@ -187,6 +233,27 @@ function LinearAlgebra.logabsdet(A::AbstractMatrix{D}) where {D<:Dual}
187233
t1, t2 = LinearAlgebra.logabsdet(tracers)
188234
return (D(p1, t1), D(p2, t2))
189235
end
236+
function LinearAlgebra.:\(A::AbstractMatrix{<:Dual}, B::AbstractVector)
237+
primals, tracers = split_dual_array(A)
238+
p = primals \ B
239+
t = tracers \ B
240+
return Dual.(p, t)
241+
end
242+
function LinearAlgebra.:\(A::AbstractMatrix, B::AbstractVector{D}) where {D<:Dual}
243+
primals, tracers = split_dual_array(B)
244+
p = A \ primals
245+
t = A \ tracers
246+
return Dual.(p, t)
247+
end
248+
function LinearAlgebra.:\(
249+
A::AbstractMatrix{D1}, B::AbstractVector{D2}
250+
) where {D1<:Dual,D2<:Dual}
251+
A_primals, A_tracers = split_dual_array(A)
252+
B_primals, B_tracers = split_dual_array(B)
253+
p = A_primals \ B_primals
254+
t = A_tracers \ B_tracers
255+
return Dual.(p, t)
256+
end
190257

191258
#==============#
192259
# SparseArrays #

test/Project.toml

+5
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@ SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2828
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2929
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
30+
31+
[compat]
32+
Aqua = "0.8"
33+
JuliaFormatter = "1"
34+
JET = "0.9"

test/linting.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ using NNlib: NNlib
1313
using SpecialFunctions: SpecialFunctions
1414

1515
@testset "Code formatting" begin
16+
# Using JuliaFormatter v1 (`add JuliaFormatter@1`)
17+
# https://github.com/domluna/JuliaFormatter.jl/issues/909
1618
@info "...with JuliaFormatter.jl"
1719
@test JuliaFormatter.format(SparseConnectivityTracer; verbose=false, overwrite=false)
1820
end
@@ -27,9 +29,14 @@ end
2729
)
2830
end
2931

30-
@testset "JET tests" begin
31-
@info "...with JET.jl"
32-
JET.test_package(SparseConnectivityTracer; target_defined_modules=true)
32+
if VERSION < v"1.12"
33+
# JET v0.9 is compatible with Julia <1.12
34+
# JET v0.10 is compatible with Julia ≥1.12
35+
# TODO: Update when 1.12 releases
36+
@testset "JET tests" begin
37+
@info "...with JET.jl"
38+
JET.test_package(SparseConnectivityTracer; target_defined_modules=true)
39+
end
3340
end
3441

3542
@testset "ExplicitImports tests" begin

test/test_arrays.jl

+74-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import SparseConnectivityTracer as SCT
22
using SparseConnectivityTracer
33
using SparseConnectivityTracer:
4-
GradientTracer, IndexSetGradientPattern, isemptytracer, MissingPrimalError
4+
GradientTracer,
5+
Dual,
6+
IndexSetGradientPattern,
7+
isemptytracer,
8+
MissingPrimalError,
9+
split_dual_array,
10+
tracer
511
using Test
612

713
using LinearAlgebra: Symmetric, Diagonal, diagind
@@ -394,12 +400,73 @@ end
394400
t2 = idx2tracer([2, 4])
395401
t3 = idx2tracer([8, 9])
396402
t4 = idx2tracer([8, 9])
397-
A = [t1 t2; t3 t4]
398-
s_out = idx2set([1, 2, 3, 4, 8, 9])
399-
400-
x = rand(2)
401-
b = A \ x
402-
@test all(t -> sameidx(t, s_out), b)
403+
A_t = [t1 t2; t3 t4]
404+
A_p = rand(2, 2)
405+
406+
t5 = idx2tracer([6])
407+
t6 = idx2tracer([5, 7])
408+
x_t = [t5; t6]
409+
x_p = rand(2)
410+
411+
set_tp = set_A = idx2set([1, 2, 3, 4, 8, 9])
412+
set_pt = set_x = idx2set([5, 6, 7])
413+
set_tt = union(set_tp, set_pt)
414+
415+
@testset "Global" begin
416+
b_pp = A_p \ x_p
417+
@testset "Tracer-Primal" begin
418+
b_tp = A_t \ x_p
419+
@test size(b_tp) == size(b_pp)
420+
@test all(t -> sameidx(t, set_tp), b_tp)
421+
end
422+
@testset "Primal-Tracer" begin
423+
b_pt = A_p \ x_t
424+
@test size(b_pt) == size(b_pp)
425+
@test all(t -> sameidx(t, set_pt), b_pt)
426+
end
427+
@testset "Tracer-Tracer" begin
428+
b_tt = A_t \ x_t
429+
@test size(b_tt) == size(b_pp)
430+
@test all(t -> sameidx(t, set_tt), b_tt)
431+
end
432+
end
433+
@testset "Local" begin
434+
@testset "$P" for P in (Float32, BigFloat)
435+
# https://github.com/adrhill/SparseConnectivityTracer.jl/issues/235
436+
437+
A_p = rand(P, 2, 2)
438+
A_d = Dual.(A_p, A_t)
439+
@test size(A_d) == size(A_p)
440+
441+
x_p = rand(P, 2)
442+
x_d = Dual.(x_p, x_t)
443+
@test size(x_d) == size(x_p)
444+
445+
b_pp = A_p \ x_p
446+
447+
@testset "Dual-Primal" begin
448+
b_dp = A_d \ x_p
449+
primals_dp, _ = split_dual_array(b_dp)
450+
@test size(b_dp) == size(b_pp)
451+
@test primals_dp == b_pp
452+
@test all(d -> sameidx(tracer(d), set_tp), b_dp)
453+
end
454+
@testset "Primal-Dual" begin
455+
b_pd = A_p \ x_d
456+
primals_pd, _ = split_dual_array(b_pd)
457+
@test size(b_pd) == size(b_pp)
458+
@test primals_pd == b_pp
459+
@test all(d -> sameidx(tracer(d), set_pt), b_pd)
460+
end
461+
@testset "Dual-Dual" begin
462+
b_dd = A_d \ x_d
463+
primals_dd, _ = split_dual_array(b_dd)
464+
@test size(b_dd) == size(b_pp)
465+
@test primals_dd == b_pp
466+
@test all(d -> sameidx(tracer(d), set_tt), b_dd)
467+
end
468+
end
469+
end
403470
end
404471

405472
@testset "Dot" begin

0 commit comments

Comments
 (0)