Skip to content

Commit 877fe01

Browse files
committed
Fix simplification of some multiplications inside unary operations
1 parent a19a08a commit 877fe01

3 files changed

Lines changed: 40 additions & 26 deletions

File tree

src/forward.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,17 @@ function evaluate(::Mult, arg1::KrD, arg2::UnaryOp) where {UnaryOp<:UnaryOperati
340340
return BinaryOperation{Mult}(evaluate(arg2), evaluate(arg1))
341341
end
342342

343+
function evaluate(::Mult, arg1::Power, arg2::KrD)
344+
attempt = Power(evaluate(Mult(), evaluate(arg1.base), evaluate(arg2)), arg1.exponent)
345+
346+
# This ensures that arg1.base and arg2 can contract and that the contraction is simple
347+
if length(get_free_indices(attempt)) == length(get_free_indices(arg1))
348+
return attempt
349+
end
350+
351+
return BinaryOperation{Mult}(evaluate(arg1), evaluate(arg2))
352+
end
353+
343354
function evaluate(::Mult, arg1::Variable, arg2::KrD)
344355
return _multiply_with_krd(arg1, arg2)
345356
end

src/simplify.jl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,6 @@ function get_diag_delta(arg)
5050
return nothing
5151
end
5252

53-
function reshape(term::Variable, indices::LowerOrUpperIndex...)
54-
return Variable(term.id, indices...)
55-
end
56-
57-
function reshape(term::UnaryOperation{Op}, indices::LowerOrUpperIndex...) where {Op}
58-
return UnaryOperation{Op}(reshape(term.arg, indices...))
59-
end
60-
61-
function reshape(
62-
term::BinaryOperation{Op},
63-
indices::LowerOrUpperIndex...,
64-
) where {Op<:AdditiveOperation}
65-
return BinaryOperation{Op}(
66-
reshape(term.arg1, indices...),
67-
reshape(term.arg2, indices...),
68-
)
69-
end
70-
71-
function reshape(arg::Power, indices::LowerOrUpperIndex...)
72-
return Power(reshape(arg.base, indices...), arg.exponent)
73-
end
74-
7553
function get_last_letter(indices::IndexList)
7654
current_last = Upper(0)
7755

@@ -110,7 +88,15 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
11088
target_indices =
11189
eliminate_indices(vcat(get_free_indices(arg1), get_indices(arg2)))
11290
@assert length(target_indices) == 1
113-
push!(reshaped, reshape(f, target_indices...))
91+
92+
current_idx = intersect(free_ids, get_free_indices(d))
93+
f = update_index(
94+
f,
95+
only(current_idx),
96+
only(target_indices);
97+
allow_shape_change = true,
98+
)
99+
push!(reshaped, f)
114100
end
115101
end
116102

@@ -159,14 +145,30 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Variable)
159145
push!(reshaped, f)
160146
elseif length(free_ids) == 1
161147
@assert length(target_indices) == 1
162-
push!(reshaped, reshape(f, target_indices...))
148+
149+
current_idx = intersect(free_ids, get_free_indices(d))
150+
f = update_index(
151+
f,
152+
only(current_idx),
153+
only(target_indices);
154+
allow_shape_change = true,
155+
)
156+
push!(reshaped, f)
163157
else
164158
@assert false "Not implemented, please open an issue with your input"
165159
end
166160
end
167161

168-
if length(get_free_indices(arg2)) == 1
169-
push!(reshaped, reshape(arg2, target_indices...))
162+
arg2_ids = get_free_indices(arg2)
163+
164+
if length(arg2_ids) == 1
165+
arg2 = update_index(
166+
arg2,
167+
only(arg2_ids),
168+
only(target_indices);
169+
allow_shape_change = true,
170+
)
171+
push!(reshaped, arg2)
170172
else
171173
@assert false "Not implemented, please open an issue with your input"
172174
end

test/StdStrTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
@test to_std(gradient(sin.(x)' * y * a, x)) == "a(cos(x) ⊙ y)"
2222
@test to_std(gradient(x' * sin.(y) * a, x)) == "asin(y)"
2323
@test to_std(gradient(y' * sin.(x) * a, x)) == "a(cos(x) ⊙ y)"
24+
@test to_std(gradient(sin.(x .* y)' * x, x)) == "(cos(x ⊙ y) ⊙ y ⊙ x) + sin(x ⊙ y)"
2425
@test to_std(gradient(sum(x), x)) == "vec(1)"
2526
@test to_std(gradient(2 * sum(x), x)) == "2vec(1)"
2627
@test to_std(gradient(sum(2 * x), x)) == "2vec(1)"

0 commit comments

Comments
 (0)