Skip to content

Commit 6392a91

Browse files
committed
New is_diag
1 parent 8ea52f6 commit 6392a91

2 files changed

Lines changed: 39 additions & 38 deletions

File tree

src/forward.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,43 @@ function is_elementwise_multiplication(arg1, arg2)
136136
return !isempty(indices_in_common(arg1, arg2))
137137
end
138138

139+
140+
function is_diag(arg::BinaryOperation{Mult})
141+
return is_diag(arg.arg1, arg.arg2)
142+
end
143+
144+
function is_diag(arg)
145+
return false
146+
end
147+
148+
function is_diag(arg1::KrD, arg2::TensorExpr)
149+
return is_diag(arg2, arg1)
150+
end
151+
152+
function is_diag(arg1::KrD, arg2::KrD)
153+
return false
154+
end
155+
156+
function is_diag(arg::Union{Tensor,KrD,Zero})
157+
return false
158+
end
159+
160+
function is_diag(arg1::TensorExpr, arg2::KrD)
161+
arg1_indices, arg2_indices = get_free_indices.((arg1, arg2))
162+
163+
return length(arg1_indices) == 1 && !isempty(intersect(arg1_indices, arg2_indices))
164+
end
165+
166+
function is_diag(arg1::Value, arg2::Value)
167+
if isempty(get_free_indices(arg1))
168+
return is_diag(arg2)
169+
elseif isempty(get_free_indices(arg2))
170+
return is_diag(arg1)
171+
end
172+
173+
return is_diag(arg1) || is_diag(arg2)
174+
end
175+
139176
function evaluate(::Mult, arg1::Tensor, arg2::BinaryOperation{Mult})
140177
return evaluate(Mult(), arg2, arg1)
141178
end
@@ -147,7 +184,7 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
147184
contracting_indices = eliminated_indices([arg1_indices; arg2_indices])
148185

149186
if is_elementwise &&
150-
is_diag2(arg1) &&
187+
is_diag(arg1) &&
151188
!isempty(contracting_indices) &&
152189
length(arg2_indices) == 1
153190
new_index = setdiff(arg1_indices, contracting_indices)

src/simplify.jl

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -43,42 +43,6 @@ function elementwise_indices(arg1, arg2)
4343
return intersect(arg1_indices, arg2_indices)
4444
end
4545

46-
function is_diag2(arg::BinaryOperation{Mult})
47-
return is_diag2(arg.arg1, arg.arg2)
48-
end
49-
50-
function is_diag2(arg)
51-
return false
52-
end
53-
54-
function is_diag2(arg1::KrD, arg2::TensorExpr)
55-
return is_diag2(arg2, arg1)
56-
end
57-
58-
function is_diag2(arg1::KrD, arg2::KrD)
59-
return false
60-
end
61-
62-
function is_diag2(arg::Union{Tensor,KrD,Zero})
63-
return false
64-
end
65-
66-
function is_diag2(arg1::TensorExpr, arg2::KrD)
67-
arg1_indices, arg2_indices = get_free_indices.((arg1, arg2))
68-
69-
return length(arg1_indices) == 1 && !isempty(intersect(arg1_indices, arg2_indices))
70-
end
71-
72-
function is_diag2(arg1::Value, arg2::Value)
73-
if isempty(get_free_indices(arg1))
74-
return is_diag2(arg2)
75-
elseif isempty(get_free_indices(arg2))
76-
return is_diag2(arg1)
77-
end
78-
79-
return is_diag2(arg1) || is_diag2(arg2)
80-
end
81-
8246
function get_diag_delta(arg::BinaryOperation{Mult})
8347
l = get_diag_delta(arg.arg1)
8448
r = get_diag_delta(arg.arg2)
@@ -130,7 +94,7 @@ function get_last_letter(indices::IndexList)
13094
end
13195

13296
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
133-
if is_diag2(arg1)
97+
if is_diag(arg1)
13498
d = get_diag_delta(arg1)
13599

136100
@assert !isnothing(d)

0 commit comments

Comments
 (0)