11#ifndef EKAT_EXPRESSION_BINARY_OP_HPP
22#define EKAT_EXPRESSION_BINARY_OP_HPP
33
4- #include " ekat_expression_base.hpp"
4+ #include " ekat_expression_meta.hpp"
5+
6+ #include < Kokkos_Core.hpp>
57
68namespace ekat {
79
810enum class BinOp {
911 Plus,
1012 Minus,
1113 Mult,
12- Div,
13- Max,
14- Min
14+ Div
1515};
1616
1717template <typename ELeft, typename ERight, BinOp OP>
18- class BinaryExpression : public Expression <BinaryExpression<ELeft,ERight,OP>> {
18+ class BinaryExpression {
1919public:
2020 static constexpr bool expr_l = is_expr_v<ELeft>;
2121 static constexpr bool expr_r = is_expr_v<ERight>;
2222
23+ using eval_left_t = eval_return_t <ELeft>;
24+ using eval_right_t = eval_return_t <ERight>;
25+
26+ using eval_t = std::common_type_t <eval_left_t ,eval_right_t >;
27+
28+ // Don't create an expression from builtin types, just combine them!
2329 static_assert (expr_l or expr_r,
24- " [CmpExpression ] At least one between ELeft and ERight must be an Expression type.\n " );
30+ " [BinaryExpression ] At least one between ELeft and ERight must be an Expression type.\n " );
2531
2632 BinaryExpression (const ELeft& left, const ERight& right)
2733 : m_left(left)
@@ -30,19 +36,27 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
3036 // Nothing to do here
3137 }
3238
33- int num_indices () const {
34- if constexpr (not expr_l) {
35- return m_right.num_indices ();
36- } else if constexpr (not expr_r) {
37- return m_left.num_indices ();
39+ static constexpr int rank () {
40+ if constexpr (expr_l) {
41+ if constexpr (expr_r) {
42+ static_assert (ELeft::rank ()==ERight::rank (),
43+ " [BinaryExpression] Error! ELeft and ERight are Expression types of different rank.\n " );
44+ }
45+ return ELeft::rank ();
3846 } else {
39- return std::max (m_left. num_indices (),m_right. num_indices () );
47+ return ERight::rank ( );
4048 }
4149 }
50+ int extent (int i) const {
51+ if constexpr (expr_l)
52+ return m_left.extent (i);
53+ else
54+ return m_right.extent (i);
55+ }
4256
4357 template <typename ... Args>
4458 KOKKOS_INLINE_FUNCTION
45- auto eval (Args... args) const {
59+ eval_t eval (Args... args) const {
4660 if constexpr (not expr_l) {
4761 return eval_impl (m_left,m_right.eval (args...));
4862 } else if constexpr (not expr_r) {
@@ -52,12 +66,10 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
5266 }
5367 }
5468
55- static auto ret_type () { return ELeft::ret_type () + ERight::ret_type (); }
5669protected:
5770
58- template <typename T1, typename T2>
5971 KOKKOS_INLINE_FUNCTION
60- auto eval_impl (const T1 l, const T2 r) const {
72+ eval_t eval_impl (const eval_left_t & l, const eval_right_t & r) const {
6173 if constexpr (OP==BinOp::Plus) {
6274 return l+r;
6375 } else if constexpr (OP==BinOp::Minus) {
@@ -66,71 +78,21 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
6678 return l*r;
6779 } else if constexpr (OP==BinOp::Div) {
6880 return l/r;
69- } else if constexpr (OP==BinOp::Max) {
70- return Kokkos::max (l,r);
71- } else if constexpr (OP==BinOp::Min) {
72- return Kokkos::min (l,r);
81+ return Kokkos::min (static_cast <const eval_t &>(l),static_cast <const eval_t &>(r));
7382 }
7483 }
7584
7685 ELeft m_left;
7786 ERight m_right;
7887};
7988
89+ // Specialize meta utils
8090template <typename ELeft, typename ERight, BinOp OP>
8191struct is_expr <BinaryExpression<ELeft,ERight,OP>> : std::true_type {};
82-
83- // Unary minus implemented as -1*expr
84- template <typename ERight>
85- BinaryExpression<int ,ERight,BinOp::Mult>
86- operator - (const Expression<ERight>& r)
87- {
88- return BinaryExpression<int ,ERight,BinOp::Mult>(-1 ,r.cast ());
89- }
90-
91- // Overload arithmetic operators
92- template <typename ELeft, typename ERight>
93- std::enable_if_t <is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Plus>>
94- operator + (const ELeft& l, const ERight& r)
95- {
96- return BinaryExpression<ELeft,ERight,BinOp::Plus>(l,r);
97- }
98-
99- template <typename ELeft, typename ERight>
100- std::enable_if_t <is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Minus>>
101- operator - (const ELeft& l, const ERight& r)
102- {
103- return BinaryExpression<ELeft,ERight,BinOp::Minus>(l,r);
104- }
105-
106- template <typename ELeft, typename ERight>
107- std::enable_if_t <is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Mult>>
108- operator * (const ELeft& l, const ERight& r)
109- {
110- return BinaryExpression<ELeft,ERight,BinOp::Mult>(l,r);
111- }
112-
113- template <typename ELeft, typename ERight>
114- std::enable_if_t <is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Div>>
115- operator / (const ELeft& l, const ERight& r)
116- {
117- return BinaryExpression<ELeft,ERight,BinOp::Div>(l,r);
118- }
119-
120- // Overload max/min functions
121- template <typename ELeft, typename ERight>
122- std::enable_if_t <is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Max>>
123- max (const ELeft& l, const ERight& r)
124- {
125- return BinaryExpression<ELeft,ERight,BinOp::Max>(l,r);
126- }
127-
128- template <typename ELeft, typename ERight>
129- std::enable_if_t <is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Min>>
130- min (const ELeft& l, const ERight& r)
131- {
132- return BinaryExpression<ELeft,ERight,BinOp::Min>(l,r);
133- }
92+ template <typename ELeft, typename ERight, BinOp OP>
93+ struct eval_return <BinaryExpression<ELeft,ERight,OP>> {
94+ using type = typename BinaryExpression<ELeft,ERight,OP>::eval_t ;
95+ };
13496
13597} // namespace ekat
13698
0 commit comments