Skip to content

Commit 549f8a0

Browse files
refactor(expression): improve error message and simplified if-else.
1 parent 00357a0 commit 549f8a0

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

include/boost/numeric/ublas/tensor/expression_evaluation.hpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,33 +267,41 @@ inline void eval(tensor_core<TensorEngine>& lhs, TensorExpression auto const& ex
267267
is_static_v< std::decay_t< decltype(retrieve_extents(expr)) > >
268268
)
269269
{
270-
using rtensor_t = typename std::decay_t<decltype(expr)>::tensor_type;
270+
auto const& rhs = cast_tensor_expression(expr);
271+
271272
using ltensor_t = tensor_core<TensorEngine>;
272273
using lvalue_type = typename ltensor_t::value_type;
273-
using rvalue_type = typename rtensor_t::value_type;
274274
using lextents_t = typename ltensor_t::extents_type;
275-
using rextents_t = typename rtensor_t::extents_type;
275+
using rvalue_type = std::decay_t< decltype(rhs(0)) >;
276+
using rextents_t = std::decay_t< decltype(retrieve_extents(expr)) >;
276277

277278
static_assert(std::is_same_v<lvalue_type, rvalue_type>,
278-
"boost::numeric::ublas::detail::eval(tensor_core<TensorEngine>&, TensorExpression auto const&) : "
279+
"boost::numeric::ublas::detail::eval(tensor_core& lhs, tensor_expresion const& rhs, BinaryFn&& fn) : "
279280
"both LHS and RHS tensors should have same value type"
280281
);
281282

282283
if constexpr(is_static_v<lextents_t> && is_static_v<rextents_t>){
283284
static_assert(std::is_same_v<lextents_t,rextents_t>,
284-
"boost::numeric::ublas::tensor_core: "
285+
"boost::numeric::ublas::detail::eval(tensor_core& lhs, tensor_expresion const& rhs, BinaryFn&& fn) : "
285286
"both LHS and RHS tensors should have same shape."
286287
);
287288
}else{
288289
if ( !all_extents_equal( expr, lhs.extents() ) ){
289-
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");
290+
throw std::runtime_error(
291+
"boost::numeric::ublas::detail::eval(tensor_core& lhs, tensor_expresion const& rhs, BinaryFn&& fn) : "
292+
"both LHS and RHS tensors should have same shape."
293+
);
290294
}
291295
}
292296

293-
auto const& rhs = cast_tensor_expression(expr);
297+
auto const size = lhs.size();
294298

299+
/// FIXME: add 'simd' clause and 'if' clause that will be used as a starting point
300+
/// for threads to start, otherwise, it's very expansive to use threads for small
301+
/// sized containers.
302+
/// @code #pragma omp parallel for simd if(size > SOME_SIZE) @endcode
295303
#pragma omp parallel for
296-
for(auto i = 0u; i < lhs.size(); ++i)
304+
for(auto i = 0u; i < size; ++i)
297305
std::invoke(fn, lhs(i), rhs(i));
298306
}
299307

@@ -329,8 +337,14 @@ template<class TensorEngine, class UnaryFn>
329337
inline void eval(tensor_core<TensorEngine>& lhs, UnaryFn&& fn)
330338
noexcept( is_static_v< std::decay_t< decltype(retrieve_extents(lhs)) > > )
331339
{
340+
auto const size = lhs.size();
341+
342+
/// FIXME: add 'simd' clause and 'if' clause that will be used as a starting point
343+
/// for threads to start, otherwise, it's very expansive to use threads for small
344+
/// sized containers.
345+
/// @code #pragma omp parallel for simd if(size > SOME_SIZE) @endcode
332346
#pragma omp parallel for
333-
for(auto i = 0u; i < lhs.size(); ++i)
347+
for(auto i = 0u; i < size; ++i)
334348
std::invoke(fn, lhs(i));
335349
}
336350

include/boost/numeric/ublas/tensor/operators_comparison.hpp

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
7373
using rvalue_type = decltype(rexpr(0));
7474

7575
static_assert( same_exp< lvalue_type, rvalue_type >,
76-
"boost::numeric::ublas::detail::compare : "
76+
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
7777
"both LHS and RHS should have the same value type"
7878
);
7979

