Skip to content

Commit 23ba8fc

Browse files
refactor(compare): combine two compare function into one for easier maintainability
1 parent 549f8a0 commit 23ba8fc

File tree

1 file changed

+49
-82
lines changed

1 file changed

+49
-82
lines changed

include/boost/numeric/ublas/tensor/operators_comparison.hpp

+49-82
Original file line numberDiff line numberDiff line change
@@ -60,70 +60,62 @@ struct is_equality_functional_object< std::not_equal_to<> >
6060
: std::true_type
6161
{};
6262

63-
template<class T1, class T2, class L, class R, class BinaryPred>
64-
[[nodiscard]] inline
65-
constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,R> const& rhs, BinaryPred&& pred) noexcept
66-
requires is_equality_functional_object_v<BinaryPred>
63+
template<integral SizeType, typename LE, typename RE>
64+
[[nodiscard]]
65+
constexpr auto compare_helper(LE const& le, RE const& re, std::true_type /*unused*/) noexcept
66+
-> std::pair<bool, SizeType>
6767
{
68-
69-
auto const& lexpr = cast_tensor_expression(lhs);
70-
auto const& rexpr = cast_tensor_expression(rhs);
71-
72-
using lvalue_type = decltype(lexpr(0));
73-
using rvalue_type = decltype(rexpr(0));
74-
75-
static_assert( same_exp< lvalue_type, rvalue_type >,
76-
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
77-
"both LHS and RHS should have the same value type"
78-
);
79-
80-
static_assert(
81-
std::is_invocable_r_v<bool, BinaryPred, lvalue_type, rvalue_type>,
82-
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
83-
"the predicate must be a binary predicate, and it must return a bool"
84-
);
85-
86-
auto const& le = retrieve_extents(lexpr);
87-
auto const& re = retrieve_extents(rexpr);
88-
89-
using size_type = typename T1::size_type;
68+
using ::operator==;
69+
70+
constexpr auto zero = SizeType{};
71+
72+
if constexpr( is_static_v< LE > && is_static_v< RE > ){
73+
constexpr bool is_same = std::is_same_v<LE, RE>;
74+
constexpr SizeType size = ( is_same ? SizeType{ product_v< LE > } : zero );
75+
return { is_same, size };
76+
}else{
77+
bool const is_same = ( le == re );
78+
SizeType const size = ( is_same ? SizeType{ product(le) } : zero );
79+
return { is_same, size };
80+
}
81+
}
9082

91-
// returns the pair containing false if extents are not equal
92-
// else true, and the size of the container.
93-
constexpr auto cal_size = [](auto const& le, auto const& re)
94-
-> std::pair<bool, size_type>
95-
{
96-
using lex_t = std::decay_t< decltype(le) >;
97-
using rex_t = std::decay_t< decltype(re) >;
98-
99-
if constexpr(is_static_v< lex_t > && is_static_v< rex_t >){
100-
constexpr bool is_same = same_exp< lex_t, rex_t >;
101-
return { is_same, is_same ? product_v< lex_t > : size_type{} };
102-
} else {
103-
bool const is_same = ::operator==(le,re);
104-
return { is_same, is_same ? product( le ) : size_type{} };
83+
template<typename SizeType, typename LE, typename RE>
84+
[[nodiscard]]
85+
constexpr auto compare_helper(LE const& le, RE const& re, std::false_type /*unused*/)
86+
noexcept( is_static_v< LE> && is_static_v< RE > ) -> std::pair<bool, SizeType>
87+
{
88+
using ::operator!=;
89+
90+
if constexpr( is_static_v< LE > && is_static_v< RE > ){
91+
static_assert(std::is_same_v< LE, RE >,
92+
"boost::numeric::ublas::detail::compare_helper(Lextents const& lhs, Rextents const& rhs) : "
93+
"cannot compare tensors with different shapes."
94+
);
95+
96+
constexpr SizeType size = product_v< LE >;
97+
return { true, size };
98+
}else{
99+
if(le != re){
100+
throw std::runtime_error(
101+
"boost::numeric::ublas::detail::compare_helper(Lextents const& lhs, Rextents const& rhs) : "
102+
"cannot compare tensors with different shapes."
103+
);
105104
}
106-
};
107105

108-
auto const [status, size] = cal_size(le, re);
109-
110-
for(auto i = size_type{}; i < size; ++i){
111-
if(!std::invoke(pred, lexpr(i), rexpr(i)))
112-
return false;
106+
SizeType const size = product( le );
107+
return { true, size };
113108
}
114-
115-
// return false if the status is false
116-
return ( true & status );
117109
}
118110

119111
template<class T1, class T2, class L, class R, class BinaryPred>
120112
[[nodiscard]] inline
121113
constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,R> const& rhs, BinaryPred&& pred)
122114
noexcept(
123-
is_static_v< std::decay_t< decltype(retrieve_extents(lhs)) > > &&
124-
is_static_v< std::decay_t< decltype(retrieve_extents(rhs)) > >
115+
( is_static_v< std::decay_t< decltype(retrieve_extents(lhs)) > > &&
116+
is_static_v< std::decay_t< decltype(retrieve_extents(rhs)) > >
117+
) || is_equality_functional_object_v<BinaryPred>
125118
)
126-
requires ( not is_equality_functional_object_v<BinaryPred> )
127119
{
128120
auto const& lexpr = cast_tensor_expression(lhs);
129121
auto const& rexpr = cast_tensor_expression(rhs);
@@ -146,41 +138,16 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
146138
auto const& re = retrieve_extents(rexpr);
147139

148140
using size_type = typename T1::size_type;
141+
using is_eq_t = std::conditional_t< is_equality_functional_object_v<BinaryPred>, std::true_type, std::false_type >;
149142

150-
// returns the size of the container
151-
constexpr auto cal_size = [](auto const& le, auto const& re)
152-
-> size_type
153-
{
154-
using lex_t = std::decay_t< decltype(le) >;
155-
using rex_t = std::decay_t< decltype(re) >;
156-
157-
if constexpr(is_static_v< lex_t > && is_static_v< rex_t >){
158-
static_assert(same_exp< lex_t, rex_t >,
159-
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
160-
"cannot compare tensors with different shapes."
161-
);
162-
163-
return product_v< lex_t >;
164-
}else{
165-
if(::operator!=(le,re)){
166-
throw std::runtime_error(
167-
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
168-
"cannot compare tensors with different shapes."
169-
);
170-
}
171-
172-
return product( le );
173-
}
174-
};
175-
176-
size_type const size = cal_size(le, re);
177-
143+
auto const [status, size] = compare_helper<size_type>(le, re, is_eq_t{});
144+
178145
for(auto i = size_type{}; i < size; ++i){
179146
if(!std::invoke(pred, lexpr(i), rexpr(i)))
180147
return false;
181148
}
182149

183-
return true;
150+
return status;
184151
}
185152

186153

@@ -201,7 +168,7 @@ constexpr bool compare(tensor_expression<T,D> const& expr, UnaryPred&& pred) noe
201168
"the predicate must be an unary predicate, and it must return a bool"
202169
);
203170

204-
size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product( e );;
171+
size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product( e );
205172

206173
for(auto i = size_type{}; i < size; ++i){
207174
if(!std::invoke(pred, ue(i)))

0 commit comments

Comments
 (0)