Skip to content

Commit c9cad45

Browse files
committed
Fix scalar conversions for views
* ra/big.hh (ViewBig): Scalar conversion returns *cp not T &. This fixes scalar conversion for sequence view, which returns rvalue. * ra/small.hh (ViewSmall): Likewise. * bench/bench-at.cc: Exercise the conversions here.
1 parent 51626cb commit c9cad45

File tree

5 files changed

+58
-55
lines changed

5 files changed

+58
-55
lines changed

bench/bench-at.cc

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,23 @@
1212
#include <string>
1313
#include <cstdlib>
1414
#include "ra/test.hh"
15+
#include "test/mpdebug.hh"
1516

1617
using std::cout, std::endl, std::flush, ra::TestRecorder, ra::Benchmark;
1718
using real = double;
1819
using ra::dim_t;
1920

21+
ra::TestRecorder tr(cout);
22+
2023
// FIXME bigd/bigd & bigs/bigd at loop
2124

2225
int main(int argc, char * * argv)
2326
{
2427
int reps = argc>1 ? std::stoi(argv[1]) : 10000;
2528
std::println(cout, "reps = {}", reps);
26-
ra::TestRecorder tr(cout);
2729
tr.section("rank 2");
2830
{
29-
auto test2 = [&tr](auto && C, auto && I, int reps, std::string tag)
31+
auto test2 = [](auto && C, auto && I, int reps, std::string tag)
3032
{
3133
if ("warmup"!=tag) tr.section(tag);
3234
int M = C.len(0);
@@ -86,11 +88,11 @@ int main(int argc, char * * argv)
8688
}
8789
tr.section("rank 1");
8890
{
89-
auto test1 = [&tr](auto && C, auto && I, int reps, std::string tag)
91+
auto test1 = [](auto && C, auto && I, int reps, std::string tag)
9092
{
9193
if ("warmup"!=tag) tr.section(tag);
92-
int M = C.len(0);
93-
int O = I.len(0);
94+
int const M = C.len(0);
95+
[[maybe_unused]] int const O = I.len(0);
9496
I(ra::all, 0) = map([&](auto && i) { return i%M; }, ra::_0 + (std::rand() & 1));
9597

9698
int ref0 = sum(at(C, iter<1>(I))), val0 = 0;
@@ -114,14 +116,14 @@ int main(int argc, char * * argv)
114116
}));
115117
};
116118

117-
// especially Ptr, if it differs from ViewSmall/ViewBig
118-
auto iotaa = ra::iota(100, 1, 4);
119-
ra::Big<int, 1> bigsa({100}, 4*ra::_0);
120-
ra::Big<int> bigda({100}, 4*ra::_0);
121-
ra::Small<int, 100> smola = 4*ra::_0;
122-
ra::Big<int, 2> bigsi({100, 1}, ra::none);
123-
ra::Big<int> bigdi({100, 1}, ra::none);
124-
ra::Small<int, 100, 1> smoli;
119+
auto iotav = ra::ii({100}); // view.at(i) rank depends on i's size. That can be var rank which is a lot slower.
120+
auto iotai = ra::iota(100); // iter.at(i) requires i's size to be iter's rank.
121+
[[maybe_unused]] ra::Big<int, 1> bigsa({100}, 4*ra::_0);
122+
[[maybe_unused]] ra::Big<int> bigda({100}, 4*ra::_0);
123+
[[maybe_unused]] ra::Small<int, 100> smola = 4*ra::_0;
124+
[[maybe_unused]] ra::Big<int, 2> bigsi({100, 1}, ra::none);
125+
[[maybe_unused]] ra::Big<int> bigdi({100, 1}, ra::none);
126+
[[maybe_unused]] ra::Small<int, 100, 1> smoli;
125127
test1(smola, smoli, reps, "warmup");
126128
test1(smola, smoli, reps, "small/small");
127129
test1(bigsa, smoli, reps, "bigs/small");
@@ -132,9 +134,12 @@ int main(int argc, char * * argv)
132134
test1(smola, bigdi, reps, "small/bigd");
133135
test1(bigsa, bigdi, reps, "bigs/bigd");
134136
test1(bigda, bigdi, reps, "bigd/bigd");
135-
test1(iotaa, smoli, reps, "iota/small");
136-
test1(iotaa, bigsi, reps, "iota/bigs");
137-
test1(iotaa, bigdi, reps, "iota/bigd");
137+
test1(iotav, smoli, reps, "iotv/small");
138+
test1(iotav, bigsi, reps, "iotv/bigs");
139+
test1(iotav, bigdi, reps, "iotv/bigd");
140+
test1(iotai, smoli, reps, "ioti/small");
141+
test1(iotai, bigsi, reps, "ioti/bigs");
142+
test1(iotai, bigdi, reps, "ioti/bigd");
138143
}
139144
return tr.summary();
140145
}

