Skip to content

Commit 5928a12

Browse files
committed
Add abs and sign
1 parent 073670a commit 5928a12

9 files changed

Lines changed: 108 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ standard notation.
5050

5151
### Supported operators
5252

53-
`tr`, `sum`, `sin`, `cos`, `+`, `-`, `'`, `*`, `.*`, `.^`, `^`
53+
`tr`, `sum`, `sin`, `cos`, `+`, `-`, `'`, `*`, `.*`, `.^`, `^`, `abs`
5454

5555
### Installation
5656

src/forward.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ function diff(arg::Real, wrt::Variable)
3131
return Zero([flip(i) for i wrt.indices]...)
3232
end
3333

34+
function diff(arg::UnaryOperation{Abs}, wrt::Variable)
35+
return BinaryOperation{Mult}(UnaryOperation{Sgn}(arg.arg), diff(arg.arg, wrt))
36+
end
37+
3438
function diff(arg::UnaryOperation{Sin}, wrt::Variable)
3539
return BinaryOperation{Mult}(UnaryOperation{Cos}(arg.arg), diff(arg.arg, wrt))
3640
end

src/ir.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ end
2424

2525
struct Identity <: IR end
2626

27+
struct Abs <: IR
28+
arg::IR
29+
end
30+
31+
struct Sgn <: IR
32+
arg::IR
33+
end
34+
2735
struct Sin <: IR
2836
arg::IR
2937
end
@@ -242,6 +250,14 @@ function to_ir(arg::Real)
242250
return ir.Scal(ir.Const(arg))
243251
end
244252

253+
function to_ir(arg::UnaryOperation{Abs})
254+
return ir.Abs(to_ir(arg.arg))
255+
end
256+
257+
function to_ir(arg::UnaryOperation{Sgn})
258+
return ir.Sgn(to_ir(arg.arg))
259+
end
260+
245261
function to_ir(arg::UnaryOperation{Sin})
246262
return ir.Sin(to_ir(arg.arg))
247263
end

src/ricci.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ end
103103

104104
Base.hash(op::UnaryOperation{Op}, h::UInt) where {Op} = hash(op.arg, hash(Op, h))
105105

106+
struct Abs end
107+
struct Sgn end
106108
struct Sin end
107109
struct Cos end
108110

@@ -265,6 +267,14 @@ function Base.sum(arg::Tensor)
265267
return BinaryOperation{Mult}(arg, KrD(first(free_ids), flip(first(free_ids))))
266268
end
267269

270+
function Base.abs(arg::Tensor)
271+
return UnaryOperation{Abs}(arg)
272+
end
273+
274+
function Base.sign(arg::Tensor)
275+
return UnaryOperation{Sgn}(arg)
276+
end
277+
268278
function Base.broadcasted(::typeof(*), arg1::Tensor, arg2::Tensor)
269279
arg1_free_indices = get_free_indices(arg1)
270280
arg2_free_indices = get_free_indices(arg2)
@@ -687,6 +697,14 @@ function to_string(arg::Zero)
687697
return "0" * join(scripts)
688698
end
689699

700+
function to_string(arg::UnaryOperation{Abs})
701+
return "|$(arg.arg)|"
702+
end
703+
704+
function to_string(arg::UnaryOperation{Sgn})
705+
return "sgn($(arg.arg))"
706+
end
707+
690708
function to_string(arg::UnaryOperation{Sin})
691709
return "sin($(arg.arg))"
692710
end

src/std.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ function to_std_str(arg::ir.Identity)
354354
return "I"
355355
end
356356

357+
function to_std_str(arg::ir.Abs)
358+
return "abs(" * to_std_str(arg.arg) * ")"
359+
end
360+
361+
function to_std_str(arg::ir.Sgn)
362+
return "sgn(" * to_std_str(arg.arg) * ")"
363+
end
364+
357365
function to_std_str(arg::ir.Sin)
358366
return "sin(" * to_std_str(arg.arg) * ")"
359367
end

test/ForwardTest.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,19 @@ end
626626
@test equivalent(evaluate(D), KrD(Upper(1), Lower(2)))
627627
end
628628

