Skip to content

Commit a7bc65a

Browse files
committed
Add Literal type
1 parent e4ef389 commit a7bc65a

5 files changed

Lines changed: 102 additions & 13 deletions

File tree

src/forward.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ function diff(arg::Variable, wrt::Variable)
2121
return Zero(unique(indices)...)
2222
end
2323

24+
function diff(arg::Literal, wrt::Variable)
25+
indices = union(arg.indices, [flip(i) for i wrt.indices])
26+
27+
return Zero(unique(indices)...)
28+
end
29+
2430
function diff(arg::KrD, wrt::Variable)
2531
indices = union(arg.indices, [flip(i) for i wrt.indices])
2632

@@ -69,7 +75,7 @@ function collect_factors(arg)
6975
return Value[arg]
7076
end
7177

72-
function evaluate(arg::Union{Variable,KrD,Zero,Real})
78+
function evaluate(arg::Union{Variable,Literal,KrD,Zero,Real})
7379
return arg
7480
end
7581

@@ -123,7 +129,7 @@ function is_diag(arg1::KrD, arg2::KrD)
123129
return false
124130
end
125131

126-
function is_diag(arg::Union{Variable,KrD,Zero})
132+
function is_diag(arg::Union{Variable,Literal,KrD,Zero})
127133
return false
128134
end
129135

@@ -143,11 +149,11 @@ function is_diag(arg1::Value, arg2::Value)
143149
return is_diag(arg1) || is_diag(arg2)
144150
end
145151

146-
function evaluate(::Mult, arg1::Variable, arg2::BinaryOperation{Mult})
152+
function evaluate(::Mult, arg1::Union{Variable,Literal}, arg2::BinaryOperation{Mult})
147153
return evaluate(Mult(), arg2, arg1)
148154
end
149155

150-
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Variable)
156+
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Union{Variable,Literal})
151157
if arg1.arg1 isa Real
152158
return BinaryOperation{Mult}(arg1.arg1, BinaryOperation{Mult}(arg1.arg2, arg2))
153159
end
@@ -351,19 +357,19 @@ function evaluate(::Mult, arg1::Power, arg2::KrD)
351357
return BinaryOperation{Mult}(evaluate(arg1), evaluate(arg2))
352358
end
353359

354-
function evaluate(::Mult, arg1::Variable, arg2::KrD)
360+
function evaluate(::Mult, arg1::Union{Variable,Literal}, arg2::KrD)
355361
return _multiply_with_krd(arg1, arg2)
356362
end
357363

358-
function evaluate(::Mult, arg1::KrD, arg2::Variable)
364+
function evaluate(::Mult, arg1::KrD, arg2::Union{Variable,Literal})
359365
return _multiply_with_krd(arg2, arg1)
360366
end
361367

362368
function evaluate(::Mult, arg1::KrD, arg2::KrD)
363369
return _multiply_with_krd(arg1, arg2)
364370
end
365371

366-
function _multiply_with_krd(arg1::Union{Variable,KrD}, arg2::KrD)
372+
function _multiply_with_krd(arg1::Union{Variable,Literal,KrD}, arg2::KrD)
367373
arg1_indices = get_free_indices(arg1)
368374
contracting_index = eliminated_indices([arg1_indices; get_indices(arg2)])
369375

