Skip to content

Commit d925304

Browse files
committed
Tests for reframe() used as transpose()
* ra/expr.hh (reframe): Take dest as function argument. * test/wrank.cc: As stated.
1 parent 5f2750d commit d925304

File tree

6 files changed

+33
-24
lines changed

6 files changed

+33
-24
lines changed

ra/big.hh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ template <class E> requires (ANY==size_s<E>())
492492
struct concrete_type_<E> { using type = Big<ncvalue_t<E>, rank_s<E>()>; };
493493

494494
template <class E> requires (0!=rank_s<E>() && ANY!=size_s<E>())
495-
struct concrete_type_<E> { using type = SmallArray<ncvalue_t<E>, ic_t<default_dims<shape_s<E>>>>; };
495+
struct concrete_type_<E> { using type = SmallArray<ncvalue_t<E>, ic_t<default_dims(shape_s<E>)>>; };
496496

497497
template <class E> using concrete_type = std::conditional_t<(0==rank_s<E>() && !is_ra<E>), std::decay_t<E>,
498498
typename concrete_type_<E>::type>;

ra/expr.hh

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ maybe_step()
223223

224224
template <class I, class N=dim_c<UNB>, class S=dim_c<1>>
225225
constexpr auto
226-
ptr(I && i, N && n = N {}, S && s = maybe_step<S>())
226+
ptr(I && i, N && n=N {}, S && s=maybe_step<S>())
227227
{
228228
if constexpr (std::ranges::bidirectional_range<std::remove_reference_t<I>>) {
229229
static_assert(std::is_same_v<dim_c<UNB>, N>, "Object has own length.");
@@ -283,8 +283,7 @@ struct Reframe<A, ilist_t<di ...>, ilist_t<i ...>>
283283
constexpr static int orig(int k) { int r=-1; (void)((di==k && (r=i, 1)) || ...); return r; }
284284
constexpr static dim_t len_s(int k)
285285
{
286-
int l=orig(k);
287-
return l>=0 ? std::decay_t<A>::len_s(l) : UNB;
286+
int l=orig(k); return l>=0 ? std::decay_t<A>::len_s(l) : UNB;
288287
}
289288
constexpr static dim_t
290289
len(int k) requires (requires { std::decay_t<A>::len(k); })
@@ -321,16 +320,15 @@ struct Reframe<A, ilist_t<di ...>, ilist_t<i ...>>
321320
constexpr decltype(auto) operator*() const { return *a; }
322321
constexpr auto save() const { return a.save(); }
323322
constexpr void load(auto const & p) { a.load(p); }
324-
// FIXME if Dest preserves axis order (?) which wrank does
325323
constexpr void mov(auto const & s) { a.mov(s); }
326324
};
327325

328326
// Optimize nop case. TODO If A is CellBig, etc. beat Dest on it, same for eventual transpose_expr<>.
329-
template <class Dest, class A>
327+
template <class A, class Dest>
330328
constexpr decltype(auto)
331-
reframe(A && a)
329+
reframe(A && a, Dest)
332330
{
333-
if constexpr (std::is_same_v<Dest, mp::iota<Reframe<A, Dest>::rank()>>) {
331+
if constexpr (std::is_same_v<Dest, mp::iota<mp::len<Dest>>>) {
334332
return RA_FWD(a);
335333
} else {
336334
return Reframe<A, Dest> { RA_FWD(a) };
@@ -339,9 +337,9 @@ reframe(A && a)
339337

340338
template <int w=0, class I=dim_t, class N=dim_c<UNB>, class S=dim_c<1>>
341339
constexpr auto
342-
iota(N && n = N {}, I && i = 0, S && s = maybe_step<S>())
340+
iota(N && n=N {}, I && i=0, S && s=maybe_step<S>())
343341
{
344-
return reframe<ilist_t<w>>(Ptr<Seq<sarg<I>>, sarg<N>, sarg<S>> { {RA_FWD(i)}, RA_FWD(n), RA_FWD(s) });
342+
return reframe(Ptr<Seq<sarg<I>>, sarg<N>, sarg<S>> { {RA_FWD(i)}, RA_FWD(n), RA_FWD(s) }, ilist_t<w>{});
345343
}
346344

347345
#define DEF_TENSORINDEX(w) constexpr auto JOIN(_, w) = iota<w>();
@@ -613,7 +611,7 @@ constexpr bool
613611
agree_verb(ilist_t<i ...>, V const & v, T const & ... t)
614612
{
615613
using FM = Framematch<V, std::tuple<T ...>>;
616-
return agree_op(FM::op(v), reframe<mp::ref<typename FM::R, i>>(ra::start(t)) ...);
614+
return agree_op(FM::op(v), reframe(ra::start(t), mp::ref<typename FM::R, i>{}) ...);
617615
}
618616

619617

@@ -656,7 +654,7 @@ constexpr auto
656654
map_verb(ilist_t<i ...>, Op && op, P && ... p)
657655
{
658656
using FM = Framematch<Op, std::tuple<P ...>>;
659-
return map_(FM::op(RA_FWD(op)), reframe<mp::ref<typename FM::R, i>>(RA_FWD(p)) ...);
657+
return map_(FM::op(RA_FWD(op)), reframe(RA_FWD(p), mp::ref<typename FM::R, i>{}) ...);
660658
}
661659

662660
constexpr auto

ra/ply.hh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ struct STLIterator
301301
over = true;
302302
}
303303
constexpr STLIterator & operator++() { if constexpr (ANY==rank_s(a)) { next(rank(a)-1); } else { next<rank_s(a)-1>(); } return *this; }
304-
constexpr void operator++(int) { ++(*this); } // see p0541 and p2550. Or just avoid.
304+
constexpr void operator++(int) { ++(*this); }
305305
};
306306

307307
template <class A> STLIterator(A &&) -> STLIterator<A>;

ra/small.hh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ filldim(auto && shape, auto & dimv)
6868
return s;
6969
}
7070

71-
template <auto lenv> constexpr auto
72-
default_dims = []{ std::array<Dim, ra::size(lenv)> dimv; filldim(lenv, dimv); return dimv; }();
71+
consteval auto
72+
default_dims(auto const & lv) { std::array<Dim, ra::size(lv)> dv; filldim(lv, dv); return dv; };
7373

7474
constexpr auto
7575
shape(auto const & v, auto && e)
@@ -349,7 +349,7 @@ struct nested_arg<T, Dimv>
349349
{
350350
constexpr static auto sn = ssize(Dimv::value)-1;
351351
constexpr static auto s = std::apply([](auto ... i) { return std::array<dim_t, sn> { Dimv::value[i].len ... }; }, mp::iota<sn, 1> {});
352-
using sub = std::conditional_t<0==sn, T, SmallArray<T, ic_t<default_dims<s>>>>;
352+
using sub = std::conditional_t<0==sn, T, SmallArray<T, ic_t<default_dims(s)>>>;
353353
};
354354

355355
template <class P> struct reconst_t { using type = void; };
@@ -559,7 +559,7 @@ SmallArray<T, Dimv, std::tuple<nested_args ...>>
559559
};
560560

