Skip to content

Commit 021ec18

Browse files
committed
Add norm2 and norm1 shortcuts
1 parent 7cfe752 commit 021ec18

3 files changed

Lines changed: 59 additions & 2 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ to_std(derivative(tr(A), A)) # "I"
4848
The function `to_std` will throw an exception when given an expression that that cannot be converted to
4949
standard notation.
5050

51-
### Supported operators
51+
### Supported functions and operators
5252

53-
`tr`, `sum`, `sin`, `cos`, `+`, `-`, `'`, `*`, `.*`, `.^`, `^`, `abs`
53+
`tr`, `sum`, `sin`, `cos`, `+`, `-`, `'`, `*`, `.*`, `.^`, `^`, `abs`, `norm2`, `norm1`
5454

5555
### Installation
5656

src/ricci.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import LinearAlgebra.tr
66

77
export tr
88
export sum
9+
export norm1, norm2
910

1011
abstract type Tensor end
1112

@@ -257,6 +258,28 @@ function tr(arg::Tensor)
257258
return evaluate(BinaryOperation{Mult}(arg, KrD(flip(free_ids[2]), flip(free_ids[1]))))
258259
end
259260

261+
function norm2(arg::Tensor)
262+
free_ids = get_free_indices(arg)
263+
264+
if length(free_ids) != 1
265+
throw(DomainError("Norms are currently implemented only for vectors."))
266+
end
267+
268+
p = 2
269+
270+
return sum(arg .^ p)^(1//p)
271+
end
272+
273+
function norm1(arg::Tensor)
274+
free_ids = get_free_indices(arg)
275+
276+
if length(free_ids) != 1
277+
throw(DomainError("Norms are currently implemented only for vectors."))
278+
end
279+
280+
return sum(abs(arg))
281+
end
282+
260283
function Base.sum(arg::Tensor)
261284
free_ids = get_free_indices(arg)
262285

test/RicciTest.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,40 @@ end
110110
@test typeof(op.arg) == dc.BinaryOperation{dc.Mult}
111111
end
112112

113+
@testset "norm2 throws for inputs other than vectors" begin
114+
c = Variable("c")
115+
A = Variable("A", Upper(1), Lower(2))
116+
117+
@test_throws DomainError norm2(c)
118+
@test_throws DomainError norm2(A)
119+
end
120+
121+
@testset "norm1 throws for inputs other than vectors" begin
122+
c = Variable("c")
123+
A = Variable("A", Upper(1), Lower(2))
124+
125+
@test_throws DomainError norm1(c)
126+
@test_throws DomainError norm1(A)
127+
end
128+
129+
@testset "norm2 output" begin
130+
x = Variable("x", Upper(2))
131+
132+
op = norm2(x)
133+
134+
@test typeof(op) == dc.Power
135+
@test op.base == sum(x .^ 2)
136+
@test op.exponent == 1//2
137+
end
138+
139+
@testset "norm1 output" begin
140+
x = Variable("x", Upper(2))
141+
142+
op = norm1(x)
143+
144+
@test op == sum(abs(x))
145+
end
146+
113147
@testset "UnaryOperation equality operator" begin
114148
a = KrD(Upper(1), Lower(2))
115149
b = Variable("b", Upper(2))

0 commit comments

Comments
 (0)