@@ -50,28 +50,6 @@ function get_diag_delta(arg)
5050 return nothing
5151end
5252
53- function reshape (term:: Variable , indices:: LowerOrUpperIndex... )
54- return Variable (term. id, indices... )
55- end
56-
57- function reshape (term:: UnaryOperation{Op} , indices:: LowerOrUpperIndex... ) where {Op}
58- return UnaryOperation {Op} (reshape (term. arg, indices... ))
59- end
60-
61- function reshape (
62- term:: BinaryOperation{Op} ,
63- indices:: LowerOrUpperIndex... ,
64- ) where {Op<: AdditiveOperation }
65- return BinaryOperation {Op} (
66- reshape (term. arg1, indices... ),
67- reshape (term. arg2, indices... ),
68- )
69- end
70-
71- function reshape (arg:: Power , indices:: LowerOrUpperIndex... )
72- return Power (reshape (arg. base, indices... ), arg. exponent)
73- end
74-
7553function get_last_letter (indices:: IndexList )
7654 current_last = Upper (0 )
7755
@@ -110,7 +88,15 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
11088 target_indices =
11189 eliminate_indices (vcat (get_free_indices (arg1), get_indices (arg2)))
11290 @assert length (target_indices) == 1
113- push! (reshaped, reshape (f, target_indices... ))
91+
92+ current_idx = intersect (free_ids, get_free_indices (d))
93+ f = update_index (
94+ f,
95+ only (current_idx),
96+ only (target_indices);
97+ allow_shape_change = true ,
98+ )
99+ push! (reshaped, f)
114100 end
115101 end
116102
@@ -159,14 +145,30 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Variable)
159145 push! (reshaped, f)
160146 elseif length (free_ids) == 1
161147 @assert length (target_indices) == 1
162- push! (reshaped, reshape (f, target_indices... ))
148+
149+ current_idx = intersect (free_ids, get_free_indices (d))
150+ f = update_index (
151+ f,
152+ only (current_idx),
153+ only (target_indices);
154+ allow_shape_change = true ,
155+ )
156+ push! (reshaped, f)
163157 else
164158 @assert false " Not implemented, please open an issue with your input"
165159 end
166160 end
167161
168- if length (get_free_indices (arg2)) == 1
169- push! (reshaped, reshape (arg2, target_indices... ))
162+ arg2_ids = get_free_indices (arg2)
163+
164+ if length (arg2_ids) == 1
165+ arg2 = update_index (
166+ arg2,
167+ only (arg2_ids),
168+ only (target_indices);
169+ allow_shape_change = true ,
170+ )
171+ push! (reshaped, arg2)
170172 else
171173 @assert false " Not implemented, please open an issue with your input"
172174 end
0 commit comments