@@ -377,41 +377,45 @@ namespace nda {
377377 }
378378
379379 // ------------------------------- concatenate --------------------------------------------
380- // slice in all dimensions but Axis
381- template <auto Axis, Array A>
382- auto all_view_except (A const &a, auto const &arg) {
383- auto slice_or_arg = [&arg](auto x) {
384- if constexpr (Axis == decltype (x)::value)
385- return arg;
386- else
387- return range::all;
388- };
389-
390- return [&]<auto ... Is>(std::index_sequence<Is...>) { return a (slice_or_arg (std::integral_constant<size_t , Is>{})...); }
391- (std::make_index_sequence<A::rank>{});
392- };
393380
394- // numpy style concatenation
395- template <auto Axis, Array A0, Array... A>
381+ /* *
382+ * Join a sequence of arrays along an existing axis.
383+ *
384+ * The arrays must have the same value_type and also shape,
385+ * except in the dimension corresponding to axis (the first, by default).
386+ *
387+ * @tparam Axis The axis along which to concatenate (default: 0)
388+ * @tparam A0 Type of the first array
389+ * @tparam A Types of the subsequent arrays
390+ * @param a0 The first array
391+ * @param a The subsequent arrays
392+ * @return New array with the concatenated data
393+ */
394+ template <size_t Axis = 0 , Array A0, Array... A>
396395 auto concatenate (A0 const &a0, A const &...a) {
397396 // sanity checks
398- static_assert (A0::rank >= Axis);
399- static_assert (((A0::rank == A::rank) and ... and true ));
397+ auto constexpr rank = A0::rank;
398+ static_assert (Axis < rank);
399+ static_assert (((rank == A::rank) and ... and true ));
400400 static_assert (((std::is_same_v<get_value_t <A0>, get_value_t <A>>) and ... and true ));
401+ for (auto ax [[maybe_unused]] : range (rank)) { EXPECTS (ax == Axis or ((a0.extent (ax) == a.extent (ax)) and ... and true )); }
401402
402- for (auto const ax : range (A0::rank)) {
403- if (not (ax == Axis)) { assert (((a0.shape ()[ax] == a.shape ()[ax]) and ... and true )); }
404- }
405-
406- // build concatenated array
403+ // construct concatenated array
407404 auto new_shape = a0.shape ();
408- long offset = 0 ;
409- new_shape[Axis] = new_shape[Axis] + ((a.shape ()[Axis] + ... + 0 ));
410- array<get_value_t <A0>, A0::rank> new_array (new_shape);
405+ new_shape[Axis] = (a.extent (Axis) + ... + new_shape[Axis]);
406+ auto new_array = array<get_value_t <A0>, rank>(new_shape);
407+
408+ // slicing helper function
409+ auto slice_Axis = [](Array auto &a, range r) {
410+ auto all_or_range = std::make_tuple (range::all, r);
411+ return [&]<auto ... Is>(std::index_sequence<Is...>) { return a (std::get<Is == Axis>(all_or_range)...); }(std::make_index_sequence<rank>{});
412+ };
411413
414+ // initialize concatenated array
415+ long offset = 0 ;
412416 for (auto const &a_view : {basic_array_view (a0), basic_array_view (a)...}) {
413- all_view_except<Axis> (new_array, range (offset, offset + a_view.shape ()[ Axis] )) = a_view;
414- offset += a_view.shape ()[ Axis] ;
417+ slice_Axis (new_array, range (offset, offset + a_view.extent ( Axis) )) = a_view;
418+ offset += a_view.extent ( Axis) ;
415419 }
416420
417421 return new_array;
0 commit comments