Skip to content

Commit 36cb7b5

Browse files
committed
Review and simplify nda::concatenate implementation
1 parent 20b5056 commit 36cb7b5

File tree

2 files changed

+32
-36
lines changed

2 files changed

+32
-36
lines changed

c++/nda/basic_functions.hpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -377,41 +377,45 @@ namespace nda {
377377
}
378378

379379
// ------------------------------- concatenate --------------------------------------------
380-
// slice in all dimensions but Axis
381-
template <auto Axis, Array A>
382-
auto all_view_except(A const &a, auto const &arg) {
383-
auto slice_or_arg = [&arg](auto x) {
384-
if constexpr (Axis == decltype(x)::value)
385-
return arg;
386-
else
387-
return range::all;
388-
};
389-
390-
return [&]<auto... Is>(std::index_sequence<Is...>) { return a(slice_or_arg(std::integral_constant<size_t, Is>{})...); }
391-
(std::make_index_sequence<A::rank>{});
392-
};
393380

394-
// numpy style concatenation
395-
template <auto Axis, Array A0, Array... A>
381+
/**
382+
* Join a sequence of arrays along an existing axis.
383+
*
384+
* The arrays must have the same value_type and also shape,
385+
* except in the dimension corresponding to axis (the first, by default).
386+
*
387+
* @tparam Axis The axis along which to concatenate (default: 0)
388+
* @tparam A0 Type of the first array
389+
* @tparam A Types of the subsequent arrays
390+
* @param a0 The first array
391+
* @param a The subsequent arrays
392+
* @return New array with the concatenated data
393+
*/
394+
template <size_t Axis = 0, Array A0, Array... A>
396395
auto concatenate(A0 const &a0, A const &...a) {
397396
// sanity checks
398-
static_assert(A0::rank >= Axis);
399-
static_assert(((A0::rank == A::rank) and ... and true));
397+
auto constexpr rank = A0::rank;
398+
static_assert(Axis < rank);
399+
static_assert(((rank == A::rank) and ... and true));
400400
static_assert(((std::is_same_v<get_value_t<A0>, get_value_t<A>>) and ... and true));
401+
for (auto ax [[maybe_unused]] : range(rank)) { EXPECTS(ax == Axis or ((a0.extent(ax) == a.extent(ax)) and ... and true)); }
401402

402-
for (auto const ax : range(A0::rank)) {
403-
if (not (ax == Axis)) { assert(((a0.shape()[ax] == a.shape()[ax]) and ... and true)); }
404-
}
405-
406-
// build concatenated array
403+
// construct concatenated array
407404
auto new_shape = a0.shape();
408-
long offset = 0;
409-
new_shape[Axis] = new_shape[Axis] + ((a.shape()[Axis] + ... + 0));
410-
array<get_value_t<A0>, A0::rank> new_array(new_shape);
405+
new_shape[Axis] = (a.extent(Axis) + ... + new_shape[Axis]);
406+
auto new_array = array<get_value_t<A0>, rank>(new_shape);
407+
408+
// slicing helper function
409+
auto slice_Axis = [](Array auto &a, range r) {
410+
auto all_or_range = std::make_tuple(range::all, r);
411+
return [&]<auto... Is>(std::index_sequence<Is...>) { return a(std::get<Is == Axis>(all_or_range)...); }(std::make_index_sequence<rank>{});
412+
};
411413

414+
// initialize concatenated array
415+
long offset = 0;
412416
for (auto const &a_view : {basic_array_view(a0), basic_array_view(a)...}) {
413-
all_view_except<Axis>(new_array, range(offset, offset + a_view.shape()[Axis])) = a_view;
414-
offset += a_view.shape()[Axis];
417+
slice_Axis(new_array, range(offset, offset + a_view.extent(Axis))) = a_view;
418+
offset += a_view.extent(Axis);
415419
}
416420

417421
return new_array;

test/c++/nda_basic.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,6 @@ TEST(Array, Concatenate) { //NOLINT
418418
for (int k = 0; k < 6; ++k) { c(i, j, k) = i + 10 * j + 102 * k; }
419419
}
420420

421-
// test all_view_except
422-
auto const a_view_except = all_view_except<1>(a, range(1, 3));
423-
EXPECT_EQ(a_view_except.shape()[1], 2);
424-
425-
for (int i = 0; i < 2; ++i)
426-
for (int j = 0; j < 2; ++j)
427-
for (int k = 0; k < 4; ++k) { EXPECT_EQ(a_view_except(i, j, k), a(i, j + 1, k)); }
428-
429421
// test concatenate
430422
auto const abc_axis2_concat = concatenate<2>(a, b, c);
431423
EXPECT_EQ(abc_axis2_concat.shape()[2], 15);
@@ -441,4 +433,4 @@ TEST(Array, Concatenate) { //NOLINT
441433
EXPECT_EQ(abc_axis2_concat(i, j, k), c(i, j, k - 9));
442434
}
443435
}
444-
}
436+
}

0 commit comments

Comments
 (0)