Skip to content

Commit e7fed82

Browse files
authored
Merge pull request #399 from E3SM-Project/bartgol/expression-templates-simplifications
Some simplifications to the expression templates design
2 parents 7140e22 + e1553cd commit e7fed82

11 files changed

+410
-353
lines changed

src/expression/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ target_include_directories(ekat_expression INTERFACE
99

1010
# Set the PUBLIC_HEADER property
1111
set (HEADERS
12-
ekat_expression_base.hpp
12+
ekat_expression_meta.hpp
1313
ekat_expression_binary_op.hpp
1414
ekat_expression_compare.hpp
1515
ekat_expression_conditional.hpp

src/expression/ekat_expression_base.hpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

src/expression/ekat_expression_binary_op.hpp

Lines changed: 34 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
#ifndef EKAT_EXPRESSION_BINARY_OP_HPP
22
#define EKAT_EXPRESSION_BINARY_OP_HPP
33

4-
#include "ekat_expression_base.hpp"
4+
#include "ekat_expression_meta.hpp"
5+
6+
#include <Kokkos_Core.hpp>
57

68
namespace ekat {
79

810
enum class BinOp {
911
Plus,
1012
Minus,
1113
Mult,
12-
Div,
13-
Max,
14-
Min
14+
Div
1515
};
1616

1717
template<typename ELeft, typename ERight, BinOp OP>
18-
class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
18+
class BinaryExpression {
1919
public:
2020
static constexpr bool expr_l = is_expr_v<ELeft>;
2121
static constexpr bool expr_r = is_expr_v<ERight>;
2222

23+
using eval_left_t = eval_return_t<ELeft>;
24+
using eval_right_t = eval_return_t<ERight>;
25+
26+
using eval_t = std::common_type_t<eval_left_t,eval_right_t>;
27+
28+
// Don't create an expression from builtin types, just combine them!
2329
static_assert (expr_l or expr_r,
24-
"[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n");
30+
"[BinaryExpression] At least one between ELeft and ERight must be an Expression type.\n");
2531

2632
BinaryExpression (const ELeft& left, const ERight& right)
2733
: m_left(left)
@@ -30,19 +36,27 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
3036
// Nothing to do here
3137
}
3238

33-
int num_indices () const {
34-
if constexpr (not expr_l) {
35-
return m_right.num_indices();
36-
} else if constexpr (not expr_r) {
37-
return m_left.num_indices();
39+
static constexpr int rank () {
40+
if constexpr (expr_l) {
41+
if constexpr (expr_r) {
42+
static_assert(ELeft::rank()==ERight::rank(),
43+
"[BinaryExpression] Error! ELeft and ERight are Expression types of different rank.\n");
44+
}
45+
return ELeft::rank();
3846
} else {
39-
return std::max(m_left.num_indices(),m_right.num_indices());
47+
return ERight::rank();
4048
}
4149
}
50+
int extent (int i) const {
51+
if constexpr (expr_l)
52+
return m_left.extent(i);
53+
else
54+
return m_right.extent(i);
55+
}
4256

4357
template<typename... Args>
4458
KOKKOS_INLINE_FUNCTION
45-
auto eval(Args... args) const {
59+
eval_t eval(Args... args) const {
4660
if constexpr (not expr_l) {
4761
return eval_impl(m_left,m_right.eval(args...));
4862
} else if constexpr (not expr_r) {
@@ -52,12 +66,10 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
5266
}
5367
}
5468

55-
static auto ret_type () { return ELeft::ret_type() + ERight::ret_type(); }
5669
protected:
5770

58-
template<typename T1, typename T2>
5971
KOKKOS_INLINE_FUNCTION
60-
auto eval_impl(const T1 l, const T2 r) const {
72+
eval_t eval_impl (const eval_left_t& l, const eval_right_t& r) const {
6173
if constexpr (OP==BinOp::Plus) {
6274
return l+r;
6375
} else if constexpr (OP==BinOp::Minus) {
@@ -66,71 +78,21 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
6678
return l*r;
6779
} else if constexpr (OP==BinOp::Div) {
6880
return l/r;
69-
} else if constexpr (OP==BinOp::Max) {
70-
return Kokkos::max(l,r);
71-
} else if constexpr (OP==BinOp::Min) {
72-
return Kokkos::min(l,r);
81+
return Kokkos::min(static_cast<const eval_t&>(l),static_cast<const eval_t&>(r));
7382
}
7483
}
7584

7685
ELeft m_left;
7786
ERight m_right;
7887
};
7988

89+
// Specialize meta utils
8090
template<typename ELeft, typename ERight, BinOp OP>
8191
struct is_expr<BinaryExpression<ELeft,ERight,OP>> : std::true_type {};
82-
83-
// Unary minus implemented as -1*expr
84-
template<typename ERight>
85-
BinaryExpression<int,ERight,BinOp::Mult>
86-
operator- (const Expression<ERight>& r)
87-
{
88-
return BinaryExpression<int,ERight,BinOp::Mult>(-1,r.cast());
89-
}
90-
91-
// Overload arithmetic operators
92-
template<typename ELeft, typename ERight>
93-
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Plus>>
94-
operator+ (const ELeft& l, const ERight& r)
95-
{
96-
return BinaryExpression<ELeft,ERight,BinOp::Plus>(l,r);
97-
}
98-
99-
template<typename ELeft, typename ERight>
100-
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Minus>>
101-
operator- (const ELeft& l, const ERight& r)
102-
{
103-
return BinaryExpression<ELeft,ERight,BinOp::Minus>(l,r);
104-
}
105-
106-
template<typename ELeft, typename ERight>
107-
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Mult>>
108-
operator* (const ELeft& l, const ERight& r)
109-
{
110-
return BinaryExpression<ELeft,ERight,BinOp::Mult>(l,r);
111-
}
112-
113-
template<typename ELeft, typename ERight>
114-
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Div>>
115-
operator/ (const ELeft& l, const ERight& r)
116-
{
117-
return BinaryExpression<ELeft,ERight,BinOp::Div>(l,r);
118-
}
119-
120-
// Overload max/min functions
121-
template<typename ELeft, typename ERight>
122-
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Max>>
123-
max (const ELeft& l, const ERight& r)
124-
{
125-
return BinaryExpression<ELeft,ERight,BinOp::Max>(l,r);
126-
}
127-
128-
template<typename ELeft, typename ERight>
129-
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Min>>
130-
min (const ELeft& l, const ERight& r)
131-
{
132-
return BinaryExpression<ELeft,ERight,BinOp::Min>(l,r);
133-
}
92+
template<typename ELeft, typename ERight, BinOp OP>
93+
struct eval_return<BinaryExpression<ELeft,ERight,OP>> {
94+
using type = typename BinaryExpression<ELeft,ERight,OP>::eval_t;
95+
};
13496

13597
} // namespace ekat
13698

0 commit comments

Comments
 (0)