ra/big.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ struct ViewBig
143143
{
144144
return *reinterpret_cast<ViewBig<reconst<P>, RANK> const *>(this);
145145
}
146-
constexpr operator T & () const { return to_scalar(*this); }
146+
constexpr operator decltype(*cp) () const { return to_scalar(*this); }
147147
constexpr decltype(auto) operator()(this auto && self, auto && ... i) { return from(RA_FW(self), RA_FW(i) ...); }
148148
constexpr decltype(auto) operator[](this auto && self, auto && ... i) { return from(RA_FW(self), RA_FW(i) ...); }
149-
constexpr decltype(auto) at(auto const & i) const { return at1(*this, i); }
149+
constexpr decltype(auto) at(auto const & i) const { return at_view(*this, i); }
150150
};
151151

152152
template <class V>

ra/expr.hh

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,29 @@ constexpr auto start(is_builtin_array auto && t);
176176

177177
template <class T> constexpr auto start(std::initializer_list<T> v);
178178

179+
template <class A>
180+
constexpr decltype(auto)
181+
VAL(A && a)
182+
{
183+
if constexpr (is_scalar<A>) { return RA_FW(a); } // [ra8]
184+
else if constexpr (is_iterator<A>) { return *a; } // no need to start()
185+
else if constexpr (requires { *ra::start(RA_FW(a)); }) { return *ra::start(RA_FW(a)); }
186+
// else void
187+
}
188+
189+
template <class A> using value_t = std::remove_volatile_t<std::remove_reference_t<decltype(VAL(std::declval<A>()))>>;
190+
template <class A> using ncvalue_t = std::remove_const_t<value_t<A>>;
191+
192+
constexpr decltype(auto)
193+
to_scalar(auto && e)
194+
{
195+
if constexpr (constexpr dim_t s=size_s(e); 1!=s) {
196+
static_assert(ANY==s, "Bad scalar conversion.");
197+
RA_CK(1==ra::size(e), "Bad scalar conversion from shape [", fmt(nstyle, ra::shape(e)), "].");
198+
}
199+
return VAL(e);
200+
}
201+
179202

180203
// --------------------
181204
// view iterators
@@ -293,14 +316,13 @@ struct Cell: public std::conditional_t<is_constant<Dimv>, CellSmall<P, Dimv, Spe
293316
using View = decltype(std::declval<Base>().c);
294317
static_assert((cellr>=0 || cellr==ANY) && (framer>=0 || framer==ANY), "Bad cell/frame ranks.");
295318
RA_ASSIGNOPS_ITER(Cell)
296-
297319
constexpr static dim_t len_s(int k) { if constexpr (is_constant<Dimv>) return len(k); else return ANY; }
298320
constexpr void adv(rank_t k, dim_t d) { mov(step(k)*d); }
299-
constexpr decltype(auto) at(auto && i) const requires (0==cellr) { return *indexer(*this, c.cp, start(RA_FW(i))); }
300-
constexpr auto at(auto && i) const requires (0!=cellr) { View d(c); d.cp=indexer(*this, d.cp, start(RA_FW(i))); return d; }
301-
constexpr decltype(auto) operator*() const requires (0==cellr) { return *(c.cp); }
321+
constexpr decltype(*c.cp) at(auto && i) const requires (0==cellr) { return *indexer(*this, c.cp, start(RA_FW(i))); }
322+
constexpr View at(auto && i) const requires (0!=cellr) { View d(c); d.cp=indexer(*this, d.cp, start(RA_FW(i))); return d; }
323+
constexpr decltype(*c.cp) operator*() const requires (0==cellr) { return *(c.cp); }
302324
constexpr View const & operator*() const requires (0!=cellr) { return c; }
303-
constexpr operator decltype(c.cp[0]) () const { return to_scalar(*this); }
325+
constexpr operator decltype(*c.cp) () const { return to_scalar(*this); }
304326
constexpr auto save() const { return c.cp; }
305327
constexpr void load(P p) { c.cp = p; }
306328
#pragma GCC diagnostic push
@@ -330,7 +352,6 @@ struct Ptr final
330352
if constexpr (std::is_integral_v<N>) { RA_CK(n>=0, "Bad Ptr length ", n, "."); }
331353
}
332354
RA_ASSIGNOPS_ITER(Ptr)
333-
334355
consteval static rank_t rank() { return 1; }
335356
constexpr static dim_t len_s(int k) { return nn; }
336357
constexpr static dim_t len(int k) requires (is_constant<N>) { return nn; }
@@ -340,8 +361,8 @@ struct Ptr final
340361
constexpr static bool keep(dim_t st, int z, int j) requires (is_constant<S>) { return st*step(z)==step(j); }
341362
constexpr bool keep(dim_t st, int z, int j) const requires (!is_constant<S>) { return st*step(z)==step(j); }
342363
constexpr void adv(rank_t k, dim_t d) { mov(step(k)*d); }
343-
constexpr decltype(auto) at(auto && i) const { return *indexer(*this, cp, start(RA_FW(i))); }
344-
constexpr decltype(auto) operator*() const { return *cp; }
364+
constexpr decltype(*cp) at(auto && i) const { return *indexer(*this, cp, start(RA_FW(i))); } // iter's not view's
365+
constexpr decltype(*cp) operator*() const { return *cp; }
345366
constexpr auto save() const { return cp; }
346367
constexpr void load(P p) { cp = p; }
347368
#pragma GCC diagnostic push
@@ -707,29 +728,6 @@ agree_verb(ilist_t<i ...>, V const & v, T const & ... t)
707728
// map and pick
708729
// ---------------------------
709730

