Skip to content

Commit 9136bdd

Browse files
authored
lmul!/rmul! for banded matrices (JuliaLang#55823)
This adds fast methods for `lmul!` and `rmul!` between banded matrices and numbers. Performance impact: ```julia julia> T = Tridiagonal(rand(999), rand(1000), rand(999)); julia> @Btime rmul!($T, 0.2); 4.686 ms (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1225" 669.355 ns (0 allocations: 0 bytes) # this PR ```
1 parent 4964c97 commit 9136bdd

File tree

6 files changed

+130
-0
lines changed

6 files changed

+130
-0
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

+26
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,32 @@ end
441441
-(A::Bidiagonal)=Bidiagonal(-A.dv,-A.ev,A.uplo)
442442
*(A::Bidiagonal, B::Number) = Bidiagonal(A.dv*B, A.ev*B, A.uplo)
443443
*(B::Number, A::Bidiagonal) = Bidiagonal(B*A.dv, B*A.ev, A.uplo)
444+
function rmul!(B::Bidiagonal, x::Number)
445+
if size(B,1) > 1
446+
isupper = B.uplo == 'U'
447+
row, col = 1 + isupper, 1 + !isupper
448+
# ensure that zeros are preserved on scaling
449+
y = B[row,col] * x
450+
iszero(y) || throw(ArgumentError(LazyString(lazy"cannot set index ($row, $col) off ",
451+
lazy"the tridiagonal band to a nonzero value ($y)")))
452+
end
453+
@. B.dv *= x
454+
@. B.ev *= x
455+
return B
456+
end
457+
function lmul!(x::Number, B::Bidiagonal)
458+
if size(B,1) > 1
459+
isupper = B.uplo == 'U'
460+
row, col = 1 + isupper, 1 + !isupper
461+
# ensure that zeros are preserved on scaling
462+
y = x * B[row,col]
463+
iszero(y) || throw(ArgumentError(LazyString(lazy"cannot set index ($row, $col) off ",
464+
lazy"the tridiagonal band to a nonzero value ($y)")))
465+
end
466+
@. B.dv = x * B.dv
467+
@. B.ev = x * B.ev
468+
return B
469+
end
444470
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo)
445471
\(B::Number, A::Bidiagonal) = Bidiagonal(B\A.dv, B\A.ev, A.uplo)
446472

stdlib/LinearAlgebra/src/diagonal.jl

+20
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,26 @@ end
274274

275275
(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
276276
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
277+
function lmul!(x::Number, D::Diagonal)
278+
if size(D,1) > 1
279+
# ensure that zeros are preserved on scaling
280+
y = D[2,1] * x
281+
iszero(y) || throw(ArgumentError(LazyString("cannot set index (2, 1) off ",
282+
lazy"the tridiagonal band to a nonzero value ($y)")))
283+
end
284+
@. D.diag = x * D.diag
285+
return D
286+
end
287+
function rmul!(D::Diagonal, x::Number)
288+
if size(D,1) > 1
289+
# ensure that zeros are preserved on scaling
290+
y = x * D[2,1]
291+
iszero(y) || throw(ArgumentError(LazyString("cannot set index (2, 1) off ",
292+
lazy"the tridiagonal band to a nonzero value ($y)")))
293+
end
294+
@. D.diag *= x
295+
return D
296+
end
277297
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
278298
(\)(x::Number, D::Diagonal) = Diagonal(x \ D.diag)
279299
(^)(D::Diagonal, a::Number) = Diagonal(D.diag .^ a)

stdlib/LinearAlgebra/src/tridiag.jl

+47
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,29 @@ end
228228
-(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev)
229229
*(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B)
230230
*(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev)
231+
function rmul!(A::SymTridiagonal, x::Number)
232+
if size(A,1) > 2
233+
# ensure that zeros are preserved on scaling
234+
y = A[3,1] * x
235+
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
236+
lazy"the tridiagonal band to a nonzero value ($y)")))
237+
end
238+
A.dv .*= x
239+
_evview(A) .*= x
240+
return A
241+
end
242+
function lmul!(x::Number, B::SymTridiagonal)
243+
if size(B,1) > 2
244+
# ensure that zeros are preserved on scaling
245+
y = x * B[3,1]
246+
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
247+
lazy"the tridiagonal band to a nonzero value ($y)")))
248+
end
249+
@. B.dv = x * B.dv
250+
ev = _evview(B)
251+
@. ev = x * ev
252+
return B
253+
end
231254
/(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B)
232255
\(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev)
233256
==(A::SymTridiagonal{<:Number}, B::SymTridiagonal{<:Number}) =
@@ -836,6 +859,30 @@ tr(M::Tridiagonal) = sum(M.d)
836859
-(A::Tridiagonal) = Tridiagonal(-A.dl, -A.d, -A.du)
837860
*(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B)
838861
*(B::Number, A::Tridiagonal) = Tridiagonal(B*A.dl, B*A.d, B*A.du)
862+
function rmul!(T::Tridiagonal, x::Number)
863+
if size(T,1) > 2
864+
# ensure that zeros are preserved on scaling
865+
y = T[3,1] * x
866+
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
867+
lazy"the tridiagonal band to a nonzero value ($y)")))
868+
end
869+
T.dl .*= x
870+
T.d .*= x
871+
T.du .*= x
872+
return T
873+
end
874+
function lmul!(x::Number, T::Tridiagonal)
875+
if size(T,1) > 2
876+
# ensure that zeros are preserved on scaling
877+
y = x * T[3,1]
878+
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
879+
lazy"the tridiagonal band to a nonzero value ($y)")))
880+
end
881+
@. T.dl = x * T.dl
882+
@. T.d = x * T.d
883+
@. T.du = x * T.du
884+
return T
885+
end
839886
/(A::Tridiagonal, B::Number) = Tridiagonal(A.dl/B, A.d/B, A.du/B)
840887
\(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du)
841888

