@@ -262,30 +262,31 @@ function evaluate(::Mult, arg1::Zero, arg2::UnaryOperation)
262262 return evaluate (Mult (), arg2, arg1)
263263end
264264
265- function evaluate (:: Mult , arg1:: UnaryOperation , arg2:: Zero )
265+ # Assumes one argument is of type 'Zero'
266+ function _multiply_by_zero (arg1, arg2)
266267 free_indices = unique (eliminate_indices ([get_indices (arg1); get_indices (arg2)]))
267268
268269 return Zero (free_indices... )
269270end
270271
272+ function evaluate (:: Mult , arg1:: UnaryOperation , arg2:: Zero )
273+ return _multiply_by_zero (arg1, arg2)
274+ end
275+
271276function evaluate (:: Mult , arg1:: Zero , arg2:: Tensor )
272277 return evaluate (Mult (), arg2, arg1)
273278end
274279
275280function evaluate (:: Mult , arg1:: Tensor , arg2:: Zero )
276- free_indices = unique (eliminate_indices ([get_indices (arg1); get_indices (arg2)]))
277-
278- return Zero (free_indices... )
281+ return _multiply_by_zero (arg1, arg2)
279282end
280283
281284function evaluate (:: Mult , arg1:: KrD , arg2:: Zero )
282285 return evaluate (Mult (), arg2, arg1)
283286end
284287
285288function evaluate (:: Mult , arg1:: Zero , arg2:: KrD )
286- free_indices = unique (eliminate_indices ([get_indices (arg1); get_indices (arg2)]))
287-
288- return Zero (free_indices... )
289+ return _multiply_by_zero (arg1, arg2)
289290end
290291
291292function evaluate (:: Mult , arg1:: KrD , arg2:: UnaryOperation )
@@ -337,10 +338,7 @@ function evaluate(::Mult, arg1::Zero, arg2::Power)
337338end
338339
339340function evaluate (:: Mult , arg1:: Power , arg2:: Zero )
340- new_indices =
341- unique (eliminate_indices ([get_free_indices (arg1); get_free_indices (arg2)]))
342-
343- return Zero (new_indices... )
341+ return _multiply_by_zero (arg1, arg2)
344342end
345343
346344function evaluate (:: Mult , arg1:: Union{Variable,Literal} , arg2:: KrD )
@@ -394,22 +392,30 @@ function _multiply_with_krd(arg1::Union{Variable,Literal,KrD}, arg2::KrD)
394392 return newarg
395393end
396394
397- function evaluate (
398- :: Mult ,
399- arg1 :: BinaryOperation{Op} ,
400- arg2 :: Tensor ,
401- ) where {Op <: AdditiveOperation }
395+ function evaluate (:: Mult , arg1 :: BinaryOperation{Add} , arg2 :: Tensor )
396+ return evaluate ( Mult (), arg2, arg1)
397+ end
398+
399+ function evaluate ( :: Mult , arg1 :: BinaryOperation{Sub} , arg2 :: Tensor )
402400 return evaluate (Mult (), arg2, arg1)
403401end
404402
405- function evaluate (
406- :: Mult ,
407- arg1:: Tensor ,
408- arg2:: BinaryOperation{Op} ,
409- ) where {Op<: AdditiveOperation }
403+ function evaluate (:: Mult , arg1:: Tensor , arg2:: BinaryOperation{Add} )
404+ if length (get_free_indices (arg1)) > 2 || length (get_free_indices (arg2)) > 2
405+ return evaluate (
406+ Add (),
407+ evaluate (Mult (), arg1, evaluate (arg2. arg1)),
408+ evaluate (Mult (), arg1, evaluate (arg2. arg2)),
409+ )
410+ end
411+
412+ return BinaryOperation {Mult} (arg1, arg2)
413+ end
414+
415+ function evaluate (:: Mult , arg1:: Tensor , arg2:: BinaryOperation{Sub} )
410416 if length (get_free_indices (arg1)) > 2 || length (get_free_indices (arg2)) > 2
411417 return evaluate (
412- Op (),
418+ Sub (),
413419 evaluate (Mult (), arg1, evaluate (arg2. arg1)),
414420 evaluate (Mult (), arg1, evaluate (arg2. arg2)),
415421 )
@@ -418,63 +424,36 @@ function evaluate(
418424 return BinaryOperation {Mult} (arg1, arg2)
419425end
420426
421- function evaluate (
422- :: Mult ,
423- arg1:: Union{KrD,Literal} ,
424- arg2:: BinaryOperation{Op} ,
425- ) where {Op<: AdditiveOperation }
427+ function evaluate (:: Mult , arg1:: Union{KrD,Literal} , arg2:: BinaryOperation{Add} )
426428 return evaluate (
427- Op (),
429+ Add (),
428430 evaluate (Mult (), arg1, evaluate (arg2. arg1)),
429431 evaluate (Mult (), arg1, evaluate (arg2. arg2)),
430432 )
431433end
432434
433- function evaluate (
434- :: Mult ,
435- arg1 :: Zero ,
436- arg2:: BinaryOperation{Op} ,
437- ) where {Op <: AdditiveOperation }
438- return evaluate ( Mult (), arg2, arg1 )
435+ function evaluate (:: Mult , arg1 :: Union{KrD,Literal} , arg2 :: BinaryOperation{Sub} )
436+ return evaluate (
437+ Sub () ,
438+ evaluate ( Mult (), arg1, evaluate ( arg2. arg1)) ,
439+ evaluate ( Mult (), arg1, evaluate (arg2 . arg2)),
440+ )
439441end
440442
441- function evaluate (
442- :: Mult ,
443- arg1:: BinaryOperation{Op} ,
444- arg2:: Zero ,
445- ) where {Op<: AdditiveOperation }
446- free_indices = unique (eliminate_indices ([get_indices (arg1); get_indices (arg2)]))
443+ function evaluate (:: Mult , arg1:: Zero , arg2:: BinaryOperation{Add} )
444+ return evaluate (Mult (), arg2, arg1)
445+ end
447446
448- return Zero (free_indices... )
447+ function evaluate (:: Mult , arg1:: Zero , arg2:: BinaryOperation{Sub} )
448+ return evaluate (Mult (), arg2, arg1)
449449end
450450
451- # TODO : Why are these needed?
452- function evaluate (
453- :: Mult ,
454- arg1:: BinaryOperation{Op1} ,
455- arg2:: BinaryOperation{Op2} ,
456- ) where {Op1<: AdditiveOperation ,Op2<: AdditiveOperation }
457- return invoke (
458- evaluate,
459- Tuple{Mult,BinaryOperation{Op1},BinaryOperation{Op2}},
460- Mult (),
461- arg1,
462- arg2,
463- )
451+ function evaluate (:: Mult , arg1:: BinaryOperation{Add} , arg2:: Zero )
452+ return _multiply_by_zero (arg1, arg2)
464453end
465454
466- function evaluate (
467- :: Mult ,
468- arg1:: BinaryOperation{Op} ,
469- arg2:: BinaryOperation{Op} ,
470- ) where {Op<: AdditiveOperation }
471- return invoke (
472- evaluate,
473- Tuple{Mult,BinaryOperation{Op},BinaryOperation{Op}},
474- Mult (),
475- arg1,
476- arg2,
477- )
455+ function evaluate (:: Mult , arg1:: BinaryOperation{Sub} , arg2:: Zero )
456+ return _multiply_by_zero (arg1, arg2)
478457end
479458
480459function evaluate (:: Mult , arg1:: BinaryOperation{Add} , arg2:: BinaryOperation{Add} )
@@ -520,20 +499,24 @@ function evaluate(::Mult, arg1::BinaryOperation{Sub}, arg2::BinaryOperation{Sub}
520499 )
521500end
522501
523- function evaluate (
524- :: Mult ,
525- arg1:: BinaryOperation{Op} ,
526- arg2:: Power ,
527- ) where {Op<: AdditiveOperation }
502+ function evaluate (:: Mult , arg1:: BinaryOperation{Add} , arg2:: Power )
528503 return evaluate (Mult (), arg2, arg1)
529504end
530505
531- function evaluate (
532- :: Mult ,
533- arg1:: Power ,
534- arg2:: BinaryOperation{Op} ,
535- ) where {Op<: AdditiveOperation }
536- return BinaryOperation {Op} (
506+ function evaluate (:: Mult , arg1:: BinaryOperation{Sub} , arg2:: Power )
507+ return evaluate (Mult (), arg2, arg1)
508+ end
509+
510+
511+ function evaluate (:: Mult , arg1:: Power , arg2:: BinaryOperation{Add} )
512+ return BinaryOperation {Add} (
513+ evaluate (Mult (), arg1, arg2. arg1),
514+ evaluate (Mult (), arg1, arg2. arg2),
515+ )
516+ end
517+
518+ function evaluate (:: Mult , arg1:: Power , arg2:: BinaryOperation{Sub} )
519+ return BinaryOperation {Sub} (
537520 evaluate (Mult (), arg1, arg2. arg1),
538521 evaluate (Mult (), arg1, arg2. arg2),
539522 )
@@ -561,9 +544,7 @@ function evaluate(::Mult, arg1::T, arg2::Tensor) where {T<:Real}
561544end
562545
563546function evaluate (:: Mult , arg1:: Zero , arg2:: Zero )
564- new_indices = eliminate_indices ([get_free_indices (arg1); get_free_indices (arg2)])
565-
566- return Zero (new_indices... )
547+ return _multiply_by_zero (arg1, arg2)
567548end
568549
569550function evaluate (:: Mult , arg1:: Zero , arg2:: Real )
0 commit comments