Skip to content

Commit b17c6ed

Browse files
committed
Remove index replacement for powers
1 parent 021ec18 commit b17c6ed

2 files changed

Lines changed: 3 additions & 17 deletions

File tree

src/forward.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,8 @@ function diff(arg::UnaryOperation{Cos}, wrt::Variable)
4444
end
4545

4646
function diff(arg::Power, wrt::Variable)
47-
outer = replace_bound_letters(arg.base, wrt)
48-
4947
return BinaryOperation{Mult}(
50-
BinaryOperation{Mult}(arg.exponent, Power(outer, arg.exponent - 1)),
48+
BinaryOperation{Mult}(arg.exponent, Power(arg.base, arg.exponent - 1)),
5149
diff(arg.base, wrt),
5250
)
5351
end
@@ -63,18 +61,6 @@ function diff(arg::BinaryOperation{Op}, wrt::Variable) where {Op<:AdditiveOperat
6361
return BinaryOperation{Op}(diff(arg.arg1, wrt), diff(arg.arg2, wrt))
6462
end
6563

66-
function replace_bound_letters(arg::Tensor, letters_to_skip::Tensor...)
67-
letters = unique(get_letters(get_indices(arg)))
68-
free_letters = unique(get_letters(get_free_indices(arg)))
69-
bound_letters = setdiff(letters, free_letters)
70-
next_letter = get_next_letter(arg, letters_to_skip...)
71-
72-
letter_map =
73-
Dict(bound_letters[li] => next_letter + li - 1 for li eachindex(bound_letters))
74-
75-
return replace_letters(arg, letter_map)
76-
end
77-
7864
function collect_factors(arg::BinaryOperation{Mult})
7965
return Value[collect_factors(arg.arg1); collect_factors(arg.arg2)]
8066
end

test/StdStrTest.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
@test to_std(gradient(sum(2 * cos(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
3131
@test to_std(gradient(sum(x .^ 2), x)) == "2x"
3232
@test to_std(gradient(sum(x .^ 3), x)) == "3x^2"
33-
@test to_std(gradient(sum(x)^2, x)) == "2sum(xᵀ)vec(1)"
34-
@test to_std(gradient(sum(x .^ 2)^2, x)) == "2sum(xᵀ^2)2x"
33+
@test to_std(gradient(sum(x)^2, x)) == "2sum(xᵀ)Iᵀvec(1)" # TODO: Simplify and remove Iᵀ
34+
@test to_std(gradient(sum(x .^ 2)^2, x)) == "4sum(xᵀ^2)x"
3535
@test to_std(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
3636
@test to_std(gradient(sum((x .* y) .^ 2), x)) == "2(x ⊙ y ⊙ y)"
3737
@test to_std(gradient(sum((A * x - y) .^ 2), x)) == "2Aᵀ(Ax - y)"

0 commit comments

Comments
 (0)