Skip to content

Commit d68b1ed

Browse files
committed
Simplify 'evaluate' involving additive ops
1 parent 87ef47d commit d68b1ed

1 file changed

Lines changed: 61 additions & 80 deletions

File tree

src/forward.jl

Lines changed: 61 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -262,30 +262,31 @@ function evaluate(::Mult, arg1::Zero, arg2::UnaryOperation)
262262
return evaluate(Mult(), arg2, arg1)
263263
end
264264

265-
function evaluate(::Mult, arg1::UnaryOperation, arg2::Zero)
265+
# Assumes one argument is of type 'Zero'
266+
function _multiply_by_zero(arg1, arg2)
266267
free_indices = unique(eliminate_indices([get_indices(arg1); get_indices(arg2)]))
267268

268269
return Zero(free_indices...)
269270
end
270271

272+
function evaluate(::Mult, arg1::UnaryOperation, arg2::Zero)
273+
return _multiply_by_zero(arg1, arg2)
274+
end
275+
271276
function evaluate(::Mult, arg1::Zero, arg2::Tensor)
272277
return evaluate(Mult(), arg2, arg1)
273278
end
274279

275280
function evaluate(::Mult, arg1::Tensor, arg2::Zero)
276-
free_indices = unique(eliminate_indices([get_indices(arg1); get_indices(arg2)]))
277-
278-
return Zero(free_indices...)
281+
return _multiply_by_zero(arg1, arg2)
279282
end
280283

281284
function evaluate(::Mult, arg1::KrD, arg2::Zero)
282285
return evaluate(Mult(), arg2, arg1)
283286
end
284287

285288
function evaluate(::Mult, arg1::Zero, arg2::KrD)
286-
free_indices = unique(eliminate_indices([get_indices(arg1); get_indices(arg2)]))
287-
288-
return Zero(free_indices...)
289+
return _multiply_by_zero(arg1, arg2)
289290
end
290291

291292
function evaluate(::Mult, arg1::KrD, arg2::UnaryOperation)
@@ -337,10 +338,7 @@ function evaluate(::Mult, arg1::Zero, arg2::Power)
337338
end
338339

339340
function evaluate(::Mult, arg1::Power, arg2::Zero)
340-
new_indices =
341-
unique(eliminate_indices([get_free_indices(arg1); get_free_indices(arg2)]))
342-
343-
return Zero(new_indices...)
341+
return _multiply_by_zero(arg1, arg2)
344342
end
345343

346344
function evaluate(::Mult, arg1::Union{Variable,Literal}, arg2::KrD)
@@ -394,22 +392,30 @@ function _multiply_with_krd(arg1::Union{Variable,Literal,KrD}, arg2::KrD)
394392
return newarg
395393
end
396394

397-
function evaluate(
398-
::Mult,
399-
arg1::BinaryOperation{Op},
400-
arg2::Tensor,
401-
) where {Op<:AdditiveOperation}
395+
function evaluate(::Mult, arg1::BinaryOperation{Add}, arg2::Tensor)
396+
return evaluate(Mult(), arg2, arg1)
397+
end
398+
399+
function evaluate(::Mult, arg1::BinaryOperation{Sub}, arg2::Tensor)
402400
return evaluate(Mult(), arg2, arg1)
403401
end
404402

