Skip to content

Commit 71e16b2

Browse files
committed
Implement 'LinearAlgebra.diag' and one necessary simplification
1 parent b001e2e commit 71e16b2

4 files changed

Lines changed: 80 additions & 2 deletions

File tree

src/ricci.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,20 @@ function LinearAlgebra.diagm(v::Tensor)
845845
return BinaryOperation{Mult}(id, v)
846846
end
847847

848+
function LinearAlgebra.diag(m::Tensor, k::Integer = 0)
849+
indices = get_free_indices(m)
850+
851+
if length(indices) != 2
852+
throw(DomainError(indices, "Input is not a matrix, cannot extract the diagonal"))
853+
end
854+
855+
if k != 0
856+
throw(DomainError(k, "Cannot extract the k-th diagonal when k!=0"))
857+
end
858+
859+
return (m .* LinearAlgebra.I) * vector(1)
860+
end
861+
848862
function script(index::Lower)
849863
@assert index.letter >= 0
850864
text = []

src/simplify.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,45 @@ function simplify(::Mult, arg1::Tensor, arg2::BinaryOperation{Mult})
125125
return simplify(Mult(), arg2, arg1)
126126
end
127127

128+
function can_apply(l::BinaryOperation{Mult}, r::KrD)
129+
return can_contract(l.arg1, r) || can_contract(l.arg2, r)
130+
end
131+
132+
function can_apply(l::KrD, r::BinaryOperation{Mult})
133+
return can_contract(l, r.arg1) || can_contract(l, r.arg2)
134+
end
135+
136+
function can_apply(l::BinaryOperation{Mult}, r::BinaryOperation{Mult})
137+
return can_apply(l, r.arg1) || can_apply(l, r.arg2)
138+
end
139+
140+
function can_apply(l::BinaryOperation{Mult}, r::Tensor)
141+
return can_contract(l.arg1, r) || can_contract(l.arg2, r)
142+
end
143+
144+
function can_apply(l::Tensor, r::BinaryOperation{Mult})
145+
return can_contract(l, r.arg1) || can_contract(l, r.arg2)
146+
end
147+
148+
function can_apply(l::Tensor, r::KrD)
149+
return can_contract(l, r)
150+
end
151+
152+
function can_apply(l::KrD, r::Tensor)
153+
return can_contract(l, r)
154+
end
155+
156+
function can_apply(l::KrD, r::KrD)
157+
return can_contract(l, r)
158+
end
159+
160+
function can_apply(l, r)
161+
return false
162+
end
163+
128164
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
129165
op = BinaryOperation{Mult}(arg1, arg2)
166+
target_indices = unique(get_free_indices(op))
130167

131168
if is_diagm(arg1) &&
132169
!is_elementwise_multiplication(arg1, arg2) &&
@@ -135,7 +172,6 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
135172

136173
@assert !isnothing(d)
137174

138-
target_indices = eliminate_indices(vcat(get_free_indices(arg1), get_indices(arg2)))
139175
factors = collect_factors(arg1)
140176
vector_factors = filter(f -> f != d, factors)
141177
reshaped = []
@@ -189,6 +225,32 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
189225
return to_binary_operation(Mult(), reshaped)
190226
end
191227

228+
if length(get_free_indices(arg1)) > 2 && length(get_free_indices(arg2)) == 1
229+
if can_apply(arg1.arg1, arg2) &&
230+
can_apply(arg1.arg2, arg2) &&
231+
arg1.arg1 isa KrD &&
232+
arg1.arg2 isa KrD
233+
r_free_indices = get_free_indices(arg2)
234+
235+
new_r = update_index(
236+
arg2,
237+
only(r_free_indices),
238+
first(target_indices);
239+
allow_shape_change = true,
240+
)
241+
242+
eliminated = eliminated_indices([get_free_indices(arg1); r_free_indices])
243+
new_l = update_index(
244+
arg1.arg2,
245+
first(eliminated),
246+
first(target_indices);
247+
allow_shape_change = true,
248+
)
249+
250+
return BinaryOperation{Mult}(new_l, new_r)
251+
end
252+
end
253+
192254
return op
193255
end
194256

test/RicciTest.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using Test
99
using DiffMatic: Variable, Literal, KrD, Zero
1010
using DiffMatic: Upper, Lower
1111

12-
using LinearAlgebra: norm, tr, I, diagm
12+
using LinearAlgebra: norm, tr, I, diagm, diag
1313

1414
dc = DiffMatic
1515

test/StdStrTest.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
@test to_std(gradient(sum(2 * x), x)) == "2vec(1)"
3333
@test to_std(gradient(2 * sum(sin.(x)), x)) == "2cos(x)"
3434
@test to_std(gradient(sum(2 * sin.(x)), x)) == "2cos(x)"
35+
@test to_std(gradient(diag(A)'*x, x)) == "(A ⊙ I)vec(1)" # TODO: Output 'diag(A)'
3536
@test to_std(gradient(2 * sum(cos.(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
3637
@test to_std(gradient(sum(2 * cos.(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
3738
@test to_std(gradient(sum(x .^ 2), x)) == "2x"
@@ -73,6 +74,7 @@ end
7374
@matrix A B C X
7475
@vector x y z
7576

77+
@test to_std(derivative(diag(A)'*x, A)) == "Iᵀdiagm(x)"
7678
@test to_std(derivative(sum(-y .* (X*z)), X)) == "(-1)zyᵀ"
7779
@test to_std(derivative(sum((A .* B) * C * x), x)) == "vec(1)ᵀ(A ⊙ B)C"
7880
end

0 commit comments

Comments
 (0)