Skip to content

Commit 31102a1

Browse files
committed
Simplify is_special for Pick or Map
* ra/expr.hh (is_special_def): True for any_has_len. Remove Pick and Map specializations. * test/len.cc: Extra tests.
1 parent d925304 commit 31102a1

File tree

2 files changed

+54
-66
lines changed

2 files changed

+54
-66
lines changed

ra/expr.hh

Lines changed: 51 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ constexpr struct Len
9595
consteval static void mov(dim_t d) { len_outside_subscript_context(); }
9696
} len;
9797

98-
template <> constexpr bool is_special_def<Len> = true; // protect exprs with Len from reduction.
99-
template <class E> struct WLen; // defined in ply.hh.
98+
template <class E> struct WLen; // defined in ply.hh.
10099
template <class E> concept has_len = requires(int ln, E && e) { WLen<std::decay_t<E>>::f(ln, RA_FWD(e)); };
100+
template <has_len E> constexpr bool is_special_def<E> = true; // protect exprs with Len from reduction.
101101

102102
template <class Ln, class E>
103103
constexpr decltype(auto)
104104
wlen(Ln ln, E && e)
105105
{
106-
static_assert(std::is_integral_v<std::decay_t<Ln>> || is_constant<std::decay_t<Ln>>);
106+
static_assert(std::is_integral_v<Ln> || is_constant<Ln>);
107107
if constexpr (has_len<E>) {
108108
return WLen<std::decay_t<E>>::f(ln, RA_FWD(e));
109109
} else {
@@ -170,17 +170,16 @@ struct Ptr final
170170
static_assert(has_len<S> || is_constant<S> || std::is_integral_v<S>);
171171
static_assert(has_len<I> || std::bidirectional_iterator<I>);
172172
constexpr static dim_t nn = maybe_any<N>;
173+
static_assert(0<=nn || ANY==nn || UNB==nn);
173174
constexpr static bool constant = is_constant<N> && is_constant<S>;
174175

175176
I i;
176177
[[no_unique_address]] N const n = {};
177178
[[no_unique_address]] S const s = {};
178179
constexpr static S gets() requires (is_constant<S>) { return S {}; }
179180
constexpr I gets() const requires (!is_constant<S>) { return s; }
180-
181181
constexpr Ptr(I i, N n, S s): i(i), n(n), s(s)
182182
{
183-
static_assert(ANY==nn || 0<=nn || UNB==nn);
184183
if constexpr (std::is_integral_v<N>) { RA_CHECK(n>=0, "Bad Ptr length ", n, "."); }
185184
}
186185
RA_ASSIGNOPS_SELF(Ptr)
@@ -352,7 +351,7 @@ FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
352351
// ---------------------------
353352

354353
template <class cranks, class Op> struct Verb final { Op op; };
355-
template <class A> concept is_verb = requires (A a) { []<class cranks, class Op>(Verb<cranks, Op> const &){}(a); };
354+
template <class A> concept is_verb = requires (A & a) { []<class cranks, class Op>(Verb<cranks, Op> &){}(a); };
356355

357356
template <class cranks, class Op>
358357
constexpr auto
@@ -362,74 +361,32 @@ template <rank_t ... crank, class Op>
362361
constexpr auto
363362
wrank(Op && op) { return Verb<ilist_t<crank ...>, Op> { RA_FWD(op) }; }
364363

365-
template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
364+
template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, int skip=0>
366365
struct Framematch_def;
367366

368-
template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
367+
template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, int skip=0>
369368
using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
370369

371370
// Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
372-
template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
371+
template <class ... crank, class W, class ... Ti, class ... Ri, int skip>
373372
struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
374373
{
375-
static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
376-
// live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
374+
// TODO crank negative, inf.
377375
using live = ilist_t<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
378376
using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
379377
using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + std::ranges::max(tuple2array<int, live>)>;
380378
using R = typename FM::R;
381379
template <class U> constexpr static decltype(auto) op(U && v) { return FM::op(RA_FWD(v).op); } // cf [ra31]
382380
};
383381

384-
// Terminal case where V doesn't have rank (is a raw op()).
385-
template <class V, class ... Ti, class ... Ri, rank_t skip>
382+
template <class V, class ... Ti, class ... Ri, int skip>
386383
struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
387384
{
388-
static_assert(sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
389385
// TODO -crank::value when the actual verb rank is used (eg to use CellBig<... that_rank> instead of just begin()).
390386
using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
391387
template <class U> constexpr static decltype(auto) op(U && v) { return RA_FWD(v); }
392388
};
393389

394-
395-
// --------------
396-
// making Iterators
397-
// --------------
398-
399-
// TODO arbitrary exprs? runtime cr? ra::len in cr?
400-
template <int cr>
401-
constexpr auto
402-
iter(Slice auto && a) { return RA_FWD(a).template iter<cr>(); }
403-
404-
constexpr void
405-
start(auto && t) { static_assert(false, "Cannot start() type."); }
406-
407-
constexpr auto
408-
start(is_fov auto && t) { return ra::ptr(RA_FWD(t)); }
409-
410-
template <class T>
411-
constexpr auto
412-
start(std::initializer_list<T> v) { return ra::ptr(v.begin(), v.size()); }
413-
414-
constexpr auto
415-
start(is_scalar auto && t) { return ra::scalar(RA_FWD(t)); }
416-
417-
// forward declare for Match; implemented in small.hh.
418-
constexpr auto
419-
start(is_builtin_array auto && t);
420-
421-
// neither CellBig nor CellSmall will retain rvalues [ra4].
422-
constexpr auto
423-
start(Slice auto && t) { return iter<0>(RA_FWD(t)); }
424-
425-
// iterators need to be reset on each use [ra35].
426-
template <class A> requires (is_iterator<A> && !(requires (A a) { []<class C>(Scalar<C> const &){}(a); }))
427-
constexpr auto
428-
start(A & a) { return a; }
429-
430-
constexpr decltype(auto)
431-
start(is_iterator auto && a) { return RA_FWD(a); }
432-
433390

434391
// --------------------
435392
// prefix match
@@ -438,7 +395,7 @@ start(is_iterator auto && a) { return RA_FWD(a); }
438395
constexpr rank_t
439396
choose_rank(rank_t a, rank_t b) { return ANY==a ? a : ANY==b ? b : a>=0 ? (b>=0 ? std::max(a, b) : a) : b; }
440397

441-
// finite before ANY before UNB, assumes checks pass.
398+
// finite before ANY before UNB, assumes neither is MIS.
442399
constexpr dim_t
443400
choose_len(dim_t a, dim_t b) { return a>=0 ? (a==b ? a : b>=0 ? MIS : a) : UNB==a ? b : UNB==b ? a : b; }
444401

@@ -586,6 +543,45 @@ struct Match<std::tuple<P ...>, ilist_t<I ...>>
586543
constexpr void mov(auto const & s) { ((get<I>(t).mov(get<I>(s))), ...); }
587544
};
588545

546+
547+
// --------------
548+
// making Iterators
549+
// --------------
550+
551+
// TODO arbitrary exprs? runtime cr? ra::len in cr?
552+
template <int cr>
553+
constexpr auto
554+
iter(Slice auto && a) { return RA_FWD(a).template iter<cr>(); }
555+
556+
constexpr void
557+
start(auto && t) { static_assert(false, "Cannot start() type."); }
558+
559+
constexpr auto
560+
start(is_fov auto && t) { return ra::ptr(RA_FWD(t)); }
561+
562+
template <class T>
563+
constexpr auto
564+
start(std::initializer_list<T> v) { return ra::ptr(v.begin(), v.size()); }
565+
566+
constexpr auto
567+
start(is_scalar auto && t) { return ra::scalar(RA_FWD(t)); }
568+
569+
// forward declare for Match; implemented in small.hh.
570+
constexpr auto
571+
start(is_builtin_array auto && t);
572+
573+
// neither CellBig nor CellSmall will retain rvalues [ra4].
574+
constexpr auto
575+
start(Slice auto && t) { return iter<0>(RA_FWD(t)); }
576+
577+
// iterators need to be reset on each use [ra35].
578+
template <class A> requires (is_iterator<A> && !(requires (A a) { []<class C>(Scalar<C> const &){}(a); }))
579+
constexpr auto
580+
start(A & a) { return a; }
581+
582+
constexpr decltype(auto)
583+
start(is_iterator auto && a) { return RA_FWD(a); }
584+
589585

590586
// ---------------
591587
// explicit agreement checks
@@ -616,7 +612,7 @@ agree_verb(ilist_t<i ...>, V const & v, T const & ... t)
616612

617613

618614
// ---------------------------
619-
// operator expression
615+
// map and pick
620616
// ---------------------------
621617

622618
template <class E>
@@ -643,9 +639,6 @@ struct Map<Op, std::tuple<P ...>, ilist_t<I ...>>: public Match<std::tuple<P ...
643639
constexpr operator decltype(std::invoke(op, *get<I>(t) ...)) () const { return to_scalar(*this); }
644640
};
645641

646-
template <class Op, Iterator ... P>
647-
constexpr bool is_special_def<Map<Op, std::tuple<P ...>>> = (is_special<P> || ...);
648-
649642
template <class Op, class ... P>
650643
Map(Op && op, P && ... p) -> Map<Op, std::tuple<P ...>>;
651644

@@ -666,11 +659,6 @@ map_(auto && op, auto && ... p) { return Map(RA_FWD(op), RA_FWD(p) ...); }
666659
constexpr auto
667660
map(auto && op, auto && ... a) { return map_(RA_FWD(op), start(RA_FWD(a)) ...); }
668661

669-
670-
// ---------------------------
671-
// pick expression
672-
// ---------------------------
673-
674662
template <class J> struct type_at { template <class P> using type = decltype(std::declval<P>().at(std::declval<J>())); };
675663

676664
template <std::size_t I, class T, class J>
@@ -711,9 +699,6 @@ struct Pick<std::tuple<P ...>, ilist_t<I ...>>: public Match<std::tuple<P ...>>
711699
constexpr operator decltype(pick_star<0>(*get<0>(t), t)) () const { return to_scalar(*this); }
712700
};
713701

714-
template <Iterator ... P>
715-
constexpr bool is_special_def<Pick<std::tuple<P ...>>> = (is_special<P> || ...);
716-
717702
template <class ... P>
718703
Pick(P && ... p) -> Pick<std::tuple<P ...>>;
719704

test/len.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ int main()
3131
tr.test(ra::is_zero_or_scalar<ra::Len>);
3232
tr.test(!ra::is_ra_pos<ra::Len>);
3333
tr.test(ra::is_special<ra::Len>);
34+
tr.test(ra::is_special<decltype(ra::len + ra::len)>);
35+
tr.test(ra::is_special<decltype(ra::pick(std::array {0, 1, 0}, ra::len, 3))>);
36+
tr.test(!ra::is_special<decltype(ra::pick(std::array {0, 1, 0}, 1, 3))>);
3437
tr.test(ra::tomap<ra::Len>);
3538
tr.test(!ra::toreduce<ra::Len>);
3639
tr.test(ra::tomap<ra::Len, ra::Len>);

0 commit comments

Comments
 (0)