15
15
#include < stdexcept>
16
16
#include < type_traits>
17
17
18
+ #include " ../multiplication.hpp"
18
19
#include " ../extents.hpp"
19
20
#include " ../type_traits.hpp"
20
21
#include " ../tags.hpp"
@@ -49,6 +50,190 @@ using enable_ttv_if_extent_has_dynamic_rank = std::enable_if_t<is_dynamic_rank_v
49
50
} // namespace detail
50
51
51
52
53
+ namespace detail {
54
+ template <class TC , class TA , class V , class EC >
55
+ inline auto scalar_scalar_prod (TA const &a, V const &b, EC const & nc_base)
56
+ {
57
+ assert (ublas::is_scalar (a.extents ()));
58
+ using tensor = TC;
59
+ using value = typename tensor::value_type;
60
+ using shape = typename tensor::extents_type;
61
+ return tensor (shape (nc_base),value (a[0 ]*b (0 )));
62
+ }
63
+
64
+ template <class TC , class TA , class V , class EC >
65
+ inline auto vector_vector_prod (TA const &a, V const &b, EC& nc_base, std::size_t m)
66
+ {
67
+ auto const & na = a.extents ();
68
+
69
+ assert ( ublas::is_vector (na));
70
+ assert (!ublas::is_scalar (na));
71
+ assert ( ublas::size (na) > 1u );
72
+ assert (m > 0 );
73
+
74
+ using tensor = TC;
75
+ using value = typename tensor::value_type;
76
+ using shape = typename tensor::extents_type;
77
+
78
+ auto const n1 = na[0 ];
79
+ auto const n2 = na[1 ];
80
+ auto const s = b.size ();
81
+
82
+ // general
83
+ // [n1 n2 1 ... 1] xj [s 1] for any 1 <= j <= p with n1==1 or n2==1
84
+
85
+
86
+ // [n1 1 1 ... 1] x1 [n1 1] -> [1 1 1 ... 1]
87
+ // [1 n2 1 ... 1] x2 [n2 1] -> [1 1 1 ... 1]
88
+
89
+
90
+ assert (n1>1 || n2>1 );
91
+
92
+ if ( (n1>1u && m==1u ) || (n2>1u && m==2u ) ){
93
+ if (m==1u ) assert (n2==1u && n1==s);
94
+ if (m==2u ) assert (n1==1u && n2==s);
95
+ auto cc = std::inner_product ( a.begin (), a.end (), b.begin (), value (0 ) );
96
+ return tensor (shape (nc_base),value (cc));
97
+ }
98
+
99
+ // [n1 1 1 ... 1] xj [1 1] -> [n1 1 1 ... 1] with j != 1
100
+ // [1 n2 1 ... 1] xj [1 1] -> [1 n2 1 ... 1] with j != 2
101
+
102
+ // if( (n1>1u && m!=1u) && (n2>0u && m!=2u) ){
103
+
104
+ if (n1>1u ) assert (m!=1u );
105
+ if (n2>1u ) assert (m!=2u );
106
+ assert (s==1u );
107
+
108
+ if (n1>1u ) assert (n2==1u );
109
+ if (n2>1u ) assert (n1==1u );
110
+
111
+ if (n1>1u ) nc_base[0 ] = n1;
112
+ if (n2>1u ) nc_base[1 ] = n2;
113
+
114
+ auto bb = b (0 );
115
+ auto c = tensor (shape (nc_base));
116
+ std::transform (a.begin (),a.end (),c.begin (),[bb](auto aa){ return aa*bb; });
117
+ return c;
118
+ // }
119
+
120
+
121
+ }
122
+
123
+
124
+ /* * Computes a matrix-vector product.
125
+ *
126
+ *
127
+ * @note assume stride 1 for specific dimensions and therefore requires refactoring for subtensor
128
+ *
129
+ */
130
+ template <class TC , class TA , class V , class EC >
131
+ inline auto matrix_vector_prod (TA const &a, V const &b, EC& nc_base, std::size_t m)
132
+ {
133
+ auto const & na = a.extents ();
134
+
135
+ assert ( ublas::is_matrix (na));
136
+ assert (!ublas::is_vector (na));
137
+ assert (!ublas::is_scalar (na));
138
+ assert ( ublas::size (na) > 1u );
139
+ assert (m > 0 );
140
+
141
+ using tensor = TC;
142
+ using shape = typename tensor::extents_type;
143
+ using size_t = typename shape::value_type;
144
+
145
+ auto const n1 = na[0 ];
146
+ auto const n2 = na[1 ];
147
+ auto const s = b.size ();
148
+
149
+ // general
150
+ // [n1 n2 1 ... 1] xj [s 1] for any 1 <= j <= p with either n1>1 and n2>1
151
+
152
+
153
+ // if [n1 n2 1 ... 1] xj [1 1] -> [n1 n2 1 ... 1] for j > 2
154
+ if (m > 2 ){
155
+ nc_base[0 ] = n1;
156
+ nc_base[1 ] = n2;
157
+ assert (s == 1 );
158
+ auto c = tensor (shape (nc_base));
159
+ auto const bb = b (0 );
160
+ std::transform (a.begin (),a.end (), c.begin (), [bb](auto aa){return aa*bb;});
161
+ return c;
162
+ }
163
+
164
+
165
+ // [n1 n2 1 ... 1] x1 [n1 1] -> [n2 1 ... 1] -> vector-times-matrix
166
+ // [n1 n2 1 ... 1] x2 [n2 1] -> [n1 1 ... 1] -> matrix-times-vector
167
+
168
+ nc_base[0 ] = m==1 ? n2 : n1;
169
+
170
+ auto c = tensor (shape (nc_base));
171
+ auto const & wa = a.strides ();
172
+ auto const * bdata = &(b (0 ));
173
+
174
+ detail::recursive::mtv (m-1 ,n1,n2, c.data (), size_t (1 ), a.data (), wa[0 ], wa[1 ], bdata, size_t (1 ));
175
+
176
+ return c;
177
+ }
178
+
179
+
180
+
181
+ template <class TC , class TA , class V , class EC >
182
+ inline auto tensor_vector_prod (TA const &a, V const &b, EC& nc_base, std::size_t m)
183
+ {
184
+ auto const & na = a.extents ();
185
+
186
+ assert ( ublas::is_tensor (na));
187
+ assert ( ublas::size (na) > 1u );
188
+ assert (m > 0 );
189
+
190
+ using tensor = TC;
191
+ using shape = typename tensor::extents_type;
192
+ using layout = typename tensor::layout_type;
193
+
194
+ auto const pa = a.rank ();
195
+ auto const nm = na[m-1 ];
196
+ auto const s = b.size ();
197
+
198
+ auto nb = extents<2 >{std::size_t (b.size ()),std::size_t (1ul )};
199
+ auto wb = ublas::to_strides (nb,layout{} );
200
+
201
+ // TODO: Include an outer product when legacy vector becomes a new vector.
202
+
203
+ for (auto i = 0ul , j = 0ul ; i < pa; ++i)
204
+ if (i != m - 1 )
205
+ nc_base[j++] = na.at (i);
206
+
207
+ auto c = tensor (shape (nc_base));
208
+
209
+ // [n1 n2 ... nm ... np] xm [1 1] -> [n1 n2 ... nm-1 nm+1 ... np]
210
+
211
+ if (s == 0 ){
212
+ assert (nm == 1 );
213
+ auto const bb = b (0 );
214
+ std::transform (a.begin (),a.end (), c.begin (), [bb](auto aa){return aa*bb;});
215
+ return c;
216
+ }
217
+
218
+
219
+ // if [n1 n2 n3 ... np] xm [nm 1] -> [n1 n2 ... nm-1 nm+1 ... np]
220
+
221
+ auto const & nc = c.extents ();
222
+ auto const & wc = c.strides ();
223
+ auto const & wa = a.strides ();
224
+ auto const * bp = &(b (0 ));
225
+
226
+ ttv (m, pa,
227
+ c.data (), nc.data (), wc.data (),
228
+ a.data (), na.data (), wa.data (),
229
+ bp, nb.data (), wb.data ());
230
+
231
+ return c;
232
+ }
233
+
234
+ }// namespace detail
235
+
236
+
52
237
/* * @brief Computes the m-mode tensor-times-vector product
53
238
*
54
239
* Implements C[i1,...,im-1,im+1,...,ip] = A[i1,i2,...,ip] * b[im]
@@ -63,45 +248,49 @@ using enable_ttv_if_extent_has_dynamic_rank = std::enable_if_t<is_dynamic_rank_v
63
248
*/
64
249
template <class TE , class A , class T = typename tensor_core< TE >::value,
65
250
detail::enable_ttv_if_extent_has_dynamic_rank<TE> = true >
66
- inline decltype ( auto ) prod( tensor_core< TE > const &a, vector<T, A> const &b, const std::size_t m)
251
+ inline auto prod ( tensor_core< TE > const &a, vector<T, A> const &b, const std::size_t m)
67
252
{
68
253
69
254
using tensor = tensor_core< TE >;
70
255
using shape = typename tensor::extents_type;
71
- using value = typename tensor::value_type;
72
- using layout = typename tensor::layout_type;
73
256
using resize_tag = typename tensor::resizable_tag;
74
257
75
- auto const p = a.rank ();
258
+ auto const pa = a.rank ();
76
259
77
260
static_assert (std::is_same_v<resize_tag,storage_resizable_container_tag>);
78
261
static_assert (is_dynamic_v<shape>);
79
262
80
263
if (m == 0ul ) throw std::length_error (" error in boost::numeric::ublas::prod(ttv): contraction mode must be greater than zero." );
81
- if (p < m) throw std::length_error (" error in boost::numeric::ublas::prod(ttv): rank of tensor must be greater than or equal to the contraction mode." );
264
+ if (pa < m) throw std::length_error (" error in boost::numeric::ublas::prod(ttv): rank of tensor must be greater than or equal to the contraction mode." );
82
265
if (a.empty ()) throw std::length_error (" error in boost::numeric::ublas::prod(ttv): first argument tensor should not be empty." );
83
266
if (b.empty ()) throw std::length_error (" error in boost::numeric::ublas::prod(ttv): second argument vector should not be empty." );
84
267
85
268
auto const & na = a.extents ();
86
- auto nb = extents< 2 >{ std::size_t (b. size ()), std::size_t ( 1ul )};
87
- auto wb = ublas::to_strides (nb,layout{} );
269
+
270
+ if (b. size () != na[m- 1 ]) throw std::length_error ( " error in boost::numeric:: ublas::prod(ttv): dimension mismatch of tensor and vector. " );
88
271
89
272
auto const sz = std::max ( std::size_t (ublas::size (na)-1u ), std::size_t (2 ) );
90
273
auto nc_base = typename shape::base_type (sz,1 );
91
274
92
- for (auto i = 0ul , j = 0ul ; i < p; ++i)
93
- if (i != m - 1 )
94
- nc_base[j++] = na.at (i);
275
+ // output scalar tensor
276
+ if (ublas::is_scalar (na)){
277
+ return detail::scalar_scalar_prod<tensor>(a,b,nc_base);
278
+ }
279
+
280
+ // output scalar tensor or vector tensor
281
+ if (ublas::is_vector (na)){
282
+ return detail::vector_vector_prod<tensor>(a,b,nc_base,m);
283
+ }
284
+
285
+ // output scalar tensor or vector tensor
286
+ if (ublas::is_matrix (na)){
287
+ return detail::matrix_vector_prod<tensor>(a,b,nc_base,m);
288
+ }
289
+
290
+ assert (ublas::is_tensor (na));
291
+ return detail::tensor_vector_prod<tensor>(a,b,nc_base,m);
95
292
96
- auto nc = shape (nc_base);
97
- auto c = tensor ( nc, value{} );
98
293
99
- auto const * bb = &(b (0 ));
100
- ttv (m, p,
101
- c.data (), c.extents ().data (), c.strides ().data (),
102
- a.data (), a.extents ().data (), a.strides ().data (),
103
- bb, nb.data (), wb.data ());
104
- return c;
105
294
}
106
295
107
296
@@ -143,7 +332,6 @@ inline auto prod( tensor_core< TE > const &a, vector<T, A> const &b, const std::
143
332
constexpr auto p = std::tuple_size_v<shape>;
144
333
constexpr auto sz = std::max (std::size_t (std::tuple_size_v<shape>-1U ),std::size_t (2 ));
145
334
146
- using shape_b = ublas::extents<2 >;
147
335
using shape_c = ublas::extents<sz>;
148
336
using tensor_c = tensor_core<tensor_engine<shape_c,layout,container>>;
149
337
@@ -158,21 +346,25 @@ inline auto prod( tensor_core< TE > const &a, vector<T, A> const &b, const std::
158
346
159
347
auto nc_base = typename shape_c::base_type{};
160
348
std::fill (nc_base.begin (), nc_base.end (),std::size_t (1 ));
161
- for (auto i = 0ul , j = 0ul ; i < p; ++i)
162
- if (i != m - 1 )
163
- nc_base[j++] = na.at (i);
164
349
165
- auto nc = shape_c (std::move (nc_base));
166
- auto nb = shape_b{b.size (),1UL };
167
- auto wb = ublas::to_strides (nb,layout{});
168
- auto c = tensor_c ( std::move (nc) );
169
- auto const * bb = &(b (0 ));
170
350
171
- ttv (m, p,
172
- c.data (), c.extents ().data (), c.strides ().data (),
173
- a.data (), a.extents ().data (), a.strides ().data (),
174
- bb, nb.data (), wb.data () );
175
- return c;
351
+ // output scalar tensor
352
+ if (ublas::is_scalar (na)){
353
+ return detail::scalar_scalar_prod<tensor_c>(a,b,nc_base);
354
+ }
355
+
356
+ // output scalar tensor or vector tensor
357
+ if (ublas::is_vector (na)){
358
+ return detail::vector_vector_prod<tensor_c>(a,b,nc_base,m);
359
+ }
360
+
361
+ // output scalar tensor or vector tensor
362
+ if (ublas::is_matrix (na)){
363
+ return detail::matrix_vector_prod<tensor_c>(a,b,nc_base,m);
364
+ }
365
+
366
+ assert (ublas::is_tensor (na));
367
+ return detail::tensor_vector_prod<tensor_c>(a,b,nc_base,m);
176
368
}
177
369
178
370
@@ -201,7 +393,7 @@ inline auto prod( tensor_core< TE > const &a, vector<T, A> const &b)
201
393
using shape = typename tensor::extents;
202
394
using layout = typename tensor::layout;
203
395
using shape_b = extents<2 >;
204
- using shape_c = remove_element_t <m,shape>;
396
+ using shape_c = remove_element_t <m,shape>; // this is wrong
205
397
using container_c = rebind_storage_size_t <shape_c,container>;
206
398
using tensor_c = tensor_core<tensor_engine<shape_c,layout,container_c>>;
207
399
0 commit comments