Skip to content

Commit a3d1edb

Browse files
committed
Add scalar power operator (^)
1 parent f3671f1 commit a3d1edb

2 files changed

Lines changed: 23 additions & 7 deletions

File tree

src/ricci.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,30 @@ end
286286
function Base.broadcasted(
287287
::typeof(Base.literal_pow),
288288
f::Function,
289-
arg1::Tensor,
290-
arg2::Val{P},
291-
) where {P}
292-
return Base.broadcasted(f, arg1, P)
289+
base::Tensor,
290+
exponent::Val{E},
291+
) where {E}
292+
return Base.broadcasted(f, base, E)
293+
end
294+
295+
function Base.broadcasted(::typeof(^), base::Tensor, exponent::Union{Int,Rational{Int}})
296+
return Power(base, exponent)
293297
end
294298

295-
function Base.broadcasted(::typeof(^), base::Tensor, power::Int)
296-
return Power(base, power)
299+
function Base.:(^)(base::Tensor, exponent::Union{Int,Rational{Int}})
300+
if !isempty(get_free_indices(base))
301+
throw(DomainError(base, " is not a scalar, use .^ for element-wise power"))
302+
end
303+
304+
return Power(base, exponent)
305+
end
306+
307+
function Base.literal_pow(f::typeof(^), base::Tensor, exponent::Val{E}) where {E}
308+
if !isempty(get_free_indices(base))
309+
throw(DomainError(base, " is not a scalar, use .^ for element-wise power"))
310+
end
311+
312+
return Power(base, E)
297313
end
298314

299315
function replace_letters(arg::BinaryOperation{Mult}, letter_map::Dict)

test/StdTest.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ end
355355
@test to_std_string(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
356356
@test to_std_string(gradient(sum((x .* y) .^ 2), x)) == "2x ⊙ y ⊙ y"
357357
@test to_std_string(gradient(sum((A * x - y) .^ 2), x)) == "2Aᵀ(Ax - y)"
358-
@test to_std_string(gradient((x' * A * x) .^ (-2), x)) == "(-2)(xᵀAᵀx)⁻³(Aᵀx + Ax)"
358+
@test to_std_string(gradient((x' * A * x) ^ (-2), x)) == "(-2)(xᵀAᵀx)⁻³(Aᵀx + Ax)"
359359
@test to_std_string(gradient(((A .* B) * C * x)' * x, x)) == "(A ⊙ B)Cx + Cᵀ(Aᵀ ⊙ Bᵀ)x"
360360
@test to_std_string(gradient(((A .* (B .* C)) * C * x)' * x, x)) ==
361361
"(B ⊙ C ⊙ A)Cx + Cᵀ(Bᵀ ⊙ Cᵀ ⊙ Aᵀ)x"

0 commit comments

Comments
 (0)