Skip to content

Commit c62699a

Browse files
committed
Remove recursive adjoint
1 parent bf103d0 commit c62699a

1 file changed

Lines changed: 11 additions & 48 deletions

File tree

src/std.jl

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -586,15 +586,15 @@ function to_standard(arg::BinaryOperation{Op}) where {Op<:AdditiveOperation}
586586
if isempty(setdiff(get_free_indices(l), get_free_indices(r))) &&
587587
isempty(setdiff(get_free_indices(l), target_indices))
588588
return BinaryOperation{Op}(l, r)
589-
elseif isempty(setdiff(get_free_indices(radjoint(l)), get_free_indices(r))) &&
590-
isempty(setdiff(get_free_indices(radjoint(l)), target_indices))
591-
return BinaryOperation{Op}(radjoint(l), r)
592-
elseif isempty(setdiff(get_free_indices(l), get_free_indices(radjoint(r)))) &&
589+
elseif isempty(setdiff(get_free_indices(adjoint(l)), get_free_indices(r))) &&
590+
isempty(setdiff(get_free_indices(adjoint(l)), target_indices))
591+
return BinaryOperation{Op}(adjoint(l), r)
592+
elseif isempty(setdiff(get_free_indices(l), get_free_indices(adjoint(r)))) &&
593593
isempty(setdiff(get_free_indices(l), target_indices))
594-
return BinaryOperation{Op}(l, radjoint(r))
595-
elseif isempty(setdiff(get_free_indices(radjoint(l)), get_free_indices(radjoint(r)))) &&
596-
isempty(setdiff(get_free_indices(radjoint(l)), target_indices))
597-
return BinaryOperation{Op}(radjoint(l), radjoint(r))
594+
return BinaryOperation{Op}(l, adjoint(r))
595+
elseif isempty(setdiff(get_free_indices(adjoint(l)), get_free_indices(adjoint(r)))) &&
596+
isempty(setdiff(get_free_indices(adjoint(l)), target_indices))
597+
return BinaryOperation{Op}(adjoint(l), adjoint(r))
598598
end
599599

600600
throw_not_std(arg)
@@ -604,43 +604,6 @@ function to_standard(arg::Real)
604604
return arg
605605
end
606606

607-
# Recursive adjoint
608-
function radjoint(arg::T) where {T<:UnaryOperation}
609-
return T(arg.arg')
610-
end
611-
612-
function radjoint(arg::BinaryOperation{Pow})
613-
return BinaryOperation{Pow}(radjoint(arg.arg1), arg.arg2)
614-
end
615-
616-
function radjoint(arg::BinaryOperation{Mult})
617-
return BinaryOperation{Mult}(radjoint(arg.arg1), radjoint(arg.arg2))
618-
end
619-
620-
function radjoint(arg::BinaryOperation{Op}) where {Op<:AdditiveOperation}
621-
return BinaryOperation{Op}(radjoint(arg.arg1), radjoint(arg.arg2))
622-
end
623-
624-
function radjoint(arg::Monomial)
625-
indices = get_indices(arg)
626-
627-
if length(indices) > 2
628-
throw(DomainError(arg.id, "Adjoint is only defined for vectors and matrices"))
629-
end
630-
631-
return Monomial(arg.id, flip.(indices)...)
632-
end
633-
634-
function radjoint(arg::Union{KrD,Zero})
635-
indices = get_indices(arg)
636-
637-
if length(indices) > 2
638-
throw(DomainError(arg.id, "Adjoint is only defined for vectors and matrices"))
639-
end
640-
641-
return typeof(arg)(flip.(indices)...)
642-
end
643-
644607
function to_standard(arg::BinaryOperation{Mult})
645608
target_indices = unique(get_free_indices(arg))
646609
target_len = length(target_indices)
@@ -657,17 +620,17 @@ function to_standard(arg::BinaryOperation{Mult})
657620
return attempt
658621
end
659622

660-
attempt = BinaryOperation{Mult}(radjoint(l), r)
623+
attempt = BinaryOperation{Mult}(adjoint(l), r)
661624
if get_free_indices(attempt) == target_indices
662625
return attempt
663626
end
664627

665-
attempt = BinaryOperation{Mult}(l, radjoint(r))
628+
attempt = BinaryOperation{Mult}(l, adjoint(r))
666629
if get_free_indices(attempt) == target_indices
667630
return attempt
668631
end
669632

670-
attempt = BinaryOperation{Mult}(radjoint(l), radjoint(r))
633+
attempt = BinaryOperation{Mult}(adjoint(l), adjoint(r))
671634
if get_free_indices(attempt) == target_indices
672635
return attempt
673636
end

0 commit comments

Comments
 (0)