Skip to content

Commit b2e66d0

Browse files
authored
Merge pull request #395 from E3SM-Project/bartgol/expression-templates
Framework for expression-templates evaluation based on views
2 parents f158d9d + 37d63f9 commit b2e66d0

12 files changed

+873
-0
lines changed

src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ endif()
99

1010
option(EKAT_ENABLE_ALGORITHM "Enable EKAT algorithm utilities" ${PKG_DEFAULT})
1111
option(EKAT_ENABLE_KOKKOS "Enable EKAT kokkos utilities" ${PKG_DEFAULT})
12+
option(EKAT_ENABLE_EXPRESSION "Enable EKAT expression templates" ${PKG_DEFAULT})
1213
option(EKAT_ENABLE_LOGGING "Enable EKAT spdlog utilities" ${PKG_DEFAULT})
1314
option(EKAT_ENABLE_PACK "Enable EKAT packs utilities" ${PKG_DEFAULT})
1415
option(EKAT_ENABLE_YAML_PARSER "Enable EKAT yaml parsing utilities" ${PKG_DEFAULT})
@@ -45,6 +46,9 @@ endif()
4546
if (EKAT_ENABLE_KOKKOS)
4647
add_subdirectory(kokkos)
4748
endif()
49+
if (EKAT_ENABLE_EXPRESSION)
50+
add_subdirectory(expression)
51+
endif()
4852
if (EKAT_ENABLE_LOGGING)
4953
add_subdirectory(logging)
5054
endif()