629+
@testset "diff abs" begin
630+
x = Variable("x", Upper(2))
631+
632+
op = abs(x)
633+
634+
D = dc.diff(op, Variable("x", Upper(3)))
635+
636+
@test equivalent(
637+
D,
638+
dc.BinaryOperation{dc.Mult}(dc.UnaryOperation{dc.Sgn}(x), KrD(Upper(2), Lower(3))),
639+
)
640+
end
641+
629642
@testset "diff sin" begin
630643
x = Variable("x", Upper(2))
631644

test/RicciTest.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,26 @@ end
7070
@test left != Zero()
7171
end
7272

73+
@testset "Abs constructor" begin
74+
a = KrD(Upper(1), Lower(2))
75+
b = Variable("b", Upper(2))
76+
77+
op = abs(a * b)
78+
79+
@test typeof(op) == dc.UnaryOperation{dc.Abs}
80+
@test typeof(op.arg) == dc.BinaryOperation{dc.Mult}
81+
end
82+
83+
@testset "Sgn constructor" begin
84+
a = KrD(Upper(1), Lower(2))
85+
b = Variable("b", Upper(2))
86+
87+
op = sign(a * b)
88+
89+
@test typeof(op) == dc.UnaryOperation{dc.Sgn}
90+
@test typeof(op.arg) == dc.BinaryOperation{dc.Mult}
91+
end
92+
7393
@testset "Sin constructor" begin
7494
a = KrD(Upper(1), Lower(2))
7595
b = Variable("b", Upper(2))

test/StdStrTest.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
@test to_std(gradient(sin(tr(x * x')), x)) == "cos(xᵀx)2x"
1313
@test to_std(gradient(cos(tr(x * x')), x)) == "(-1)sin(xᵀx)2x"
1414
@test to_std(gradient(tr(A), x)) == "vec(0)"
15+
@test to_std(gradient(abs(x' * x), x)) == "sgn(xᵀx)2x"
1516
@test to_std(gradient(x' * B' * A * A * x, x)) == "AᵀAᵀBx + BᵀAAx"
1617
@test to_std(gradient((A' * B * x)' * A * x, x)) == "AᵀAᵀBx + BᵀAAx"
1718
@test to_std(gradient(a * sin(y)' * x, x)) == "asin(y)"
@@ -48,6 +49,7 @@ end
4849

4950
@test to_std(jacobian(A * x, x)) == "A"
5051
@test to_std(jacobian(A' * x, x)) == "Aᵀ"
52+
@test to_std(jacobian(abs(x), x)) == "diag(sgn(x))I"
5153
@test to_std(jacobian(sin(A * x + y), x)) == "diag(cos(Ax + y))A"
5254
@test to_std(jacobian(((A .* B) * C * x)' * x * x, x)) ==
5355
"xᵀCᵀ(Aᵀ ⊙ Bᵀ)xI + x(xᵀCᵀ(Aᵀ ⊙ Bᵀ) + xᵀ(A ⊙ B)C)"

test/StdTest.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,32 @@ end
263263
@test to_std(sum(y)) == "sum(yᵀ)"
264264
end
265265

266+
@testset "to_std output is correct with abs" begin
267+
x = Variable("x", Upper(2))
268+
A = Variable("A", Upper(1), Lower(2))
269+
270+
function mul(l, r)
271+
return dc.BinaryOperation{dc.Mult}(l, r)
272+
end
273+
274+
@test to_std(dc.UnaryOperation{dc.Abs}(1)) == "abs(1)"
275+
@test to_std(dc.UnaryOperation{dc.Abs}(x)) == "abs(x)"
276+
@test to_std(dc.UnaryOperation{dc.Abs}(mul(A, x))) == "abs(Ax)"
277+
end
278+
279+
@testset "to_std output is correct with sgn" begin
280+
x = Variable("x", Upper(2))
281+
A = Variable("A", Upper(1), Lower(2))
282+
283+
function mul(l, r)
284+
return dc.BinaryOperation{dc.Mult}(l, r)
285+
end
286+
287+
@test to_std(dc.UnaryOperation{dc.Sgn}(1)) == "sgn(1)"
288+
@test to_std(dc.UnaryOperation{dc.Sgn}(x)) == "sgn(x)"
289+
@test to_std(dc.UnaryOperation{dc.Sgn}(mul(A, x))) == "sgn(Ax)"
290+
end
291+
266292
@testset "to_std output is correct with KrD-KrD and one free index" begin
267293
l = KrD(Upper(1), Lower(2))
268294
u = KrD(Upper(2), Lower(1))

0 commit comments

Comments
 (0)