Skip to content

Commit f3671f1

Browse files
committed
Replace bound letters when diff:ing Powers
1 parent 4c7ec86 commit f3671f1

2 files changed

Lines changed: 16 additions & 1 deletion

File tree

src/forward.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ function diff(arg::UnaryOperation{Cos}, wrt::Monomial)
4040
end
4141

4242
function diff(arg::Power, wrt::Monomial)
43+
inner = replace_bound_letters(arg.base)
44+
4345
return BinaryOperation{Mult}(
4446
BinaryOperation{Mult}(arg.exponent, Power(arg.base, arg.exponent - 1)),
45-
diff(arg.base, wrt),
47+
diff(inner, wrt),
4648
)
4749
end
4850

@@ -57,6 +59,18 @@ function diff(arg::BinaryOperation{Op}, wrt::Monomial) where {Op<:AdditiveOperat
5759
return BinaryOperation{Op}(diff(arg.arg1, wrt), diff(arg.arg2, wrt))
5860
end
5961

62+
function replace_bound_letters(arg::Tensor)
63+
letters = unique(get_letters(get_indices(arg)))
64+
free_letters = unique(get_letters(get_free_indices(arg)))
65+
bound_letters = setdiff(letters, free_letters)
66+
next_letter = get_next_letter(arg)
67+
68+
letter_map =
69+
Dict(bound_letters[li] => next_letter + li for li eachindex(bound_letters))
70+
71+
return replace_letters(arg, letter_map)
72+
end
73+
6074
function collect_factors(arg::BinaryOperation{Mult})
6175
return [collect_factors(arg.arg1); collect_factors(arg.arg2)]
6276
end

test/StdTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ end
351351
@test to_std_string(gradient(sum(2 * cos(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
352352
@test to_std_string(gradient(sum(x .^ 2), x)) == "2x"
353353
@test to_std_string(gradient(sum(x .^ 3), x)) == "3x²"
354+
@test to_std_string(gradient(sum(x)^2, x)) == "2sum(xᵀ)vec(1)"
354355
@test to_std_string(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
355356
@test to_std_string(gradient(sum((x .* y) .^ 2), x)) == "2x ⊙ y ⊙ y"
356357
@test to_std_string(gradient(sum((A * x - y) .^ 2), x)) == "2Aᵀ(Ax - y)"

0 commit comments

Comments
 (0)