Skip to content

Commit a128dfd

Browse files
refactor(expression): add function cast_tensor_expression for casting
This function casts any `tensor_expression` to its child class, and it also handles recursive casting to get the real expression that is stored inside the layers of `tensor_expression`.
1 parent d70a701 commit a128dfd

File tree

2 files changed

+64
-49
lines changed

2 files changed

+64
-49
lines changed

include/boost/numeric/ublas/tensor/expression.hpp

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,27 @@ static constexpr bool does_exp_need_cast_v = does_exp_need_cast< std::decay_t<T>
4343
template<typename E, typename T>
4444
struct does_exp_need_cast< tensor_expression<T,E> > : std::true_type{};
4545

46+
/**
47+
* @brief It is a safer way of casting `tensor_expression` because it handles
48+
* recursive expressions. Otherwise, in most of the cases, we try to access
49+
* `operator()`, which requires a parameter argument, that is not supported
50+
* by the `tensor_expression` class and might give an error if it is not casted
51+
* properly.
52+
*
53+
* @tparam T type of the tensor
54+
* @tparam E type of the child stored inside tensor_expression
55+
* @param e tensor_expression that needs to be casted
56+
* @return child of tensor_expression that is not tensor_expression
57+
*/
58+
template<typename T, typename E>
59+
constexpr auto const& cast_tensor_exression(tensor_expression<T,E> const& e) noexcept{
60+
auto const& res = e();
61+
if constexpr(does_exp_need_cast_v<decltype(res)>)
62+
return cast_tensor_exression(res);
63+
else
64+
return res;
65+
}
66+
4667
template<typename E, typename T>
4768
constexpr auto is_tensor_expression_impl(tensor_expression<T,E> const*) -> std::true_type;
4869

@@ -137,33 +158,15 @@ struct binary_tensor_expression
137158
binary_tensor_expression(const binary_tensor_expression& l) = delete;
138159
binary_tensor_expression& operator=(binary_tensor_expression const& l) noexcept = delete;
139160

161+
constexpr auto const& left_expr() const noexcept{ return cast_tensor_exression(el); }
162+
constexpr auto const& right_expr() const noexcept{ return cast_tensor_exression(er); }
140163

141164
[[nodiscard]] inline
142-
constexpr decltype(auto) operator()(size_type i) const
143-
requires (does_exp_need_cast_v<expression_type_left> && does_exp_need_cast_v<expression_type_right>)
144-
{
145-
return op(el()(i), er()(i));
146-
}
147-
148-
[[nodiscard]] inline
149-
constexpr decltype(auto) operator()(size_type i) const
150-
requires (does_exp_need_cast_v<expression_type_left> && !does_exp_need_cast_v<expression_type_right>)
151-
{
152-
return op(el()(i), er(i));
153-
}
154-
155-
[[nodiscard]] inline
156-
constexpr decltype(auto) operator()(size_type i) const
157-
requires (!does_exp_need_cast_v<expression_type_left> && does_exp_need_cast_v<expression_type_right>)
158-
{
159-
return op(el(i), er()(i));
160-
}
161-
162-
[[nodiscard]] inline
163-
constexpr decltype(auto) operator()(size_type i) const {
164-
return op(el(i), er(i));
165+
constexpr decltype(auto) operator()(size_type i) const {
166+
return op(left_expr()(i), right_expr()(i));
165167
}
166168

169+
private:
167170
expression_type_left el;
168171
expression_type_right er;
169172
binary_operation op;
@@ -211,19 +214,15 @@ struct unary_tensor_expression
211214
constexpr unary_tensor_expression() = delete;
212215
unary_tensor_expression(unary_tensor_expression const& l) = delete;
213216
unary_tensor_expression& operator=(unary_tensor_expression const& l) noexcept = delete;
214-
215-
[[nodiscard]] inline constexpr
216-
decltype(auto) operator()(size_type i) const
217-
requires does_exp_need_cast_v<expression_type>
218-
{
219-
return op(e()(i));
220-
}
217+
218+
constexpr auto const& expr() const noexcept{ return cast_tensor_exression(e); }
221219

222220
[[nodiscard]] inline constexpr
223-
decltype(auto) operator()(size_type i) const {
224-
return op(e(i));
221+
decltype(auto) operator()(size_type i) const {
222+
return op(expr()(i));
225223
}
226224

225+
private:
227226
expression_type e;
228227
unary_operation op;
229228
};

include/boost/numeric/ublas/tensor/expression_evaluation.hpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
134134
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
135135
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
136136

137+
auto const& lexpr = expr.left_expr();
138+
auto const& rexpr = expr.right_expr();
139+
137140
if constexpr ( same_exp<T,EL> )
138-
return expr.el.extents();
141+
return lexpr.extents();
139142

140143
else if constexpr ( same_exp<T,ER> )
141-
return expr.er.extents();
144+
return rexpr.extents();
142145

143146
else if constexpr ( has_tensor_types_v<T,EL> )
144-
return retrieve_extents(expr.el);
147+
return retrieve_extents(lexpr);
145148

146149
else if constexpr ( has_tensor_types_v<T,ER> )
147-
return retrieve_extents(expr.er);
150+
return retrieve_extents(rexpr);
148151
}
149152

