Skip to content

Commit 88ff89d

Browse files
committed
Add diff rules for log
1 parent 7dc00ae commit 88ff89d

10 files changed

Lines changed: 112 additions & 5 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ standard notation.
5151

5252
### Supported functions and operators
5353

54-
`+`, `-`, `'`, `*`, `^`, `abs`, `sin`, `cos`
54+
`+`, `-`, `'`, `*`, `^`, `abs`, `sin`, `cos`, `log`
5555

56-
Element-wise operations `sin.`, `cos.`, `abs.`, `.*` and `.^` are supported.
56+
Element-wise operations `sin.`, `cos.`, `abs.`, `.*`, `.^` and `log.` are supported.
5757
Vector 1-norm and 2-norm can be computed with `LinearAlgebra.norm(..., 1)` and `LinearAlgebra.norm(..., 2)`.
5858
Sums of vectors can be computed with `sum`.
5959
Matrix traces can be computed with `LinearAlgebra.tr`.

docs/src/examples.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,21 @@ expr = sum(A * x)
118118
119119
A¹₄x⁴1₁
120120
```
121+
#### Log and element-wise log
122+
```jldoctest usage
123+
expr = log(x' * y)
124+
125+
# output
126+
127+
log(x₃y³)
128+
```
129+
```jldoctest usage
130+
expr = log.(x)' * y
131+
132+
# output
133+
134+
log(x₃)y³
135+
```
121136
#### Vector Norms
122137
```jldoctest usage
123138
expr = norm(A * x, 2)

docs/src/usage.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ c * A
6666
cA¹₂
6767
```
6868
Supported operators and functions when creating expressions:
69-
- Basic operators `+`, `-`, `'`, `*`, `^`, `abs`, `sin` and `cos`
70-
- Element-wise operators `sin.`, `cos.`, `abs.`, `.*` and `.^`
69+
- Basic operators `+`, `-`, `'`, `*`, `^`, `abs`, `sin`, `cos` and `log`
70+
- Element-wise operators `sin.`, `cos.`, `abs.`, `.*`, `.^` and `log.`
7171
- Vector 1-norm and 2-norm can be computed with `norm1` and `norm2`
7272
- Sums of vectors can be computed with `sum`.
7373
- Matrix traces can be computed with `tr`.

src/forward.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ function diff(arg::UnaryOperation{Cos}, wrt::Variable)
4949
return BinaryOperation{Mult}(-UnaryOperation{Sin}(arg.arg), diff(arg.arg, wrt))
5050
end
5151

52+
function diff(arg::Log, wrt::Variable)
53+
return BinaryOperation{Mult}(
54+
BinaryOperation{Div}(Literal(1, get_free_indices(arg)...), arg.arg),
55+
diff(arg.arg, wrt),
56+
)
57+
end
58+
5259
function diff(arg::Power, wrt::Variable)
5360
outer = replace_bound_letters(arg.base, wrt)
5461

@@ -287,6 +294,20 @@ function evaluate(::Mult, arg1::KrD, arg2::BinaryOperation{Mult})
287294
return evaluate(Mult(), arg2, arg1)
288295
end
289296

297+
function evaluate(::Mult, arg1::BinaryOperation{Div}, arg2::KrD)
298+
attempt = BinaryOperation{Div}(
299+
evaluate(Mult(), evaluate(arg1.arg1), evaluate(arg2)),
300+
evaluate(Mult(), evaluate(arg1.arg2), evaluate(arg2)),
301+
)
302+
303+
# This ensures that arg1.base and arg2 can contract and that the contraction is simple
304+
if length(get_free_indices(attempt)) == length(get_free_indices(arg1))
305+
return attempt
306+
end
307+
308+
return BinaryOperation{Mult}(arg1, arg2)
309+
end
310+
290311
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
291312
ci = indices_in_common(arg1.arg1, arg1.arg2)
292313

@@ -377,6 +398,17 @@ function evaluate(::Mult, arg1::KrD, arg2::UnaryOp) where {UnaryOp<:UnaryOperati
377398
return BinaryOperation{Mult}(evaluate(arg2), evaluate(arg1))
378399
end
379400

401+
function evaluate(::Mult, arg1::Log, arg2::KrD)
402+
attempt = Log(evaluate(Mult(), evaluate(arg1.arg), evaluate(arg2)))
403+
404+
# This ensures that arg1.base and arg2 can contract and that the contraction is simple
405+
if length(get_free_indices(attempt)) == length(get_free_indices(arg1))
406+
return attempt
407+
end
408+
409+
return BinaryOperation{Mult}(evaluate(arg1), evaluate(arg2))
410+
end
411+
380412
function evaluate(::Mult, arg1::Power, arg2::KrD)
381413
attempt = Power(evaluate(Mult(), evaluate(arg1.base), evaluate(arg2)), arg1.exponent)
382414

@@ -874,6 +906,10 @@ function _sub_from_product(arg1::BinaryOperation{Mult}, arg2::Value)
874906
return BinaryOperation{Sub}(evaluate(arg1), evaluate(arg2))
875907
end
876908

909+
function evaluate(op::Log)
910+
return Log(evaluate(op.arg))
911+
end
912+
877913
function evaluate(op::Power)
878914
if op.exponent == 1
879915
return evaluate(op.base)

src/ir.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ struct HadamardProduct <: IR
6565
r::IR
6666
end
6767

68+
struct Log <: IR
69+
arg::IR
70+
end
71+
6872
struct Power <: IR
6973
base::IR
7074
exponent::Union{Int,Rational{Int}}
@@ -509,6 +513,10 @@ function to_ir(arg::BinaryOperation{Div})
509513
return ir.Quotient(to_ir(arg.arg1), to_ir(arg.arg2))
510514
end
511515

516+
function to_ir(arg::Log)
517+
return ir.Log(to_ir(arg.arg))
518+
end
519+
512520
function to_ir(arg::Power)
513521
return ir.Power(to_ir(arg.base), arg.exponent)
514522
end

src/ricci.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ end
118118

119119
Base.hash(op::Power, h::UInt) = hash(op.exponent, hash(op.base, hash(Power, h)))
120120

121+
struct Log <: Tensor
122+
arg::Value
123+
end
124+
125+
Base.hash(op::Log, h::UInt) = hash(op.arg, hash(Log, h))
126+
121127
struct UnaryOperation{Op} <: Tensor where {Op}
122128
arg::Value
123129
end
@@ -246,6 +252,10 @@ function get_indices(arg::Power)
246252
return get_indices(arg.base)
247253
end
248254

255+
function get_indices(arg::Log)
256+
return get_indices(arg.arg)
257+
end
258+
249259
function get_indices(arg::BinaryOperation{Op}) where {Op<:AdditiveOperation}
250260
arg1_free_ids, arg2_free_ids = get_free_indices.((arg.arg1, arg.arg2))
251261

@@ -432,13 +442,29 @@ function Base.literal_pow(f::typeof(^), base::Tensor, exponent::Val{E}) where {E
432442
return Power(base, E)
433443
end
434444

445+
function Base.log(arg::Tensor)
446+
if !isempty(get_free_indices(arg))
447+
throw(DomainError(arg, "Argument is not a scalar, use log. for element-wise log."))
448+
end
449+
450+
return Log(arg)
451+
end
452+
453+
function Base.broadcasted(::typeof(log), arg::Tensor)
454+
return Log(arg)
455+
end
456+
435457
function replace_letters(arg::BinaryOperation{Mult}, letter_map::Dict)
436458
return BinaryOperation{Mult}(
437459
replace_letters(arg.arg1, letter_map),
438460
replace_letters(arg.arg2, letter_map),
439461
)
440462
end
441463

464+
function replace_letters(arg::Log, letter_map::Dict)
465+
return Log(replace_letters(arg.arg, letter_map))
466+
end
467+
442468
function replace_letters(arg::Power, letter_map::Dict)
443469
return Power(replace_letters(arg.base, letter_map), arg.exponent)
444470
end
@@ -697,6 +723,10 @@ function Base.adjoint(arg::T) where {T<:UnaryOperation}
697723
return T(arg.arg')
698724
end
699725

726+
function Base.adjoint(arg::Log)
727+
return Log(adjoint(arg.arg))
728+
end
729+
700730
function Base.adjoint(arg::Power)
701731
return Power(adjoint(arg.base), arg.exponent)
702732
end
@@ -846,6 +876,10 @@ function to_string(arg::Power)
846876
return b * ".^" * parenthesize(arg.exponent)
847877
end
848878

879+
function to_string(arg::Log)
880+
return "log(" * to_string(arg.arg) * ")"
881+
end
882+
849883
function to_string(arg::BinaryOperation{Mult})
850884
if arg.arg1 == -1
851885
return "-" * parenthesize(arg.arg2)

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Literal)
111111
return BinaryOperation{Mult}(arg1, arg2)
112112
end
113113

114-
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Variable)
114+
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
115115
if is_diag(arg1) && !is_elementwise_multiplication(arg1, arg2)
116116
d = get_diag_delta(arg1)
117117

src/std.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ function to_standard(term::UnaryOperation{Op}) where {Op}
217217
return UnaryOperation{Op}(to_standard(term.arg))
218218
end
219219

220+
function to_standard(arg::Log)
221+
return Log(to_standard(arg.arg))
222+
end
223+
220224
function to_standard(term::Power)
221225
return Power(to_standard(term.base), term.exponent)
222226
end

src/stdstr.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ function to_std_str(arg::ir.HadamardProduct)
107107
return to_std_str(arg.l) * "" * to_std_str(arg.r)
108108
end
109109

110+
function to_std_str(arg::ir.Log)
111+
out = to_std_str(arg.arg)
112+
113+
return "log(" * out * ")"
114+
end
115+
110116
function to_std_str(arg::ir.Power)
111117
out = to_std_str(arg.base)
112118

test/StdStrTest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
@test to_std(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
3737
@test to_std(gradient(sum((x .* y) .^ 2), x)) == "2(x ⊙ y ⊙ y)"
3838
@test to_std(gradient(sum((A * x - y) .^ 2), x)) == "2Aᵀ(Ax - y)"
39+
@test to_std(gradient(log.(x)'*x, x)) == "(vec(1) ⊘ x ⊙ x) + log(x)" # TODO: Add simplification rule for quotients
40+
@test to_std(gradient(log.(x)'*log.(x), x)) ==
41+
"diag(vec(1)ᵀ ⊘ xᵀ)Iᵀlog(x) + (vec(1) ⊘ x ⊙ log(x))"
3942
@test to_std(gradient((x' * A * x) ^ (-2), x)) == "(-2)(xᵀAᵀx)^(-3)(Aᵀx + Ax)"
4043
@test to_std(gradient((x' * A * x) ^ 2, x)) == "2xᵀAᵀx(Aᵀx + Ax)"
4144
@test to_std(gradient(((A .* B) * C * x)' * x, x)) == "(A ⊙ B)Cx + Cᵀ(Aᵀ ⊙ Bᵀ)x"
@@ -52,6 +55,7 @@ end
5255
@test to_std(jacobian(A * x, x)) == "A"
5356
@test to_std(jacobian(A' * x, x)) == "Aᵀ"
5457
@test to_std(jacobian(abs.(x), x)) == "diag(sgn(x))I"
58+
@test to_std(jacobian(log.(x), x)) == "diag(vec(1) ⊘ x)I"
5559
@test to_std(jacobian(sin.(A * x + y), x)) == "diag(cos(Ax + y))A"
5660
@test to_std(jacobian(((A .* B) * C * x)' * x * x, x)) ==
5761
"xᵀCᵀ(Aᵀ ⊙ Bᵀ)xI + x(xᵀCᵀ(Aᵀ ⊙ Bᵀ) + xᵀ(A ⊙ B)C)"

0 commit comments

Comments
 (0)