stdlib/LinearAlgebra/test/bidiag.jl

+13
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,19 @@ end
969969
end
970970
end
971971

972+
@testset "rmul!/lmul! with numbers" begin
973+
for T in (Bidiagonal(rand(4), rand(3), :U), Bidiagonal(rand(4), rand(3), :L))
974+
@test rmul!(copy(T), 0.2) rmul!(Array(T), 0.2)
975+
@test lmul!(0.2, copy(T)) lmul!(0.2, Array(T))
976+
@test_throws ArgumentError rmul!(T, NaN)
977+
@test_throws ArgumentError lmul!(NaN, T)
978+
end
979+
for T in (Bidiagonal(rand(1), rand(0), :U), Bidiagonal(rand(1), rand(0), :L))
980+
@test all(isnan, rmul!(copy(T), NaN))
981+
@test all(isnan, lmul!(NaN, copy(T)))
982+
end
983+
end
984+
972985
@testset "mul with Diagonal" begin
973986
for n in 0:4
974987
dv, ev = rand(n), rand(max(n-1,0))

stdlib/LinearAlgebra/test/diagonal.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,17 @@ end
13451345
end
13461346
end
13471347

1348+
@testset "rmul!/lmul! with numbers" begin
1349+
D = Diagonal(rand(4))
1350+
@test rmul!(copy(D), 0.2) rmul!(Array(D), 0.2)
1351+
@test lmul!(0.2, copy(D)) lmul!(0.2, Array(D))
1352+
@test_throws ArgumentError rmul!(D, NaN)
1353+
@test_throws ArgumentError lmul!(NaN, D)
1354+
D = Diagonal(rand(1))
1355+
@test all(isnan, rmul!(copy(D), NaN))
1356+
@test all(isnan, lmul!(NaN, copy(D)))
1357+
end
1358+
13481359
@testset "+/- with block Symmetric/Hermitian" begin
13491360
for p in ([1 2; 3 4], [1 2+im; 2-im 4+2im])
13501361
m = SizedArrays.SizedArray{(2,2)}(p)

stdlib/LinearAlgebra/test/tridiag.jl

+13
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,19 @@ end
935935
end
936936
end
937937

938+
@testset "rmul!/lmul! with numbers" begin
939+
for T in (SymTridiagonal(rand(4), rand(3)), Tridiagonal(rand(3), rand(4), rand(3)))
940+
@test rmul!(copy(T), 0.2) rmul!(Array(T), 0.2)
941+
@test lmul!(0.2, copy(T)) lmul!(0.2, Array(T))
942+
@test_throws ArgumentError rmul!(T, NaN)
943+
@test_throws ArgumentError lmul!(NaN, T)
944+
end
945+
for T in (SymTridiagonal(rand(2), rand(1)), Tridiagonal(rand(1), rand(2), rand(1)))
946+
@test all(isnan, rmul!(copy(T), NaN))
947+
@test all(isnan, lmul!(NaN, copy(T)))
948+
end
949+
end
950+
938951
@testset "mul with empty arrays" begin
939952
A = zeros(5,0)
940953
T = Tridiagonal(zeros(0), zeros(0), zeros(0))

0 commit comments

Comments
 (0)