@@ -377,41 +377,45 @@ namespace nda {
377
377
}
378
378
379
379
// ------------------------------- concatenate --------------------------------------------
380
- // slice in all dimensions but Axis
381
- template <auto Axis, Array A>
382
- auto all_view_except (A const &a, auto const &arg) {
383
- auto slice_or_arg = [&arg](auto x) {
384
- if constexpr (Axis == decltype (x)::value)
385
- return arg;
386
- else
387
- return range::all;
388
- };
389
-
390
- return [&]<auto ... Is>(std::index_sequence<Is...>) { return a (slice_or_arg (std::integral_constant<size_t , Is>{})...); }
391
- (std::make_index_sequence<A::rank>{});
392
- };
393
380
394
- // numpy style concatenation
395
- template <auto Axis, Array A0, Array... A>
381
+ /* *
382
+ * Join a sequence of arrays along an existing axis.
383
+ *
384
+ * The arrays must have the same value_type and also shape,
385
+ * except in the dimension corresponding to axis (the first, by default).
386
+ *
387
+ * @tparam Axis The axis along which to concatenate (default: 0)
388
+ * @tparam A0 Type of the first array
389
+ * @tparam A Types of the subsequent arrays
390
+ * @param a0 The first array
391
+ * @param a The subsequent arrays
392
+ * @return New array with the concatenated data
393
+ */
394
+ template <size_t Axis = 0 , Array A0, Array... A>
396
395
auto concatenate (A0 const &a0, A const &...a) {
397
396
// sanity checks
398
- static_assert (A0::rank >= Axis);
399
- static_assert (((A0::rank == A::rank) and ... and true ));
397
+ auto constexpr rank = A0::rank;
398
+ static_assert (Axis < rank);
399
+ static_assert (((rank == A::rank) and ... and true ));
400
400
static_assert (((std::is_same_v<get_value_t <A0>, get_value_t <A>>) and ... and true ));
401
+ for (auto ax [[maybe_unused]] : range (rank)) { EXPECTS (ax == Axis or ((a0.extent (ax) == a.extent (ax)) and ... and true )); }
401
402
402
- for (auto const ax : range (A0::rank)) {
403
- if (not (ax == Axis)) { assert (((a0.shape ()[ax] == a.shape ()[ax]) and ... and true )); }
404
- }
405
-
406
- // build concatenated array
403
+ // construct concatenated array
407
404
auto new_shape = a0.shape ();
408
- long offset = 0 ;
409
- new_shape[Axis] = new_shape[Axis] + ((a.shape ()[Axis] + ... + 0 ));
410
- array<get_value_t <A0>, A0::rank> new_array (new_shape);
405
+ new_shape[Axis] = (a.extent (Axis) + ... + new_shape[Axis]);
406
+ auto new_array = array<get_value_t <A0>, rank>(new_shape);
407
+
408
+ // slicing helper function
409
+ auto slice_Axis = [](Array auto &a, range r) {
410
+ auto all_or_range = std::make_tuple (range::all, r);
411
+ return [&]<auto ... Is>(std::index_sequence<Is...>) { return a (std::get<Is == Axis>(all_or_range)...); }(std::make_index_sequence<rank>{});
412
+ };
411
413
414
+ // initialize concatenated array
415
+ long offset = 0 ;
412
416
for (auto const &a_view : {basic_array_view (a0), basic_array_view (a)...}) {
413
- all_view_except<Axis> (new_array, range (offset, offset + a_view.shape ()[ Axis] )) = a_view;
414
- offset += a_view.shape ()[ Axis] ;
417
+ slice_Axis (new_array, range (offset, offset + a_view.extent ( Axis) )) = a_view;
418
+ offset += a_view.extent ( Axis) ;
415
419
}
416
420
417
421
return new_array;
0 commit comments