Skip to content

Commit 9da0101

Browse files
committed
Overload LinearAlgebra.diagm for vectors
1 parent facf6d5 commit 9da0101

4 files changed

Lines changed: 67 additions & 12 deletions

File tree

src/ricci.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,22 @@ function Base.adjoint(arg::Union{Variable,Literal,KrD,Zero})
781781
return evaluate(e)
782782
end
783783

784+
function LinearAlgebra.diagm(v::Tensor)
785+
indices = get_free_indices(v)
786+
787+
if length(indices) != 1
788+
throw(
789+
DomainError(indices, "Input is not a vector, cannot create a diagonal matrix"),
790+
)
791+
end
792+
793+
vector_index = only(indices)
794+
next_letter = get_next_letter(v)
795+
id = KrD(vector_index, flip_to(vector_index, next_letter))
796+
797+
return BinaryOperation{Mult}(id, v)
798+
end
799+
784800
function script(index::Lower)
785801
@assert index.letter >= 0
786802
text = []

src/simplify.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,33 +135,46 @@ function simplify(::Mult, arg1::Tensor, arg2::BinaryOperation{Mult})
135135
end
136136

137137
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
138-
if is_diag(arg1) && !is_elementwise_multiplication(arg1, arg2)
138+
op = BinaryOperation{Mult}(arg1, arg2)
139+
140+
if is_diag(arg1) &&
141+
!is_elementwise_multiplication(arg1, arg2) &&
142+
length(get_free_indices(op)) == 1
139143
d = get_diag_delta(arg1)
140144

141145
@assert !isnothing(d)
142146

143147
target_indices = eliminate_indices(vcat(get_free_indices(arg1), get_indices(arg2)))
144148
factors = collect_factors(arg1)
149+
vector_factors = filter(f -> f != d, factors)
145150
reshaped = []
146151

147152
for f factors
148-
if f isa KrD
153+
if isequal(f, d)
149154
continue
150155
end
151156

152157
free_ids = get_free_indices(f)
158+
153159
if isempty(free_ids)
154160
push!(reshaped, f)
155-
elseif length(free_ids) == 1
161+
elseif length(free_ids) == 1 || length(free_ids) == 2
156162
@assert length(target_indices) == 1
157163

158-
current_idx = intersect(free_ids, get_free_indices(d))
159-
f = update_index(
160-
f,
161-
only(current_idx),
162-
only(target_indices);
163-
allow_shape_change = true,
164-
)
164+
vector_index =
165+
only(get_free_indices(to_binary_operation(Mult(), vector_factors)))
166+
167+
current_idx = intersect(free_ids, [vector_index])
168+
169+
if !isempty(current_idx)
170+
f = update_index(
171+
f,
172+
vector_index,
173+
only(target_indices);
174+
allow_shape_change = true,
175+
)
176+
end
177+
165178
push!(reshaped, f)
166179
else
167180
@assert false "Not implemented, please open an issue with your input"
@@ -185,7 +198,7 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
185198
return to_binary_operation(Mult(), reshaped)
186199
end
187200

188-
return BinaryOperation{Mult}(arg1, arg2)
201+
return op
189202
end
190203

191204
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::BinaryOperation{Mult})

test/RicciTest.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using Test
99
using DiffMatic: Variable, Literal, KrD, Zero
1010
using DiffMatic: Upper, Lower
1111

12-
using LinearAlgebra: norm, tr, I
12+
using LinearAlgebra: norm, tr, I, diagm
1313

1414
dc = DiffMatic
1515

@@ -48,6 +48,14 @@ end
4848
@test matrix(4.2) == Literal(4.2, Upper(1), Lower(2))
4949
end
5050

51+
@testset "Diagonal matrix constructor" begin
52+
x = Variable("x", Upper(1))
53+
xt = Variable("x", Lower(1))
54+
55+
@test diagm(x) == dc.BinaryOperation{dc.Mult}(KrD(Upper(1), Lower(2)), x)
56+
@test diagm(xt) == dc.BinaryOperation{dc.Mult}(KrD(Lower(1), Upper(2)), xt)
57+
end
58+
5159
@testset "index equality operator" begin
5260
left = Lower(3)
5361
right = Lower(3)

test/StdTest.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,24 @@ end
240240
@test to_std(mul(y, A)) == "Adiag(y)"
241241
end
242242

243+
@testset "to_std output is correct with diagonal matrix-matrix multiplication" begin
244+
A = Variable("A", Upper(1), Lower(2))
245+
x = Variable("x", Upper(3))
246+
y = Variable("y", Upper(2))
247+
248+
d = KrD(Upper(1), Lower(3))
249+
250+
function mul(l, r)
251+
return dc.BinaryOperation{dc.Mult}(l, r)
252+
end
253+
254+
@test to_std(mul(mul(A, y), KrD(Upper(1), Lower(3)))) == "diag(Ay)I"
255+
@test to_std(mul(KrD(Upper(1), Lower(3)), mul(A, y))) == "diag(Ay)I"
256+
@test to_std(mul(mul(mul(A, y), KrD(Upper(1), Lower(3))), x)) == "diag(x)Ay"
257+
@test to_std(mul(x, mul(KrD(Upper(1), Lower(3)), mul(A, y)))) == "diag(x)Ay"
258+
@test to_std(mul(x, mul(mul(A, y), KrD(Upper(1), Lower(3))))) == "diag(x)Ay"
259+
end
260+
243261
@testset "to_std output is correct with vector-vector element wise multiplication" begin
244262
x = Variable("x", Upper(1))
245263
y = Variable("y", Upper(1))

0 commit comments

Comments
 (0)