@@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
134
134
static_assert (has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
135
135
" Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
136
136
137
+ auto const & lexpr = expr.left_expr ();
138
+ auto const & rexpr = expr.right_expr ();
139
+
137
140
if constexpr ( same_exp<T,EL> )
138
- return expr. el .extents ();
141
+ return lexpr .extents ();
139
142
140
143
else if constexpr ( same_exp<T,ER> )
141
- return expr. er .extents ();
144
+ return rexpr .extents ();
142
145
143
146
else if constexpr ( has_tensor_types_v<T,EL> )
144
- return retrieve_extents (expr. el );
147
+ return retrieve_extents (lexpr );
145
148
146
149
else if constexpr ( has_tensor_types_v<T,ER> )
147
- return retrieve_extents (expr. er );
150
+ return retrieve_extents (rexpr );
148
151
}
149
152
150
153
#ifdef _MSC_VER
@@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
164
167
165
168
static_assert (has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
166
169
" Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
170
+
171
+ auto const & uexpr = expr.expr ();
167
172
168
173
if constexpr ( same_exp<T,E> )
169
- return expr. e .extents ();
174
+ return uexpr .extents ();
170
175
171
176
else if constexpr ( has_tensor_types_v<T,E> )
172
- return retrieve_extents (expr. e );
177
+ return retrieve_extents (uexpr );
173
178
}
174
179
175
180
} // namespace boost::numeric::ublas::detail
@@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& exp
221
226
using ::operator ==;
222
227
using ::operator !=;
223
228
229
+ auto const & lexpr = expr.left_expr ();
230
+ auto const & rexpr = expr.right_expr ();
231
+
224
232
if constexpr ( same_exp<T,EL> )
225
- if (e != expr. el .extents ())
233
+ if (e != lexpr .extents ())
226
234
return false ;
227
235
228
236
if constexpr ( same_exp<T,ER> )
229
- if (e != expr. er .extents ())
237
+ if (e != rexpr .extents ())
230
238
return false ;
231
239
232
240
if constexpr ( has_tensor_types_v<T,EL> )
233
- if (!all_extents_equal (expr. el , e))
241
+ if (!all_extents_equal (lexpr , e))
234
242
return false ;
235
243
236
244
if constexpr ( has_tensor_types_v<T,ER> )
237
- if (!all_extents_equal (expr. er , e))
245
+ if (!all_extents_equal (rexpr , e))
238
246
return false ;
239
247
240
248
return true ;
@@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, ex
250
258
251
259
using ::operator ==;
252
260
261
+ auto const & uexpr = expr.expr ();
262
+
253
263
if constexpr ( same_exp<T,E> )
254
- if (e != expr. e .extents ())
264
+ if (e != uexpr .extents ())
255
265
return false ;
256
266
257
267
if constexpr ( has_tensor_types_v<T,E> )
258
- if (!all_extents_equal (expr. e , e))
268
+ if (!all_extents_equal (uexpr , e))
259
269
return false ;
260
270
261
271
return true ;
@@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
281
291
if (!all_extents_equal (expr, lhs.extents () ))
282
292
throw std::runtime_error (" Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes." );
283
293
284
- #pragma omp parallel for
294
+ auto const & rhs = cast_tensor_exression (expr);
295
+
296
+ #pragma omp parallel for
285
297
for (auto i = 0u ; i < lhs.size (); ++i)
286
- lhs (i) = expr () (i);
298
+ lhs (i) = rhs (i);
287
299
}
288
300
289
301
/* * @brief Evaluates expression for a tensor_core
@@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression<other_tensor_type, derived_
310
322
throw std::runtime_error (" Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes." );
311
323
}
312
324
325
+ auto const & rhs = cast_tensor_exression (expr);
326
+
313
327
#pragma omp parallel for
314
328
for (auto i = 0u ; i < lhs.size (); ++i)
315
- lhs (i) = expr () (i);
329
+ lhs (i) = rhs (i);
316
330
}
317
331
318
332
/* * @brief Evaluates expression for a tensor_core
@@ -330,9 +344,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
330
344
if (!all_extents_equal ( expr, lhs.extents () ))
331
345
throw std::runtime_error (" Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes." );
332
346
347
+ auto const & rhs = cast_tensor_exression (expr);
348
+
333
349
#pragma omp parallel for
334
350
for (auto i = 0u ; i < lhs.size (); ++i)
335
- fn (lhs (i), expr () (i));
351
+ fn (lhs (i), rhs (i));
336
352
}
337
353
338
354
@@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
347
363
template <class tensor_type , class unary_fn >
348
364
inline void eval (tensor_type& lhs, unary_fn const & fn)
349
365
{
350
- #pragma omp parallel for
366
+ #pragma omp parallel for
351
367
for (auto i = 0u ; i < lhs.size (); ++i)
352
368
fn (lhs (i));
353
369
}
0 commit comments