@@ -60,70 +60,62 @@ struct is_equality_functional_object< std::not_equal_to<> >
60
60
: std::true_type
61
61
{};
62
62
63
- template <class T1 , class T2 , class L , class R , class BinaryPred >
64
- [[nodiscard]] inline
65
- constexpr bool compare (tensor_expression<T1,L> const & lhs, tensor_expression<T2,R> const & rhs, BinaryPred&& pred ) noexcept
66
- requires is_equality_functional_object_v<BinaryPred >
63
+ template <integral SizeType, typename LE, typename RE >
64
+ [[nodiscard]]
65
+ constexpr auto compare_helper (LE const & le, RE const & re, std::true_type /* unused */ ) noexcept
66
+ -> std::pair<bool, SizeType >
67
67
{
68
-
69
- auto const & lexpr = cast_tensor_expression (lhs);
70
- auto const & rexpr = cast_tensor_expression (rhs);
71
-
72
- using lvalue_type = decltype (lexpr (0 ));
73
- using rvalue_type = decltype (rexpr (0 ));
74
-
75
- static_assert ( same_exp< lvalue_type, rvalue_type >,
76
- " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
77
- " both LHS and RHS should have the same value type"
78
- );
79
-
80
- static_assert (
81
- std::is_invocable_r_v<bool , BinaryPred, lvalue_type, rvalue_type>,
82
- " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
83
- " the predicate must be a binary predicate, and it must return a bool"
84
- );
85
-
86
- auto const & le = retrieve_extents (lexpr);
87
- auto const & re = retrieve_extents (rexpr);
88
-
89
- using size_type = typename T1::size_type;
68
+ using ::operator ==;
69
+
70
+ constexpr auto zero = SizeType{};
71
+
72
+ if constexpr ( is_static_v< LE > && is_static_v< RE > ){
73
+ constexpr bool is_same = std::is_same_v<LE, RE>;
74
+ constexpr SizeType size = ( is_same ? SizeType{ product_v< LE > } : zero );
75
+ return { is_same, size };
76
+ }else {
77
+ bool const is_same = ( le == re );
78
+ SizeType const size = ( is_same ? SizeType{ product (le) } : zero );
79
+ return { is_same, size };
80
+ }
81
+ }
90
82
91
- // returns the pair containing false if extents are not equal
92
- // else true, and the size of the container.
93
- constexpr auto cal_size = [](auto const & le, auto const & re)
94
- -> std::pair<bool , size_type>
95
- {
96
- using lex_t = std::decay_t < decltype (le) >;
97
- using rex_t = std::decay_t < decltype (re) >;
98
-
99
- if constexpr (is_static_v< lex_t > && is_static_v< rex_t >){
100
- constexpr bool is_same = same_exp< lex_t , rex_t >;
101
- return { is_same, is_same ? product_v< lex_t > : size_type{} };
102
- } else {
103
- bool const is_same = ::operator ==(le,re);
104
- return { is_same, is_same ? product ( le ) : size_type{} };
83
+ template <typename SizeType, typename LE, typename RE>
84
+ [[nodiscard]]
85
+ constexpr auto compare_helper (LE const & le, RE const & re, std::false_type /* unused*/ )
86
+ noexcept ( is_static_v< LE> && is_static_v< RE > ) -> std::pair<bool, SizeType>
87
+ {
88
+ using ::operator !=;
89
+
90
+ if constexpr ( is_static_v< LE > && is_static_v< RE > ){
91
+ static_assert (std::is_same_v< LE, RE >,
92
+ " boost::numeric::ublas::detail::compare_helper(Lextents const& lhs, Rextents const& rhs) : "
93
+ " cannot compare tensors with different shapes."
94
+ );
95
+
96
+ constexpr SizeType size = product_v< LE >;
97
+ return { true , size };
98
+ }else {
99
+ if (le != re){
100
+ throw std::runtime_error (
101
+ " boost::numeric::ublas::detail::compare_helper(Lextents const& lhs, Rextents const& rhs) : "
102
+ " cannot compare tensors with different shapes."
103
+ );
105
104
}
106
- };
107
105
108
- auto const [status, size] = cal_size (le, re);
109
-
110
- for (auto i = size_type{}; i < size; ++i){
111
- if (!std::invoke (pred, lexpr (i), rexpr (i)))
112
- return false ;
106
+ SizeType const size = product ( le );
107
+ return { true , size };
113
108
}
114
-
115
- // return false if the status is false
116
- return ( true & status );
117
109
}
118
110
119
111
template <class T1 , class T2 , class L , class R , class BinaryPred >
120
112
[[nodiscard]] inline
121
113
constexpr bool compare (tensor_expression<T1,L> const & lhs, tensor_expression<T2,R> const & rhs, BinaryPred&& pred)
122
114
noexcept (
123
- is_static_v< std::decay_t < decltype(retrieve_extents(lhs)) > > &&
124
- is_static_v< std::decay_t< decltype(retrieve_extents(rhs)) > >
115
+ ( is_static_v< std::decay_t < decltype(retrieve_extents(lhs)) > > &&
116
+ is_static_v< std::decay_t< decltype(retrieve_extents(rhs)) > >
117
+ ) || is_equality_functional_object_v<BinaryPred>
125
118
)
126
- requires ( not is_equality_functional_object_v<BinaryPred> )
127
119
{
128
120
auto const & lexpr = cast_tensor_expression (lhs);
129
121
auto const & rexpr = cast_tensor_expression (rhs);
@@ -146,41 +138,16 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
146
138
auto const & re = retrieve_extents (rexpr);
147
139
148
140
using size_type = typename T1::size_type;
141
+ using is_eq_t = std::conditional_t < is_equality_functional_object_v<BinaryPred>, std::true_type, std::false_type >;
149
142
150
- // returns the size of the container
151
- constexpr auto cal_size = [](auto const & le, auto const & re)
152
- -> size_type
153
- {
154
- using lex_t = std::decay_t < decltype (le) >;
155
- using rex_t = std::decay_t < decltype (re) >;
156
-
157
- if constexpr (is_static_v< lex_t > && is_static_v< rex_t >){
158
- static_assert (same_exp< lex_t , rex_t >,
159
- " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
160
- " cannot compare tensors with different shapes."
161
- );
162
-
163
- return product_v< lex_t >;
164
- }else {
165
- if (::operator !=(le,re)){
166
- throw std::runtime_error (
167
- " boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
168
- " cannot compare tensors with different shapes."
169
- );
170
- }
171
-
172
- return product ( le );
173
- }
174
- };
175
-
176
- size_type const size = cal_size (le, re);
177
-
143
+ auto const [status, size] = compare_helper<size_type>(le, re, is_eq_t {});
144
+
178
145
for (auto i = size_type{}; i < size; ++i){
179
146
if (!std::invoke (pred, lexpr (i), rexpr (i)))
180
147
return false ;
181
148
}
182
149
183
- return true ;
150
+ return status ;
184
151
}
185
152
186
153
@@ -201,7 +168,7 @@ constexpr bool compare(tensor_expression<T,D> const& expr, UnaryPred&& pred) noe
201
168
" the predicate must be an unary predicate, and it must return a bool"
202
169
);
203
170
204
- size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product ( e );;
171
+ size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product ( e );
205
172
206
173
for (auto i = size_type{}; i < size; ++i){
207
174
if (!std::invoke (pred, ue (i)))
0 commit comments