src/expression/CMakeLists.txt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
add_library(ekat_expression INTERFACE)
2+
3+
target_link_libraries (ekat_expression INTERFACE
4+
ekat::KokkosUtils)
5+
6+
target_include_directories(ekat_expression INTERFACE
7+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
8+
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ekat>)
9+
10+
# Set the PUBLIC_HEADER property
11+
set (HEADERS
12+
ekat_expression_base.hpp
13+
ekat_expression_binary_op.hpp
14+
ekat_expression_compare.hpp
15+
ekat_expression_conditional.hpp
16+
ekat_expression_eval.hpp
17+
ekat_expression_math.hpp
18+
ekat_expression_view.hpp
19+
)
20+
set_target_properties(ekat_expression PROPERTIES PUBLIC_HEADER "${HEADERS}")
21+
22+
# Set the name to be used when exportin the target
23+
# This, together with the NAMESPACE property set in the main CMakeLists.txt
24+
# install call, will force user to link ekat::Expression
25+
set_target_properties(ekat_expression PROPERTIES
26+
EXPORT_NAME Expression
27+
PUBLIC_HEADER "${HEADERS}")
28+
29+
# Install the package
30+
install (TARGETS ekat_expression
31+
EXPORT EkatTargets
32+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
33+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
34+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
35+
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ekat)
36+
37+
# Regardless of whether we use INSTALLED ekat, or BUILT (via add_subdirectory),
38+
# we want to be able to access ekat's targets via ekat::TARGET
39+
add_library(ekat::Expression ALIAS ekat_expression)
40+
41+
# Link to the all libs target
42+
target_link_libraries(ekat_all_libs INTERFACE ekat_expression)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef EKAT_EXPRESSION_HPP
2+
#define EKAT_EXPRESSION_HPP
3+
4+
#include <Kokkos_Core.hpp>
5+
6+
namespace ekat {
7+
8+
template<typename Derived>
9+
class Expression {
10+
public:
11+
12+
int num_indices () const { return cast().num_indices(); }
13+
14+
template<typename... Args>
15+
KOKKOS_INLINE_FUNCTION
16+
auto eval(Args... args) const {
17+
static_assert(std::conjunction_v<std::is_integral<Args>...>,
18+
"[Expression] All arguments must be integral types!");
19+
static_assert(sizeof...(Args) <= 7,
20+
"[Expression] The number of arguments must be between 0 and 7.");
21+
return cast().eval(args...);
22+
}
23+
24+
KOKKOS_INLINE_FUNCTION
25+
const Derived& cast () const { return static_cast<const Derived&>(*this); }
26+
};
27+
28+
// Some meta-utilities that will prove useful in derived classes
29+
30+
// Detect if a type is an Expression
31+
template<typename T>
32+
struct is_expr : std::false_type {};
33+
template<typename D>
34+
struct is_expr<Expression<D>> : std::true_type {};
35+
template<typename T>
36+
constexpr bool is_expr_v = is_expr<T>::value;
37+
38+
} // namespace ekat
39+
40+
#endif // EKAT_EXPRESSION_HPP
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#ifndef EKAT_EXPRESSION_BINARY_OP_HPP
2+
#define EKAT_EXPRESSION_BINARY_OP_HPP
3+
4+
#include "ekat_expression_base.hpp"
5+
6+
namespace ekat {
7+
8+
enum class BinOp {
9+
Plus,
10+
Minus,
11+
Mult,
12+
Div,
13+
Max,
14+
Min
15+
};
16+
17+
template<typename ELeft, typename ERight, BinOp OP>
18+
class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
19+
public:
20+
static constexpr bool expr_l = is_expr_v<ELeft>;
21+
static constexpr bool expr_r = is_expr_v<ERight>;
22+
23+
static_assert (expr_l or expr_r,
24+
"[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n");
25+
26+
BinaryExpression (const ELeft& left, const ERight& right)
27+
: m_left(left)
28+
, m_right(right)
29+
{
30+
// Nothing to do here
31+
}
32+
33+
int num_indices () const {
34+
if constexpr (not expr_l) {
35+
return m_right.num_indices();
36+
} else if constexpr (not expr_r) {
37+
return m_left.num_indices();
38+
} else {
39+
return std::max(m_left.num_indices(),m_right.num_indices());
40+
}
41+
}
42+
43+
template<typename... Args>
44+
KOKKOS_INLINE_FUNCTION
45+
auto eval(Args... args) const {
46+
if constexpr (not expr_l) {
47+
return eval_impl(m_left,m_right.eval(args...));
48+
} else if constexpr (not expr_r) {
49+
return eval_impl(m_left.eval(args...),m_right);
50+
} else {
51+
return eval_impl(m_left.eval(args...),m_right.eval(args...));
52+
}
53+
}
54+
55+
static auto ret_type () { return ELeft::ret_type() + ERight::ret_type(); }
56+
protected:
57+
58+
template<typename T1, typename T2>
59+
KOKKOS_INLINE_FUNCTION
60+
auto eval_impl(const T1 l, const T2 r) const {
61+
if constexpr (OP==BinOp::Plus) {
62+
return l+r;
63+
} else if constexpr (OP==BinOp::Minus) {
64+
return l-r;
65+
} else if constexpr (OP==BinOp::Mult) {
66+
return l*r;
67+
} else if constexpr (OP==BinOp::Div) {
68+
return l/r;
69+
} else if constexpr (OP==BinOp::Max) {
70+
return Kokkos::max(l,r);
71+
} else if constexpr (OP==BinOp::Min) {
72+
return Kokkos::min(l,r);
73+
}
74+
}
75+
76+
ELeft m_left;
77+
ERight m_right;
78+
};
79+
80+
template<typename ELeft, typename ERight, BinOp OP>
81+
struct is_expr<BinaryExpression<ELeft,ERight,OP>> : std::true_type {};
82+
83+
// Unary minus implemented as -1*expr
84+
template<typename ERight>
85+
BinaryExpression<int,ERight,BinOp::Mult>
86+
operator- (const Expression<ERight>& r)
87+
{
88+
return BinaryExpression<int,ERight,BinOp::Mult>(-1,r.cast());
89+
}
90+
91+
// Overload arithmetic operators
92+
template<typename ELeft, typename ERight>
93+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Plus>>
94+
operator+ (const ELeft& l, const ERight& r)
95+
{
96+
return BinaryExpression<ELeft,ERight,BinOp::Plus>(l,r);
97+
}
98+
99+
template<typename ELeft, typename ERight>
100+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Minus>>
101+
operator- (const ELeft& l, const ERight& r)
102+
{
103+
return BinaryExpression<ELeft,ERight,BinOp::Minus>(l,r);
104+
}
105+
106+
template<typename ELeft, typename ERight>
107+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Mult>>
108+
operator* (const ELeft& l, const ERight& r)
109+
{
110+
return BinaryExpression<ELeft,ERight,BinOp::Mult>(l,r);
111+
}
112+
113+
template<typename ELeft, typename ERight>
114+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Div>>
115+
operator/ (const ELeft& l, const ERight& r)
116+
{
117+
return BinaryExpression<ELeft,ERight,BinOp::Div>(l,r);
118+
}
119+
120+
// Overload max/min functions
121+
template<typename ELeft, typename ERight>
122+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Max>>
123+
max (const ELeft& l, const ERight& r)
124+
{
125+
return BinaryExpression<ELeft,ERight,BinOp::Max>(l,r);
126+
}
127+
128+
template<typename ELeft, typename ERight>
129+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Min>>
130+
min (const ELeft& l, const ERight& r)
131+
{
132+
return BinaryExpression<ELeft,ERight,BinOp::Min>(l,r);
133+
}
134+
135+
} // namespace ekat
136+
137+
#endif // EKAT_EXPRESSION_BINARY_OP_HPP
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#ifndef EKAT_EXPRESSION_COMPARE_HPP
2+
#define EKAT_EXPRESSION_COMPARE_HPP
3+
4+
#include "ekat_expression_base.hpp"
5+
6+
#include "ekat_std_utils.hpp"
7+
#include "ekat_kernel_assert.hpp"
8+
#include "ekat_assert.hpp"
9+
10+
namespace ekat {
11+
12+
enum class Comparison : int {
13+
EQ, // ==
14+
NE, // !=
15+
GT, // >
16+
GE, // >=
17+
LT, // <
18+
LE // <=
19+
};
20+
21+
template<typename ELeft, typename ERight>
22+
class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> {
23+
public:
24+
using ret_t = int;
25+
26+
static constexpr bool expr_l = is_expr_v<ELeft>;
27+
static constexpr bool expr_r = is_expr_v<ERight>;
28+
29+
static_assert(expr_l or expr_r,
30+
"[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n");
31+
32+
CmpExpression (const ELeft& left, const ERight& right, Comparison CMP)
33+
: m_left(left)
34+
, m_right(right)
35+
, m_cmp(CMP)
36+
{
37+
auto valid = {Comparison::EQ,Comparison::NE,Comparison::GT,
38+
Comparison::GE,Comparison::LT,Comparison::LE};
39+
EKAT_REQUIRE_MSG (ekat::contains(valid,CMP),
40+
"[CmpExpression] Error! Unrecognized/unsupported Comparison value.\n");
41+
}
42+
43+
int num_indices () const {
44+
if constexpr (not expr_l) {
45+
return m_right.num_indices();
46+
} else if constexpr (not expr_r) {
47+
return m_left.num_indices();
48+
} else {
49+
return std::max(m_left.num_indices(),m_right.num_indices());
50+
}
51+
}
52+
53+
template<typename... Args>
54+
KOKKOS_INLINE_FUNCTION
55+
ret_t eval(Args... args) const {
56+
if constexpr (not expr_l) {
57+
switch (m_cmp) {
58+
case Comparison::EQ: return m_left == m_right.eval(args...);
59+
case Comparison::NE: return m_left != m_right.eval(args...);
60+
case Comparison::GT: return m_left > m_right.eval(args...);
61+
case Comparison::GE: return m_left >= m_right.eval(args...);
62+
case Comparison::LT: return m_left < m_right.eval(args...);
63+
case Comparison::LE: return m_left <= m_right.eval(args...);
64+
default:
65+
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
66+
}
67+
} else if constexpr (not expr_r) {
68+
switch (m_cmp) {
69+
case Comparison::EQ: return m_left.eval(args...) == m_right;
70+
case Comparison::NE: return m_left.eval(args...) != m_right;
71+
case Comparison::GT: return m_left.eval(args...) > m_right;
72+
case Comparison::GE: return m_left.eval(args...) >= m_right;
73+
case Comparison::LT: return m_left.eval(args...) < m_right;
74+
case Comparison::LE: return m_left.eval(args...) <= m_right;
75+
default:
76+
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
77+
}
78+
} else {
79+
switch (m_cmp) {
80+
case Comparison::EQ: return m_left.eval(args...) == m_right.eval(args...);
81+
case Comparison::NE: return m_left.eval(args...) != m_right.eval(args...);
82+
case Comparison::GT: return m_left.eval(args...) > m_right.eval(args...);
83+
case Comparison::GE: return m_left.eval(args...) >= m_right.eval(args...);
84+
case Comparison::LT: return m_left.eval(args...) < m_right.eval(args...);
85+
case Comparison::LE: return m_left.eval(args...) <= m_right.eval(args...);
86+
default:
87+
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
88+
}
89+
}
90+
}
91+
92+
static int ret_type () { return 0; }
93+
protected:
94+
95+
ELeft m_left;
96+
ERight m_right;
97+
98+
Comparison m_cmp;
99+
};
100+
101+
template<typename ELeft, typename ERight>
102+
struct is_expr<CmpExpression<ELeft,ERight>> : std::true_type {};
103+
104+
// Overload cmp operators for Expression types
105+
template<typename ELeft, typename ERight>
106+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
107+
operator== (const ELeft& l, const ERight& r)
108+
{
109+
return CmpExpression<ELeft,ERight>(l,r,Comparison::EQ);
110+
}
111+
112+
template<typename ELeft, typename ERight>
113+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
114+
operator!= (const ELeft& l, const ERight& r)
115+
{
116+
return CmpExpression<ELeft,ERight>(l,r,Comparison::NE);
117+
}
118+
119+
template<typename ELeft, typename ERight>
120+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
121+
operator> (const ELeft& l, const ERight& r)
122+
{
123+
return CmpExpression<ELeft,ERight>(l,r,Comparison::GT);
124+
}
125+
126+
template<typename ELeft, typename ERight>
127+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
128+
operator>= (const ELeft& l, const ERight& r)
129+
{
130+
return CmpExpression<ELeft,ERight>(l,r,Comparison::GE);
131+
}
132+
133+
template<typename ELeft, typename ERight>
134+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
135+
operator< (const ELeft& l, const ERight& r)
136+
{
137+
return CmpExpression<ELeft,ERight>(l,r,Comparison::LT);
138+
}
139+
140+
template<typename ELeft, typename ERight>
141+
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
142+
operator<= (const ELeft& l, const ERight& r)
143+
{
144+
return CmpExpression<ELeft,ERight>(l,r,Comparison::LE);
145+
}
146+
147+
} // namespace ekat
148+
149+
#endif // EKAT_EXPRESSION_COMPARE_HPP

0 commit comments

Comments
 (0)