File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -95,11 +95,6 @@ struct Sum <: IR
9595 arg:: IR
9696end
9797
98- struct PartialSum <: IR
99- arg:: IR
100- dim:: Int
101- end
102-
10398function _get_variables (arg:: Mat )
10499 return _get_variables (arg. id)
105100end
@@ -399,18 +394,6 @@ function to_ir(arg::BinaryOperation{Mult})
399394 else
400395 return ir. Transpose (ir. Vec (ir. Literal (1 )))
401396 end
402- elseif arg. arg1 isa Literal || arg. arg2 isa Literal
403- tensor = if arg. arg1 isa Literal
404- arg. arg2
405- else
406- arg. arg1
407- end
408-
409- if typeof (target_indices[1 ]) == Upper
410- return ir. PartialSum (to_ir (tensor), 2 )
411- else
412- return ir. PartialSum (to_ir (tensor), 1 )
413- end
414397 end
415398 end
416399
Original file line number Diff line number Diff line change @@ -76,6 +76,12 @@ function to_julia(arg::ir.Quotient)
7676end
7777
7878function to_julia (arg:: ir.Product )
79+ if arg. l isa ir. Transpose && arg. l. arg isa ir. Vec && arg. l. arg. id isa ir. Literal
80+ return :($ (to_julia (arg. l. arg. id)) .* (sum ($ (to_julia (arg. r)), dims = 1 )))
81+ elseif arg. r isa ir. Vec && arg. r. id isa ir. Literal
82+ return :($ (to_julia (arg. r. id)) .* (sum ($ (to_julia (arg. l)), dims = 2 )))
83+ end
84+
7985 return :($ (to_julia (arg. l)) * $ (to_julia (arg. r)))
8086end
8187
Original file line number Diff line number Diff line change 200200function to_std_str (arg:: ir.Sum )
201201 return " sum(" * to_std_str (arg. arg) * " )"
202202end
203-
204- function to_std_str (arg:: ir.PartialSum )
205- if arg. dim == 1
206- product = ir. Product (ir. Transpose (ir. Vec (ir. Literal (1 ))), arg. arg) # Only used for dispatch
207- return " vec(1)ᵀ" * parenthesize (product, to_std_str, arg. arg)
208- elseif arg. dim == 2
209- product = ir. Product (arg. arg, ir. Vec (ir. Literal (1 )))
210- return parenthesize (product, to_std_str, arg. arg) * " vec(1)"
211- end
212-
213- throw (RuntimeError (" Encountered a sum over an unsupported index" ))
214- end
You can’t perform that action at this time.
0 commit comments