405-
function evaluate(
406-
::Mult,
407-
arg1::Tensor,
408-
arg2::BinaryOperation{Op},
409-
) where {Op<:AdditiveOperation}
403+
function evaluate(::Mult, arg1::Tensor, arg2::BinaryOperation{Add})
404+
if length(get_free_indices(arg1)) > 2 || length(get_free_indices(arg2)) > 2
405+
return evaluate(
406+
Add(),
407+
evaluate(Mult(), arg1, evaluate(arg2.arg1)),
408+
evaluate(Mult(), arg1, evaluate(arg2.arg2)),
409+
)
410+
end
411+
412+
return BinaryOperation{Mult}(arg1, arg2)
413+
end
414+
415+
function evaluate(::Mult, arg1::Tensor, arg2::BinaryOperation{Sub})
410416
if length(get_free_indices(arg1)) > 2 || length(get_free_indices(arg2)) > 2
411417
return evaluate(
412-
Op(),
418+
Sub(),
413419
evaluate(Mult(), arg1, evaluate(arg2.arg1)),
414420
evaluate(Mult(), arg1, evaluate(arg2.arg2)),
415421
)
@@ -418,63 +424,36 @@ function evaluate(
418424
return BinaryOperation{Mult}(arg1, arg2)
419425
end
420426

421-
function evaluate(
422-
::Mult,
423-
arg1::Union{KrD,Literal},
424-
arg2::BinaryOperation{Op},
425-
) where {Op<:AdditiveOperation}
427+
function evaluate(::Mult, arg1::Union{KrD,Literal}, arg2::BinaryOperation{Add})
426428
return evaluate(
427-
Op(),
429+
Add(),
428430
evaluate(Mult(), arg1, evaluate(arg2.arg1)),
429431
evaluate(Mult(), arg1, evaluate(arg2.arg2)),
430432
)
431433
end
432434

433-
function evaluate(
434-
::Mult,
435-
arg1::Zero,
436-
arg2::BinaryOperation{Op},
437-
) where {Op<:AdditiveOperation}
438-
return evaluate(Mult(), arg2, arg1)
435+
function evaluate(::Mult, arg1::Union{KrD,Literal}, arg2::BinaryOperation{Sub})
436+
return evaluate(
437+
Sub(),
438+
evaluate(Mult(), arg1, evaluate(arg2.arg1)),
439+
evaluate(Mult(), arg1, evaluate(arg2.arg2)),
440+
)
439441
end
440442

441-
function evaluate(
442-
::Mult,
443-
arg1::BinaryOperation{Op},
444-
arg2::Zero,
445-
) where {Op<:AdditiveOperation}
446-
free_indices = unique(eliminate_indices([get_indices(arg1); get_indices(arg2)]))
443+
function evaluate(::Mult, arg1::Zero, arg2::BinaryOperation{Add})
444+
return evaluate(Mult(), arg2, arg1)
445+
end
447446

448-
return Zero(free_indices...)
447+
function evaluate(::Mult, arg1::Zero, arg2::BinaryOperation{Sub})
448+
return evaluate(Mult(), arg2, arg1)
449449
end
450450

451-
# TODO: Why are these needed?
452-
function evaluate(
453-
::Mult,
454-
arg1::BinaryOperation{Op1},
455-
arg2::BinaryOperation{Op2},
456-
) where {Op1<:AdditiveOperation,Op2<:AdditiveOperation}
457-
return invoke(
458-
evaluate,
459-
Tuple{Mult,BinaryOperation{Op1},BinaryOperation{Op2}},
460-
Mult(),
461-
arg1,
462-
arg2,
463-
)
451+
function evaluate(::Mult, arg1::BinaryOperation{Add}, arg2::Zero)
452+
return _multiply_by_zero(arg1, arg2)
464453
end
465454

466-
function evaluate(
467-
::Mult,
468-
arg1::BinaryOperation{Op},
469-
arg2::BinaryOperation{Op},
470-
) where {Op<:AdditiveOperation}
471-
return invoke(
472-
evaluate,
473-
Tuple{Mult,BinaryOperation{Op},BinaryOperation{Op}},
474-
Mult(),
475-
arg1,
476-
arg2,
477-
)
455+
function evaluate(::Mult, arg1::BinaryOperation{Sub}, arg2::Zero)
456+
return _multiply_by_zero(arg1, arg2)
478457
end
479458

480459
function evaluate(::Mult, arg1::BinaryOperation{Add}, arg2::BinaryOperation{Add})
@@ -520,20 +499,24 @@ function evaluate(::Mult, arg1::BinaryOperation{Sub}, arg2::BinaryOperation{Sub}
520499
)
521500
end
522501

523-
function evaluate(
524-
::Mult,
525-
arg1::BinaryOperation{Op},
526-
arg2::Power,
527-
) where {Op<:AdditiveOperation}
502+
function evaluate(::Mult, arg1::BinaryOperation{Add}, arg2::Power)
528503
return evaluate(Mult(), arg2, arg1)
529504
end
530505

531-
function evaluate(
532-
::Mult,
533-
arg1::Power,
534-
arg2::BinaryOperation{Op},
535-
) where {Op<:AdditiveOperation}
536-
return BinaryOperation{Op}(
506+
function evaluate(::Mult, arg1::BinaryOperation{Sub}, arg2::Power)
507+
return evaluate(Mult(), arg2, arg1)
508+
end
509+
510+
511+
function evaluate(::Mult, arg1::Power, arg2::BinaryOperation{Add})
512+
return BinaryOperation{Add}(
513+
evaluate(Mult(), arg1, arg2.arg1),
514+
evaluate(Mult(), arg1, arg2.arg2),
515+
)
516+
end
517+
518+
function evaluate(::Mult, arg1::Power, arg2::BinaryOperation{Sub})
519+
return BinaryOperation{Sub}(
537520
evaluate(Mult(), arg1, arg2.arg1),
538521
evaluate(Mult(), arg1, arg2.arg2),
539522
)
@@ -561,9 +544,7 @@ function evaluate(::Mult, arg1::T, arg2::Tensor) where {T<:Real}
561544
end
562545

563546
function evaluate(::Mult, arg1::Zero, arg2::Zero)
564-
new_indices = eliminate_indices([get_free_indices(arg1); get_free_indices(arg2)])
565-
566-
return Zero(new_indices...)
547+
return _multiply_by_zero(arg1, arg2)
567548
end
568549

569550
function evaluate(::Mult, arg1::Zero, arg2::Real)

0 commit comments

Comments
 (0)