710-
template <class A>
711-
constexpr decltype(auto)
712-
VAL(A && a)
713-
{
714-
if constexpr (is_scalar<A>) { return RA_FW(a); } // [ra8]
715-
else if constexpr (is_iterator<A>) { return *a; } // no need to start()
716-
else if constexpr (requires { *ra::start(RA_FW(a)); }) { return *ra::start(RA_FW(a)); }
717-
// else void
718-
}
719-
720-
template <class A> using value_t = std::remove_volatile_t<std::remove_reference_t<decltype(VAL(std::declval<A>()))>>;
721-
template <class A> using ncvalue_t = std::remove_const_t<value_t<A>>;
722-
723-
constexpr decltype(auto)
724-
to_scalar(auto && e)
725-
{
726-
if constexpr (constexpr dim_t s=size_s(e); 1!=s) {
727-
static_assert(ANY==s, "Bad scalar conversion.");
728-
RA_CK(1==ra::size(e), "Bad scalar conversion from shape [", fmt(nstyle, ra::shape(e)), "].");
729-
}
730-
return VAL(e);
731-
}
732-
733731
template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Map;
734732
template <class Op, Iterator ... P, int ... I>
735733
struct Map<Op, std::tuple<P ...>, ilist_t<I ...>>: public Match<std::tuple<P ...>>

ra/small.hh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ from(A && a, auto && ... i)
338338
}
339339

340340
constexpr decltype(auto)
341-
at1(auto const & a, auto const & i)
341+
at_view(auto const & a, auto const & i)
342342
{
343343
// can't say 'frame rank 0' so -size wouldn't work. FIXME What about ra::len
344344
if constexpr (constexpr rank_t crank = rank_diff(rank_s(a), ra::size_s(i)); ANY==crank) {
@@ -454,10 +454,10 @@ struct ViewSmall
454454
{
455455
return ViewSmall<reconst<P>, Dimv>(cp);
456456
}
457-
constexpr operator T & () const { return to_scalar(*this); }
457+
constexpr operator decltype(*cp) () const { return to_scalar(*this); }
458458
constexpr decltype(auto) operator()(this auto && self, auto && ... i) { return from(RA_FW(self), RA_FW(i) ...); }
459459
constexpr decltype(auto) operator[](this auto && self, auto && ... i) { return from(RA_FW(self), RA_FW(i) ...); }
460-
constexpr decltype(auto) at(auto const & i) const { return at1(*this, i); }
460+
constexpr decltype(auto) at(auto const & i) const { return at_view(*this, i); }
461461
};
462462

463463
#if defined (__clang__)

ra/test.hh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ struct Benchmark
329329
for (int k=0; k<runs_; ++k) {
330330
auto t0 = clock::now();
331331
for (int i=0; i<repeats_; ++i) {
332-
f(RA_FW(a) ...);
332+
f(a ...);
333333
}
334334
clock::duration full = clock::now()-t0;
335335
times.push_back(lapse(empty, full));
@@ -343,7 +343,7 @@ struct Benchmark
343343
g([&](auto && f) {
344344
auto t0 = clock::now();
345345
empty = clock::now()-t0;
346-
}, RA_FW(a) ...);
346+
}, a ...);
347347

348348
ra::Big<clock::duration, 1> times;
349349
for (int k=0; k<runs_; ++k) {
@@ -352,7 +352,7 @@ struct Benchmark
352352
for (int i=0; i<repeats_; ++i) { f(); }
353353
clock::duration full = clock::now()-t0;
354354
times.push_back(lapse(empty, full));
355-
}, RA_FW(a) ...);
355+
}, a ...);
356356
}
357357
return Value { name_, repeats_, empty, std::move(times) };
358358
}

0 commit comments

Comments
 (0)