Skip to content

Commit f40bc27

Browse files
committed
Remove ir.PartialSum
1 parent 0588ef3 commit f40bc27

3 files changed

Lines changed: 6 additions & 29 deletions

File tree

src/ir.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ struct Sum <: IR
9595
arg::IR
9696
end
9797

98-
struct PartialSum <: IR
99-
arg::IR
100-
dim::Int
101-
end
102-
10398
function _get_variables(arg::Mat)
10499
return _get_variables(arg.id)
105100
end
@@ -399,18 +394,6 @@ function to_ir(arg::BinaryOperation{Mult})
399394
else
400395
return ir.Transpose(ir.Vec(ir.Literal(1)))
401396
end
402-
elseif arg.arg1 isa Literal || arg.arg2 isa Literal
403-
tensor = if arg.arg1 isa Literal
404-
arg.arg2
405-
else
406-
arg.arg1
407-
end
408-
409-
if typeof(target_indices[1]) == Upper
410-
return ir.PartialSum(to_ir(tensor), 2)
411-
else
412-
return ir.PartialSum(to_ir(tensor), 1)
413-
end
414397
end
415398
end
416399

src/julia.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ function to_julia(arg::ir.Quotient)
7676
end
7777

7878
function to_julia(arg::ir.Product)
79+
if arg.l isa ir.Transpose && arg.l.arg isa ir.Vec && arg.l.arg.id isa ir.Literal
80+
return :($(to_julia(arg.l.arg.id)) .* (sum($(to_julia(arg.r)), dims = 1)))
81+
elseif arg.r isa ir.Vec && arg.r.id isa ir.Literal
82+
return :($(to_julia(arg.r.id)) .* (sum($(to_julia(arg.l)), dims = 2)))
83+
end
84+
7985
return :($(to_julia(arg.l)) * $(to_julia(arg.r)))
8086
end
8187

src/stdstr.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,3 @@ end
200200
function to_std_str(arg::ir.Sum)
201201
return "sum(" * to_std_str(arg.arg) * ")"
202202
end
203-
204-
function to_std_str(arg::ir.PartialSum)
205-
if arg.dim == 1
206-
product = ir.Product(ir.Transpose(ir.Vec(ir.Literal(1))), arg.arg) # Only used for dispatch
207-
return "vec(1)ᵀ" * parenthesize(product, to_std_str, arg.arg)
208-
elseif arg.dim == 2
209-
product = ir.Product(arg.arg, ir.Vec(ir.Literal(1)))
210-
return parenthesize(product, to_std_str, arg.arg) * "vec(1)"
211-
end
212-
213-
throw(RuntimeError("Encountered a sum over an unsupported index"))
214-
end

0 commit comments

Comments
 (0)