Skip to content

Commit 91f3975

Browse files
committed
Remove hard-coded return type of expression eval method
We can deduce everything
1 parent 8e9c3dc commit 91f3975

File tree

8 files changed

+49
-19
lines changed

8 files changed

+49
-19
lines changed

src/expression/ekat_expression_base.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55

66
namespace ekat {
77

8-
// For now. Later, template on return type maybe?
9-
using Real = double;
10-
118
template<typename Derived>
129
class Expression {
1310
public:
@@ -16,7 +13,7 @@ class Expression {
1613

1714
template<typename... Args>
1815
KOKKOS_INLINE_FUNCTION
19-
Real eval(Args... args) const {
16+
auto eval(Args... args) const {
2017
static_assert(std::conjunction_v<std::is_integral<Args>...>,
2118
"[Expression] All arguments must be integral types!");
2219
static_assert(sizeof...(Args) <= 7,

src/expression/ekat_expression_binary_op.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum class BinOp {
1717
template<typename ELeft, typename ERight, BinOp OP>
1818
class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
1919
public:
20+
2021
BinaryExpression (const ELeft& left, const ERight& right)
2122
: m_left(left)
2223
, m_right(right)
@@ -28,7 +29,7 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
2829

2930
template<typename... Args>
3031
KOKKOS_INLINE_FUNCTION
31-
Real eval(Args... args) const {
32+
auto eval(Args... args) const {
3233
if constexpr (OP==BinOp::Plus) {
3334
return m_left.eval(args...) + m_right.eval(args...);
3435
} else if constexpr (OP==BinOp::Minus) {
@@ -43,6 +44,8 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
4344
return Kokkos::min(m_left.eval(args...),m_right.eval(args...));
4445
}
4546
}
47+
48+
static auto ret_type () { return ELeft::ret_type() + ERight::ret_type(); }
4649
protected:
4750

4851
ELeft m_left;

src/expression/ekat_expression_compare.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ enum class Comparison : int {
2121
template<typename ELeft, typename ERight>
2222
class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> {
2323
public:
24+
using ret_t = int;
25+
2426
static constexpr bool scalar_left = std::is_arithmetic_v<ELeft>;
2527
static constexpr bool scalar_right = std::is_arithmetic_v<ERight>;
2628
static_assert(not scalar_left or not scalar_right,
@@ -49,7 +51,7 @@ class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> {
4951

5052
template<typename... Args>
5153
KOKKOS_INLINE_FUNCTION
52-
Real eval(Args... args) const {
54+
ret_t eval(Args... args) const {
5355
if constexpr (scalar_left) {
5456
switch (m_cmp) {
5557
case Comparison::EQ: return m_left == m_right.eval(args...);
@@ -85,6 +87,8 @@ class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> {
8587
}
8688
}
8789
}
90+
91+
static int ret_type () { return 0; }
8892
protected:
8993

9094
ELeft m_left;

src/expression/ekat_expression_conditional.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,21 @@ class ConditionalExpression : public Expression<ConditionalExpression<ECond,ELef
2222

2323
template<typename... Args>
2424
KOKKOS_INLINE_FUNCTION
25-
Real eval(Args... args) const {
25+
auto eval(Args... args) const
26+
-> std::common_type_t<decltype(ELeft::ret_type()),decltype(ERight::ret_type())>
27+
{
2628
if (m_cmp.eval(args...))
2729
return m_left.eval(args...);
2830
else
2931
return m_right.eval(args...);
3032
}
33+
34+
static auto ret_type () {
35+
auto ret_l = ELeft::ret_type();
36+
auto ret_r = ERight::ret_type();
37+
using type = std::common_type_t<decltype(ret_l),decltype(ret_r)>;
38+
return type(0);
39+
}
3140
protected:
3241

3342
ECond m_cmp;

src/expression/ekat_expression_math.hpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class PowExpression : public Expression<PowExpression<EBase,EExp>> {
3434

3535
template<typename... Args>
3636
KOKKOS_INLINE_FUNCTION
37-
Real eval(Args... args) const {
37+
auto eval(Args... args) const {
3838
if constexpr (scalar_base) {
3939
return Kokkos::pow(m_base,m_exp.eval(args...));
4040
} else if constexpr (scalar_exp) {
@@ -43,6 +43,19 @@ class PowExpression : public Expression<PowExpression<EBase,EExp>> {
4343
return Kokkos::pow(m_base.eval(args...),m_exp.eval(args...));
4444
}
4545
}
46+
47+
static auto ret_type () {
48+
if constexpr (scalar_base) {
49+
using type = decltype(Kokkos::pow(std::declval<EBase>(),EExp::ret_type()));
50+
return type(0);
51+
} else if constexpr (scalar_exp) {
52+
using type = decltype(Kokkos::pow(EBase::ret_type(),std::declval<EExp>()));
53+
return type(0);
54+
} else {
55+
using type = decltype(Kokkos::pow(EBase::ret_type(),EExp::ret_type()));
56+
return type(0);
57+
}
58+
}
4659
protected:
4760

4861
EBase m_base;
@@ -64,19 +77,20 @@ pow (const Expression<EBase>& b, const Expression<EExp>& e)
6477
public: \
6578
name##Expression (const EArg& arg) \
6679
: m_arg(arg) \
67-
{ \
68-
/* Nothing to do here */ \
69-
} \
80+
{} \
7081
\
7182
int num_indices () const { return m_arg.num_indices(); } \
7283
\
7384
template<typename... Args> \
7485
KOKKOS_INLINE_FUNCTION \
75-
Real eval(Args... args) const { \
86+
auto eval(Args... args) const { \
7687
return Kokkos::impl(m_arg.eval(args...)); \
7788
} \
89+
static auto ret_type () { \
90+
using type = decltype(Kokkos::impl(EArg::ret_type())); \
91+
return type(0); \
92+
} \
7893
protected: \
79-
\
8094
EArg m_arg; \
8195
}; \
8296
\

src/expression/ekat_expression_scalar.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
namespace ekat {
77

8-
// TODO: support 4+ dim. Also 0d?
98
template<typename ST>
109
class ScalarExpression : public Expression<ScalarExpression<ST>> {
1110
public:
@@ -18,10 +17,11 @@ class ScalarExpression : public Expression<ScalarExpression<ST>> {
1817

1918
template<typename... Args>
2019
KOKKOS_INLINE_FUNCTION
21-
Real eval(Args...) const {
20+
ST eval(Args...) const {
2221
return m_value;
2322
}
2423

24+
static ST ret_type () { return 0; }
2525
protected:
2626

2727
ST m_value;

src/expression/ekat_expression_view.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66
namespace ekat {
77

8-
// TODO: support 4+ dim. Also 0d?
98
template<typename ViewT>
109
class ViewExpression : public Expression<ViewExpression<ViewT>> {
1110
public:
1211
using view_t = ViewT;
13-
static_assert(view_t::rank >=1 and view_t::rank <=3,
14-
"Unsupported rank for ViewExpression");
12+
using value_t = typename ViewT::element_type;
1513

1614
ViewExpression (const view_t& v)
1715
: m_view(v)
@@ -23,10 +21,13 @@ class ViewExpression : public Expression<ViewExpression<ViewT>> {
2321

2422
template<typename... Args>
2523
KOKKOS_INLINE_FUNCTION
26-
Real eval(Args... args) const {
24+
value_t eval(Args... args) const {
25+
static_assert(sizeof...(Args)==ViewT::rank, "Something is off...\n");
2726
return m_view(args...);
2827
}
2928

29+
static value_t ret_type () { return 0; }
30+
3031
protected:
3132

3233
view_t m_view;

tests/expression/expressions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ void conditionals (const ViewT& x, const ViewT& y, const ViewT& z)
7373

7474
TEST_CASE("expressions", "") {
7575

76+
using Real = double;
77+
7678
std::random_device rdev;
7779
const int catchRngSeed = Catch::rngSeed();
7880
int seed = catchRngSeed==0 ? rdev()/2 : catchRngSeed;

0 commit comments

Comments
 (0)