8080
static_assert(
8181
std::is_invocable_r_v<bool, BinaryPred, lvalue_type, rvalue_type>,
82-
"boost::numeric::ublas::detail::compare(lhs,rhs,pred) :"
83-
"predicate must be a binary predicate, and it must return a bool"
82+
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
83+
"the predicate must be a binary predicate, and it must return a bool"
8484
);
8585

8686
auto const& le = retrieve_extents(lexpr);
@@ -97,20 +97,16 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
9797
using rex_t = std::decay_t< decltype(re) >;
9898

9999
if constexpr(is_static_v< lex_t > && is_static_v< rex_t >){
100-
if constexpr(!same_exp< lex_t, rex_t >)
101-
return { false, size_type{} };
102-
103-
return { true, product_v< lex_t > };
100+
constexpr bool is_same = same_exp< lex_t, rex_t >;
101+
return { is_same, is_same ? product_v< lex_t > : size_type{} };
104102
} else {
105-
if(::operator!=(le,re))
106-
return { false, size_type{} };
107-
108-
return { true, product( le ) };
103+
bool const is_same = ::operator==(le,re);
104+
return { is_same, is_same ? product( le ) : size_type{} };
109105
}
110106
};
111107

112108
auto const [status, size] = cal_size(le, re);
113-
109+
114110
for(auto i = size_type{}; i < size; ++i){
115111
if(!std::invoke(pred, lexpr(i), rexpr(i)))
116112
return false;
@@ -136,14 +132,14 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
136132
using rvalue_type = decltype(rexpr(0));
137133

138134
static_assert( same_exp< lvalue_type, rvalue_type >,
139-
"boost::numeric::ublas::detail::compare : "
135+
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
140136
"both LHS and RHS should have the same value type"
141137
);
142138

143139
static_assert(
144140
std::is_invocable_r_v<bool, BinaryPred, lvalue_type, rvalue_type>,
145-
"boost::numeric::ublas::detail::compare(lhs,rhs,pred) :"
146-
"predicate must be a binary predicate, and it must return a bool"
141+
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
142+
"the predicate must be a binary predicate, and it must return a bool"
147143
);
148144

149145
auto const& le = retrieve_extents(lexpr);
@@ -160,15 +156,15 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
160156

161157
if constexpr(is_static_v< lex_t > && is_static_v< rex_t >){
162158
static_assert(same_exp< lex_t, rex_t >,
163-
"boost::numeric::ublas::detail::compare : "
159+
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
164160
"cannot compare tensors with different shapes."
165161
);
166162

167163
return product_v< lex_t >;
168164
}else{
169165
if(::operator!=(le,re)){
170166
throw std::runtime_error(
171-
"boost::numeric::ublas::detail::compare : "
167+
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
172168
"cannot compare tensors with different shapes."
173169
);
174170
}
@@ -195,27 +191,17 @@ constexpr bool compare(tensor_expression<T,D> const& expr, UnaryPred&& pred) noe
195191
auto const& ue = cast_tensor_expression(expr);
196192
auto const& e = retrieve_extents(ue);
197193

198-
using size_type = typename T::size_type;
194+
using size_type = typename T::size_type;
195+
using value_type = decltype(ue(0));
196+
using extents_t = std::decay_t< decltype(e) >;
199197

200198
static_assert(
201-
std::is_invocable_r_v<bool, UnaryPred, decltype(ue(0))>,
202-
"boost::numeric::ublas::detail::compare(expr,pred) :"
203-
"predicate must be an unary predicate, and it must return a bool"
199+
std::is_invocable_r_v<bool, UnaryPred, value_type>,
200+
"boost::numeric::ublas::detail::compare(tensor_expresion const& expr, UnaryPred&& pred) : "
201+
"the predicate must be an unary predicate, and it must return a bool"
204202
);
205203

206-
// returns the size of the container
207-
constexpr auto cal_size = [](auto const& e)
208-
-> size_type
209-
{
210-
using extents_t = std::decay_t< decltype(e) >;
211-
212-
if constexpr(is_static_v< extents_t >)
213-
return product_v< extents_t >;
214-
else
215-
return product( e );
216-
};
217-
218-
size_type const size = cal_size(e);
204+
size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product( e );;
219205

220206
for(auto i = size_type{}; i < size; ++i){
221207
if(!std::invoke(pred, ue(i)))

0 commit comments

Comments
 (0)