150153
#ifdef _MSC_VER
@@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
164167

165168
static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
166169
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
170+
171+
auto const& uexpr = expr.expr();
167172

168173
if constexpr ( same_exp<T,E> )
169-
return expr.e.extents();
174+
return uexpr.extents();
170175

171176
else if constexpr ( has_tensor_types_v<T,E> )
172-
return retrieve_extents(expr.e);
177+
return retrieve_extents(uexpr);
173178
}
174179

175180
} // namespace boost::numeric::ublas::detail
@@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& exp
221226
using ::operator==;
222227
using ::operator!=;
223228

229+
auto const& lexpr = expr.left_expr();
230+
auto const& rexpr = expr.right_expr();
231+
224232
if constexpr ( same_exp<T,EL> )
225-
if(e != expr.el.extents())
233+
if(e != lexpr.extents())
226234
return false;
227235

228236
if constexpr ( same_exp<T,ER> )
229-
if(e != expr.er.extents())
237+
if(e != rexpr.extents())
230238
return false;
231239

232240
if constexpr ( has_tensor_types_v<T,EL> )
233-
if(!all_extents_equal(expr.el, e))
241+
if(!all_extents_equal(lexpr, e))
234242
return false;
235243

236244
if constexpr ( has_tensor_types_v<T,ER> )
237-
if(!all_extents_equal(expr.er, e))
245+
if(!all_extents_equal(rexpr, e))
238246
return false;
239247

240248
return true;
@@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, ex
250258

251259
using ::operator==;
252260

261+
auto const& uexpr = expr.expr();
262+
253263
if constexpr ( same_exp<T,E> )
254-
if(e != expr.e.extents())
264+
if(e != uexpr.extents())
255265
return false;
256266

257267
if constexpr ( has_tensor_types_v<T,E> )
258-
if(!all_extents_equal(expr.e, e))
268+
if(!all_extents_equal(uexpr, e))
259269
return false;
260270

261271
return true;
@@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
281291
if(!all_extents_equal(expr, lhs.extents() ))
282292
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");
283293

284-
#pragma omp parallel for
294+
auto const& rhs = cast_tensor_exression(expr);
295+
296+
#pragma omp parallel for
285297
for(auto i = 0u; i < lhs.size(); ++i)
286-
lhs(i) = expr()(i);
298+
lhs(i) = rhs(i);
287299
}
288300

289301
/** @brief Evaluates expression for a tensor_core
@@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression<other_tensor_type, derived_
310322
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");
311323
}
312324

325+
auto const& rhs = cast_tensor_exression(expr);
326+
313327
#pragma omp parallel for
314328
for(auto i = 0u; i < lhs.size(); ++i)
315-
lhs(i) = expr()(i);
329+
lhs(i) = rhs(i);
316330
}
317331

318332
/** @brief Evaluates expression for a tensor_core
@@ -330,9 +344,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
330344
if(!all_extents_equal( expr, lhs.extents() ))
331345
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");
332346

347+
auto const& rhs = cast_tensor_exression(expr);
348+
333349
#pragma omp parallel for
334350
for(auto i = 0u; i < lhs.size(); ++i)
335-
fn(lhs(i), expr()(i));
351+
fn(lhs(i), rhs(i));
336352
}
337353

338354

@@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
347363
template<class tensor_type, class unary_fn>
348364
inline void eval(tensor_type& lhs, unary_fn const& fn)
349365
{
350-
#pragma omp parallel for
366+
#pragma omp parallel for
351367
for(auto i = 0u; i < lhs.size(); ++i)
352368
fn(lhs(i));
353369
}

0 commit comments

Comments
 (0)