@@ -405,7 +411,7 @@ end
405411
function evaluate(
406412
::Mult,
407413
arg1::BinaryOperation{Op},
408-
arg2::KrD,
414+
arg2::Union{Literal,KrD},
409415
) where {Op<:AdditiveOperation}
410416
return evaluate(
411417
Op(),
@@ -416,7 +422,7 @@ end
416422

417423
function evaluate(
418424
::Mult,
419-
arg1::KrD,
425+
arg1::Union{Literal,KrD},
420426
arg2::BinaryOperation{Op},
421427
) where {Op<:AdditiveOperation}
422428
return evaluate(

src/ir.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,30 @@ function to_ir(arg::Variable)
214214
return ir.Scal(ir.Var(arg.id))
215215
end
216216

217+
function to_ir(arg::Literal)
218+
@assert is_standard_form(arg)
219+
220+
ids = get_indices(arg)
221+
222+
if length(ids) == 2
223+
if flip(ids[1]) == ids[2]
224+
return ir.Trace(ir.Mat(ir.Const(arg.value)))
225+
elseif typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
226+
return ir.Mat(ir.Const(arg.value))
227+
elseif typeof(ids[1]) == Lower && typeof(ids[2]) == Upper
228+
return ir.Transpose(ir.Mat(ir.Const(arg.value)))
229+
end
230+
elseif length(ids) == 1
231+
if typeof(ids[1]) == Upper
232+
return ir.Vec(ir.Const(arg.value))
233+
elseif typeof(ids[1]) == Lower
234+
return ir.Transpose(ir.Vec(ir.Const(arg.value)))
235+
end
236+
end
237+
238+
return ir.Scal(ir.Const(arg.value))
239+
end
240+
217241
function to_ir(arg::KrD)
218242
@assert is_standard_form(arg)
219243

src/ricci.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,29 @@ end
4343

4444
Base.hash(m::Variable, h::UInt) = hash(Variable, hash(m.id, hash(m.indices, h)))
4545

46+
struct Literal <: Tensor
47+
value::Real
48+
indices::IndexList
49+
50+
function Literal(value::Real, indices::LowerOrUpperIndex...)
51+
# Convert type
52+
indices = LowerOrUpperIndex[i for i indices]
53+
54+
if length(unique(indices)) != length(indices)
55+
throw(
56+
DomainError(
57+
indices,
58+
"Indices of literal with value '$(string(value))' are invalid",
59+
),
60+
)
61+
end
62+
63+
new(value, indices)
64+
end
65+
end
66+
67+
Base.hash(l::Literal, h::UInt) = hash(Literal, hash(l.value, hash(l.indices, h)))
68+
4669
function are_unique(arg::AbstractArray)
4770
return length(unique(arg)) == length(arg)
4871
end
@@ -204,7 +227,7 @@ function is_permutation(arg1::Tensor, arg2::Tensor)
204227
return is_permutation(unique(arg1_indices), unique(arg2_indices))
205228
end
206229

207-
function get_indices(arg::Union{Variable,KrD,Zero})
230+
function get_indices(arg::Union{Variable,Literal,KrD,Zero})
208231
@assert length(unique(arg.indices)) == length(arg.indices)
209232

210233
return arg.indices
@@ -419,7 +442,7 @@ function replace_letters(
419442
)
420443
end
421444

422-
function replace_letters(arg::Union{Variable,Zero,KrD}, letter_map::Dict)
445+
function replace_letters(arg::Union{Variable,Literal,Zero,KrD}, letter_map::Dict)
423446
new_indices = LowerOrUpperIndex[]
424447

425448
for i arg.indices
@@ -671,7 +694,7 @@ function Base.adjoint(arg::BinaryOperation{Op}) where {Op}
671694
return evaluate(BinaryOperation{Op}(adjoint(arg.arg1), adjoint(arg.arg2)))
672695
end
673696

674-
function Base.adjoint(arg::Union{Variable,KrD,Zero})
697+
function Base.adjoint(arg::Union{Variable,Literal,KrD,Zero})
675698
free_indices = unique(get_free_indices(arg))
676699

677700
if length(free_indices) > 2
@@ -744,6 +767,12 @@ function to_string(arg::Variable)
744767
return arg.id * join(scripts)
745768
end
746769

770+
function to_string(arg::Literal)
771+
scripts = [script(i) for i arg.indices]
772+
773+
return string(arg.value) * join(scripts)
774+
end
775+
747776
function to_string(arg::KrD)
748777
scripts = [script(i) for i arg.indices]
749778

src/std.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,24 @@ function to_standard(term::Variable)
239239
throw_not_std(term)
240240
end
241241

242+
function to_standard(term::Literal)
243+
ids = term.indices
244+
245+
if length(ids) == 2
246+
if typeof(last(term.indices)) == Lower
247+
return Literal(term.value, Upper(ids[1].letter), Lower(ids[2].letter))
248+
else
249+
return Literal(term.value, Lower(ids[1].letter), Upper(ids[2].letter))
250+
end
251+
elseif length(ids) == 1
252+
return term
253+
elseif isempty(ids)
254+
return Literal(term.value)
255+
end
256+
257+
throw_not_std(term)
258+
end
259+
242260
function to_standard(term::Union{KrD,Zero})
243261
ids = term.indices
244262

test/RicciTest.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using DiffMatic
66
using Test
77

8-
using DiffMatic: Variable, KrD, Zero
8+
using DiffMatic: Variable, Literal, KrD, Zero
99
using DiffMatic: evaluate
1010
using DiffMatic: Upper, Lower
1111

@@ -24,6 +24,18 @@ end
2424
@test !isnothing(Variable("z"))
2525
end
2626

27+
@testset "Literal constructor throws on invalid input" begin
28+
@test_throws DomainError Literal(2, Lower(2), Lower(2))
29+
end
30+
31+
@testset "Literal constructor succeeds on valid input" begin
32+
@test Literal(0.6, Upper(1), Lower(2)) isa Literal
33+
@test Literal(5//1, Lower(1), Lower(2)) isa Literal
34+
@test Literal(4.0, Upper(1)) isa Literal
35+
@test Literal(3, Lower(1)) isa Literal
36+
@test Literal(2) isa Literal
37+
end
38+
2739
@testset "index equality operator" begin
2840
left = Lower(3)
2941
right = Lower(3)

0 commit comments

Comments
 (0)