561561
template <class T, dim_t ... lens>
562-
using Small = SmallArray<T, ic_t<default_dims<std::array<dim_t, sizeof...(lens)> {lens ...}>>>;
562+
using Small = SmallArray<T, ic_t<default_dims(std::array<dim_t, sizeof...(lens)> {lens ...})>>;
563563

564564
template <class A0, class ... A> SmallArray(A0, A ...) -> Small<A0, 1+sizeof...(A)>;
565565

@@ -714,7 +714,7 @@ constexpr auto
714714
start(is_builtin_array auto && t)
715715
{
716716
using T = std::remove_all_extents_t<std::remove_reference_t<decltype(t)>>; // preserve const
717-
return ViewSmall<T *, ic_t<default_dims<ra::shape(t)>>>(peel(t)).iter();
717+
return ViewSmall<T *, ic_t<default_dims(ra::shape(t))>>(peel(t)).iter();
718718
}
719719

720720
} // namespace ra

test/small-1.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main()
7373
}
7474
tr.section("static step computation");
7575
{
76-
auto dims = ra::default_dims<std::array<ra::dim_t, 3> {3, 4, 5}>;
76+
auto dims = ra::default_dims(std::array<ra::dim_t, 3> {3, 4, 5});
7777
tr.info("step 0").test_eq(20, dims[0].step);
7878
tr.info("step 1").test_eq(5, dims[1].step);
7979
tr.info("step 2").test_eq(1, dims[2].step);
@@ -466,7 +466,7 @@ int main()
466466
tr.test_eq(a, 1+ra::Small<int, 3, 2> {{0, 1}, {2, 3}, {4, 5}});
467467
}
468468
tr.section("ViewSmall as iota<w>");
469-
// in order to replace Ptr<>, we must support Len both in P and in Dimv.
469+
// in order to replace Ptr<>, we must support Len in View::P and View::Dimv.
470470
{
471471
constexpr ra::ViewSmall<ra::Seq<ra::dim_t>, ra::ic_t<std::array {ra::Dim {ra::UNB, 1}}>>
472472
i0(ra::Seq<ra::dim_t> {0});

test/wrank.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ void nested_wrank_demo(V && v, A && a, B && b)
4747
using FM = ra::Framematch<V, tuple<decltype(a.iter()), decltype(b.iter())>>;
4848
cout << "width of fm: " << ra::mp::len<typename FM::R> << endl;
4949
cout << ra::mp::print_ilist_t<typename FM::R> {} << endl;
50-
auto af0 = ra::reframe<ra::mp::ref<typename FM::R, 0>>(a.iter());
51-
auto af1 = ra::reframe<ra::mp::ref<typename FM::R, 1>>(b.iter());
50+
auto af0 = reframe(a.iter(), ra::mp::ref<typename FM::R, 0>{});
51+
auto af1 = reframe(b.iter(), ra::mp::ref<typename FM::R, 1>{});
5252
cout << "af0: " << sizeof(af0) << endl;
5353
cout << "af1: " << sizeof(af1) << endl;
5454
{
@@ -107,8 +107,8 @@ int main()
107107
using FM = ra::Framematch<decltype(v), tuple<decltype(a.iter()), decltype(b.iter())>>;
108108
cout << "width of fm: " << ra::mp::len<FM::R> << endl;
109109
cout << ra::mp::print_ilist_t<FM::R> {} << endl;
110-
auto af0 = ra::reframe<ra::mp::ref<FM::R, 0>>(a.iter());
111-
auto af1 = ra::reframe<ra::mp::ref<FM::R, 1>>(b.iter());
110+
auto af0 = reframe(a.iter(), ra::mp::ref<FM::R, 0>{});
111+
auto af1 = reframe(b.iter(), ra::mp::ref<FM::R, 1>{});
112112
cout << "af0: " << sizeof(af0) << endl;
113113
cout << "af1: " << sizeof(af1) << endl;
114114
auto ewv = ra::map_(FM::op(v), af0, af1);
@@ -354,5 +354,16 @@ int main()
354354
.test_eq(c, from([o = from(std::multiplies<>(), a, b)](auto i, auto j) { return o.at(std::array {i, j}); },
355355
ra::iota(3), ra::iota(4)));
356356
}
357+
tr.section("Reframe as transpose");
358+
{
359+
ra::ViewSmall<ra::Seq<int>, ra::ic_t<ra::default_dims(std::array<int, 3> {2, 3, 4})>> a(ra::Seq<int> {0});
360+
auto test = [&](auto dest) { tr.strict().test_eq(transpose(a, dest), reframe(start(a), dest)); };
361+
test(ra::ilist_t<0, 1, 2> {});
362+
test(ra::ilist_t<1, 2, 0> {});
363+
test(ra::ilist_t<2, 0, 1> {});
364+
test(ra::ilist_t<2, 1, 0> {});
365+
test(ra::ilist_t<1, 0, 2> {});
366+
test(ra::ilist_t<0, 2, 1> {});
367+
}
357368
return tr.summary();
358369
}

0 commit comments

Comments
 (0)