Skip to content

Commit 0431cf9

Browse files
committed
Add straightforward simplifications for products involving quotients
1 parent 48572ae commit 0431cf9

3 files changed

Lines changed: 31 additions & 4 deletions

File tree

src/forward.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,26 @@ function evaluate(::Mult, arg1::BinaryOperation{Div}, arg2::KrD)
308308
return BinaryOperation{Mult}(arg1, arg2)
309309
end
310310

311+
function evaluate(::Mult, arg1::BinaryOperation{Div}, arg2::Tensor)
312+
if arg1.arg2 == arg2
313+
return simplify(arg1.arg1)
314+
end
315+
316+
if arg1.arg1 isa Literal
317+
if arg1.arg1.value == 1 && get_free_indices(arg1.arg1) == get_free_indices(arg2)
318+
return BinaryOperation{Div}(arg2, arg1.arg2)
319+
end
320+
end
321+
322+
return BinaryOperation{Mult}(arg1, arg2)
323+
end
324+
325+
function evaluate(::Mult, arg1::BinaryOperation{Div}, arg2::Zero)
326+
free_indices = unique(eliminate_indices([get_indices(arg1); get_indices(arg2)]))
327+
328+
return Zero(free_indices...)
329+
end
330+
311331
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
312332
ci = indices_in_common(arg1.arg1, arg1.arg2)
313333

src/simplify.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Literal)
111111
return BinaryOperation{Mult}(arg1, arg2)
112112
end
113113

114+
function simplify(::Mult, arg1::Tensor, arg2::BinaryOperation{Mult})
115+
return simplify(Mult(), arg2, arg1)
116+
end
117+
114118
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
115119
if is_diag(arg1) && !is_elementwise_multiplication(arg1, arg2)
116120
d = get_diag_delta(arg1)
@@ -165,6 +169,10 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
165169
return BinaryOperation{Mult}(arg1, arg2)
166170
end
167171

172+
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::BinaryOperation{Mult})
173+
return BinaryOperation{Mult}(arg1, arg2)
174+
end
175+
168176
function simplify(::Div, arg1::Value, arg2::Value)
169177
return evaluate(
170178
BinaryOperation{Div}(simplify(evaluate(arg1)), simplify(evaluate(arg2))),

test/StdStrTest.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
@test to_std(gradient(sin.(x)' * y * a, x)) == "a(cos(x) ⊙ y)"
2222
@test to_std(gradient(x' * sin.(y) * a, x)) == "asin(y)"
2323
@test to_std(gradient(y' * sin.(x) * a, x)) == "a(cos(x) ⊙ y)"
24-
@test to_std(gradient(sin.(x .* y)' * x, x)) == "(cos(x ⊙ y) ⊙ y ⊙ x) + sin(x ⊙ y)"
24+
@test to_std(gradient(sin.(x .* y)' * x, x)) == "(y ⊙ cos(x ⊙ y) ⊙ x) + sin(x ⊙ y)"
2525
@test to_std(gradient(sum(x), x)) == "vec(1)"
2626
@test to_std(gradient(2 * sum(x), x)) == "2vec(1)"
2727
@test to_std(gradient(sum(2 * x), x)) == "2vec(1)"
@@ -36,9 +36,8 @@
3636
@test to_std(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
3737
@test to_std(gradient(sum((x .* y) .^ 2), x)) == "2(x ⊙ y ⊙ y)"
3838
@test to_std(gradient(sum((A * x - y) .^ 2), x)) == "2Aᵀ(Ax - y)"
39-
@test to_std(gradient(log.(x)'*x, x)) == "(vec(1) ⊘ x ⊙ x) + log(x)" # TODO: Add simplification rule for quotients
40-
@test to_std(gradient(log.(x)'*log.(x), x)) ==
41-
"diag(vec(1)ᵀ ⊘ xᵀ)Iᵀlog(x) + (vec(1) ⊘ x ⊙ log(x))"
39+
@test to_std(gradient(log.(x)'*x, x)) == "vec(1) + log(x)"
40+
@test to_std(gradient(log.(x)'*log.(x), x)) == "2(log(x) ⊘ x)"
4241
@test to_std(gradient((x' * A * x) ^ (-2), x)) == "(-2)(xᵀAᵀx)^(-3)(Aᵀx + Ax)"
4342
@test to_std(gradient((x' * A * x) ^ 2, x)) == "2xᵀAᵀx(Aᵀx + Ax)"
4443
@test to_std(gradient(((A .* B) * C * x)' * x, x)) == "(A ⊙ B)Cx + Cᵀ(Aᵀ ⊙ Bᵀ)x"

0 commit comments

Comments
 (0)