Skip to content

Commit ccf399b

Browse files
committed
Add internal types for division
1 parent 276fd50 commit ccf399b

7 files changed

Lines changed: 54 additions & 0 deletions

File tree

src/forward.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,10 @@ function evaluate(op::BinaryOperation{Mult})
883883
evaluate(Mult(), evaluate(op.arg1), evaluate(op.arg2))
884884
end
885885

886+
function evaluate(op::BinaryOperation{Div})
887+
BinaryOperation{Div}(evaluate(op.arg1), evaluate(op.arg2))
888+
end
889+
886890
function evaluate(op::BinaryOperation{Op}) where {Op<:AdditiveOperation}
887891
evaluate(Op(), evaluate(op.arg1), evaluate(op.arg2))
888892
end

src/ir.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ struct Product <: IR
5555
r::IR
5656
end
5757

58+
struct Quotient <: IR
59+
num::IR
60+
den::IR
61+
end
62+
5863
struct HadamardProduct <: IR
5964
l::IR
6065
r::IR
@@ -500,6 +505,10 @@ function to_ir(arg::BinaryOperation{Mult})
500505
throw_not_std(arg)
501506
end
502507

508+
function to_ir(arg::BinaryOperation{Div})
509+
return ir.Quotient(to_ir(arg.arg1), to_ir(arg.arg2))
510+
end
511+
503512
function to_ir(arg::Power)
504513
return ir.Power(to_ir(arg.base), arg.exponent)
505514
end

src/ricci.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ abstract type AdditiveOperation end
109109
struct Add <: AdditiveOperation end
110110
struct Sub <: AdditiveOperation end
111111
struct Mult end
112+
struct Div end
112113

113114
struct Power <: Tensor
114115
base::Value
@@ -237,6 +238,10 @@ function get_indices(arg::BinaryOperation{Mult})
237238
return [get_indices(arg.arg1); get_indices(arg.arg2)]
238239
end
239240

241+
function get_indices(arg::BinaryOperation{Div})
242+
return [get_indices(arg.arg1); get_indices(arg.arg2)]
243+
end
244+
240245
function get_indices(arg::Power)
241246
return get_indices(arg.base)
242247
end
@@ -849,6 +854,10 @@ function to_string(arg::BinaryOperation{Mult})
849854
return parenthesize(arg.arg1) * parenthesize(arg.arg2)
850855
end
851856

857+
function to_string(arg::BinaryOperation{Div})
858+
return parenthesize(arg.arg1) * "/" * parenthesize(arg.arg2)
859+
end
860+
852861
function to_string(arg::BinaryOperation{Add})
853862
return to_string(arg.arg1) * " + " * to_string(arg.arg2)
854863
end

src/simplify.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Variable)
165165
return BinaryOperation{Mult}(arg1, arg2)
166166
end
167167

168+
function simplify(::Div, arg1::Value, arg2::Value)
169+
return evaluate(
170+
BinaryOperation{Div}(simplify(evaluate(arg1)), simplify(evaluate(arg2))),
171+
)
172+
end
173+
168174
function simplify(::Mult, arg1::Value, arg2::Value)
169175
return evaluate(
170176
BinaryOperation{Mult}(simplify(evaluate(arg1)), simplify(evaluate(arg2))),

src/std.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ function to_standard(arg::BinaryOperation{Mult})
327327
throw_not_std(arg)
328328
end
329329

330+
function to_standard(arg::BinaryOperation{Div})
331+
return BinaryOperation{Div}(to_standard(arg.arg1), to_standard(arg.arg2))
332+
end
333+
330334
function standardize(arg)
331335
arg = simplify(arg)
332336
free_indices = unique(get_free_indices(arg))

src/stdstr.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ function to_std_str(arg::ir.Product)
9999
return parenthesize(to_std_str, arg.l) * parenthesize(to_std_str, arg.r)
100100
end
101101

102+
function to_std_str(arg::ir.Quotient)
103+
return parenthesize(to_std_str, arg.num) * "" * parenthesize(to_std_str, arg.den)
104+
end
105+
102106
function to_std_str(arg::ir.HadamardProduct)
103107
return to_std_str(arg.l) * "" * to_std_str(arg.r)
104108
end

test/StdTest.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,24 @@ end
315315
@test to_std(dc.UnaryOperation{dc.Sgn}(mul(A, x))) == "sgn(Ax)"
316316
end
317317

318+
@testset "to_std output is correct with quotient" begin
319+
x = Variable("x", Upper(1))
320+
y = Variable("y", Upper(1))
321+
z = Variable("z", Upper(2))
322+
a = Variable("a")
323+
lv = Literal(2, Upper(1))
324+
325+
function div(l, r)
326+
return dc.BinaryOperation{dc.Div}(l, r)
327+
end
328+
329+
@test to_std(div(x, y)) == "x ⊘ y"
330+
@test to_std(div(lv, x)) == "vec(2) ⊘ x"
331+
@test to_std(div(x, lv)) == "x ⊘ vec(2)"
332+
# @test_throws to_std(div(a, x))
333+
# @test_throws to_std(div(x, z))
334+
end
335+
318336
@testset "to_std output is correct with KrD-KrD and one free index" begin
319337
l = KrD(Upper(1), Lower(2))
320338
u = KrD(Upper(2), Lower(1))

0 commit comments

Comments
 (0)