@@ -73,14 +73,14 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
73
73
using rvalue_type = decltype (rexpr (0 ));
74
74
75
75
static_assert ( same_exp< lvalue_type, rvalue_type >,
76
- " boost::numeric::ublas::detail::compare : "
76
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
77
77
" both LHS and RHS should have the same value type"
78
78
);
79
79
80
80
static_assert (
81
81
std::is_invocable_r_v<bool , BinaryPred, lvalue_type, rvalue_type>,
82
- " boost::numeric::ublas::detail::compare(lhs,rhs,pred) :"
83
- " predicate must be a binary predicate, and it must return a bool"
82
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
83
+ " the predicate must be a binary predicate, and it must return a bool"
84
84
);
85
85
86
86
auto const & le = retrieve_extents (lexpr);
@@ -97,20 +97,16 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
97
97
using rex_t = std::decay_t < decltype (re) >;
98
98
99
99
if constexpr (is_static_v< lex_t > && is_static_v< rex_t >){
100
- if constexpr (!same_exp< lex_t , rex_t >)
101
- return { false , size_type{} };
102
-
103
- return { true , product_v< lex_t > };
100
+ constexpr bool is_same = same_exp< lex_t , rex_t >;
101
+ return { is_same, is_same ? product_v< lex_t > : size_type{} };
104
102
} else {
105
- if (::operator !=(le,re))
106
- return { false , size_type{} };
107
-
108
- return { true , product ( le ) };
103
+ bool const is_same = ::operator ==(le,re);
104
+ return { is_same, is_same ? product ( le ) : size_type{} };
109
105
}
110
106
};
111
107
112
108
auto const [status, size] = cal_size (le, re);
113
-
109
+
114
110
for (auto i = size_type{}; i < size; ++i){
115
111
if (!std::invoke (pred, lexpr (i), rexpr (i)))
116
112
return false ;
@@ -136,14 +132,14 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
136
132
using rvalue_type = decltype (rexpr (0 ));
137
133
138
134
static_assert ( same_exp< lvalue_type, rvalue_type >,
139
- " boost::numeric::ublas::detail::compare : "
135
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
140
136
" both LHS and RHS should have the same value type"
141
137
);
142
138
143
139
static_assert (
144
140
std::is_invocable_r_v<bool , BinaryPred, lvalue_type, rvalue_type>,
145
- " boost::numeric::ublas::detail::compare(lhs,rhs,pred) :"
146
- " predicate must be a binary predicate, and it must return a bool"
141
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
142
+ " the predicate must be a binary predicate, and it must return a bool"
147
143
);
148
144
149
145
auto const & le = retrieve_extents (lexpr);
@@ -160,15 +156,15 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
160
156
161
157
if constexpr (is_static_v< lex_t > && is_static_v< rex_t >){
162
158
static_assert (same_exp< lex_t , rex_t >,
163
- " boost::numeric::ublas::detail::compare : "
159
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
164
160
" cannot compare tensors with different shapes."
165
161
);
166
162
167
163
return product_v< lex_t >;
168
164
}else {
169
165
if (::operator !=(le,re)){
170
166
throw std::runtime_error (
171
- " boost::numeric::ublas::detail::compare : "
167
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
172
168
" cannot compare tensors with different shapes."
173
169
);
174
170
}
@@ -195,27 +191,17 @@ constexpr bool compare(tensor_expression<T,D> const& expr, UnaryPred&& pred) noe
195
191
auto const & ue = cast_tensor_expression (expr);
196
192
auto const & e = retrieve_extents (ue);
197
193
198
- using size_type = typename T::size_type;
194
+ using size_type = typename T::size_type;
195
+ using value_type = decltype (ue (0 ));
196
+ using extents_t = std::decay_t < decltype (e) >;
199
197
200
198
static_assert (
201
- std::is_invocable_r_v<bool , UnaryPred, decltype ( ue ( 0 )) >,
202
- " boost::numeric::ublas::detail::compare(expr,pred) :"
203
- " predicate must be an unary predicate, and it must return a bool"
199
+ std::is_invocable_r_v<bool , UnaryPred, value_type >,
200
+ " boost::numeric::ublas::detail::compare(tensor_expresion const& expr, UnaryPred&& pred) : "
201
+ " the predicate must be an unary predicate, and it must return a bool"
204
202
);
205
203
206
- // returns the size of the container
207
- constexpr auto cal_size = [](auto const & e)
208
- -> size_type
209
- {
210
- using extents_t = std::decay_t < decltype (e) >;
211
-
212
- if constexpr (is_static_v< extents_t >)
213
- return product_v< extents_t >;
214
- else
215
- return product ( e );
216
- };
217
-
218
- size_type const size = cal_size (e);
204
+ size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product ( e );;
219
205
220
206
for (auto i = size_type{}; i < size; ++i){
221
207
if (!std::invoke (pred, ue (i)))
0 commit comments