Skip to content

Commit af1399b

Browse files
committed
Use variadic templates to support arbitrary number of indices
1 parent b6f5b57 commit af1399b

File tree

7 files changed

+53
-215
lines changed

7 files changed

+53
-215
lines changed

src/expression/ekat_expression_base.hpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,14 @@ class Expression {
1414

1515
int num_indices () const { return cast().num_indices(); }
1616

17+
template<typename... Args>
1718
KOKKOS_INLINE_FUNCTION
18-
Real eval(int i) const {
19-
return cast().eval(i);
20-
}
21-
22-
KOKKOS_INLINE_FUNCTION
23-
Real eval(int i, int j) const {
24-
return cast().eval(i,j);
25-
}
26-
27-
KOKKOS_INLINE_FUNCTION
28-
Real eval(int i, int j, int k) const {
29-
return cast().eval(i,j,k);
19+
Real eval(Args... args) const {
20+
static_assert(std::conjunction_v<std::is_integral<Args>...>,
21+
"[Expression] All arguments must be integral types!");
22+
static_assert(sizeof...(Args) <= 7,
23+
"[Expression] The number of arguments must be between 0 and 7.");
24+
return cast().eval(args...);
3025
}
3126

3227
KOKKOS_INLINE_FUNCTION

src/expression/ekat_expression_binary_op.hpp

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,52 +26,21 @@ class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
2626

2727
int num_indices () const { return std::max(m_left.num_indices(),m_right.num_indices()); }
2828

29+
template<typename... Args>
2930
KOKKOS_INLINE_FUNCTION
30-
Real eval (int i) const {
31+
Real eval(Args... args) const {
3132
if constexpr (OP==BinOp::Plus) {
32-
return m_left.eval(i) + m_right.eval(i);
33+
return m_left.eval(args...) + m_right.eval(args...);
3334
} else if constexpr (OP==BinOp::Minus) {
34-
return m_left.eval(i) - m_right.eval(i);
35+
return m_left.eval(args...) - m_right.eval(args...);
3536
} else if constexpr (OP==BinOp::Mult) {
36-
return m_left.eval(i) * m_right.eval(i);
37+
return m_left.eval(args...) * m_right.eval(args...);
3738
} else if constexpr (OP==BinOp::Div) {
38-
return m_left.eval(i) / m_right.eval(i);
39+
return m_left.eval(args...) / m_right.eval(args...);
3940
} else if constexpr (OP==BinOp::Max) {
40-
return Kokkos::max(m_left.eval(i),m_right.eval(i));
41+
return Kokkos::max(m_left.eval(args...),m_right.eval(args...));
4142
} else if constexpr (OP==BinOp::Div) {
42-
return Kokkos::min(m_left.eval(i),m_right.eval(i));
43-
}
44-
}
45-
KOKKOS_INLINE_FUNCTION
46-
Real eval (int i, int j) const {
47-
if constexpr (OP==BinOp::Plus) {
48-
return m_left.eval(i,j) + m_right.eval(i,j);
49-
} else if constexpr (OP==BinOp::Minus) {
50-
return m_left.eval(i,j) - m_right.eval(i,j);
51-
} else if constexpr (OP==BinOp::Mult) {
52-
return m_left.eval(i,j) * m_right.eval(i,j);
53-
} else if constexpr (OP==BinOp::Div) {
54-
return m_left.eval(i,j) / m_right.eval(i,j);
55-
} else if constexpr (OP==BinOp::Max) {
56-
return Kokkos::max(m_left.eval(i,j),m_right.eval(i,j));
57-
} else if constexpr (OP==BinOp::Div) {
58-
return Kokkos::min(m_left.eval(i,j),m_right.eval(i,j));
59-
}
60-
}
61-
KOKKOS_INLINE_FUNCTION
62-
Real eval (int i, int j, int k) const {
63-
if constexpr (OP==BinOp::Plus) {
64-
return m_left.eval(i,j,k) + m_right.eval(i,j,k);
65-
} else if constexpr (OP==BinOp::Minus) {
66-
return m_left.eval(i,j,k) - m_right.eval(i,j,k);
67-
} else if constexpr (OP==BinOp::Mult) {
68-
return m_left.eval(i,j,k) * m_right.eval(i,j,k);
69-
} else if constexpr (OP==BinOp::Div) {
70-
return m_left.eval(i,j,k) / m_right.eval(i,j,k);
71-
} else if constexpr (OP==BinOp::Max) {
72-
return Kokkos::max(m_left.eval(i,j,k),m_right.eval(i,j,k));
73-
} else if constexpr (OP==BinOp::Div) {
74-
return Kokkos::min(m_left.eval(i,j,k),m_right.eval(i,j,k));
43+
return Kokkos::min(m_left.eval(args...),m_right.eval(args...));
7544
}
7645
}
7746
protected:

src/expression/ekat_expression_compare.hpp

Lines changed: 20 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -47,112 +47,39 @@ class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> {
4747
}
4848
}
4949

50+
template<typename... Args>
5051
KOKKOS_INLINE_FUNCTION
51-
Real eval (int i) const {
52+
Real eval(Args... args) const {
5253
if constexpr (scalar_left) {
5354
switch (m_cmp) {
54-
case Comparison::EQ: return m_left == m_right.eval(i);
55-
case Comparison::NE: return m_left != m_right.eval(i);
56-
case Comparison::GT: return m_left > m_right.eval(i);
57-
case Comparison::GE: return m_left >= m_right.eval(i);
58-
case Comparison::LT: return m_left < m_right.eval(i);
59-
case Comparison::LE: return m_left <= m_right.eval(i);
55+
case Comparison::EQ: return m_left == m_right.eval(args...);
56+
case Comparison::NE: return m_left != m_right.eval(args...);
57+
case Comparison::GT: return m_left > m_right.eval(args...);
58+
case Comparison::GE: return m_left >= m_right.eval(args...);
59+
case Comparison::LT: return m_left < m_right.eval(args...);
60+
case Comparison::LE: return m_left <= m_right.eval(args...);
6061
default:
6162
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
6263
}
6364
} else if constexpr (scalar_right) {
6465
switch (m_cmp) {
65-
case Comparison::EQ: return m_left.eval(i) == m_right;
66-
case Comparison::NE: return m_left.eval(i) != m_right;
67-
case Comparison::GT: return m_left.eval(i) > m_right;
68-
case Comparison::GE: return m_left.eval(i) >= m_right;
69-
case Comparison::LT: return m_left.eval(i) < m_right;
70-
case Comparison::LE: return m_left.eval(i) <= m_right;
66+
case Comparison::EQ: return m_left.eval(args...) == m_right;
67+
case Comparison::NE: return m_left.eval(args...) != m_right;
68+
case Comparison::GT: return m_left.eval(args...) > m_right;
69+
case Comparison::GE: return m_left.eval(args...) >= m_right;
70+
case Comparison::LT: return m_left.eval(args...) < m_right;
71+
case Comparison::LE: return m_left.eval(args...) <= m_right;
7172
default:
7273
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
7374
}
7475
} else {
7576
switch (m_cmp) {
76-
case Comparison::EQ: return m_left.eval(i) == m_right.eval(i);
77-
case Comparison::NE: return m_left.eval(i) != m_right.eval(i);
78-
case Comparison::GT: return m_left.eval(i) > m_right.eval(i);
79-
case Comparison::GE: return m_left.eval(i) >= m_right.eval(i);
80-
case Comparison::LT: return m_left.eval(i) < m_right.eval(i);
81-
case Comparison::LE: return m_left.eval(i) <= m_right.eval(i);
82-
default:
83-
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
84-
}
85-
}
86-
}
87-
KOKKOS_INLINE_FUNCTION
88-
Real eval (int i,int j) const {
89-
if constexpr (scalar_left) {
90-
switch (m_cmp) {
91-
case Comparison::EQ: return m_left == m_right.eval(i,j);
92-
case Comparison::NE: return m_left != m_right.eval(i,j);
93-
case Comparison::GT: return m_left > m_right.eval(i,j);
94-
case Comparison::GE: return m_left >= m_right.eval(i,j);
95-
case Comparison::LT: return m_left < m_right.eval(i,j);
96-
case Comparison::LE: return m_left <= m_right.eval(i,j);
97-
default:
98-
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
99-
}
100-
} else if constexpr (scalar_right) {
101-
switch (m_cmp) {
102-
case Comparison::EQ: return m_left.eval(i,j) == m_right;
103-
case Comparison::NE: return m_left.eval(i,j) != m_right;
104-
case Comparison::GT: return m_left.eval(i,j) > m_right;
105-
case Comparison::GE: return m_left.eval(i,j) >= m_right;
106-
case Comparison::LT: return m_left.eval(i,j) < m_right;
107-
case Comparison::LE: return m_left.eval(i,j) <= m_right;
108-
default:
109-
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
110-
}
111-
} else {
112-
switch (m_cmp) {
113-
case Comparison::EQ: return m_left.eval(i,j) == m_right.eval(i,j);
114-
case Comparison::NE: return m_left.eval(i,j) != m_right.eval(i,j);
115-
case Comparison::GT: return m_left.eval(i,j) > m_right.eval(i,j);
116-
case Comparison::GE: return m_left.eval(i,j) >= m_right.eval(i,j);
117-
case Comparison::LT: return m_left.eval(i,j) < m_right.eval(i,j);
118-
case Comparison::LE: return m_left.eval(i,j) <= m_right.eval(i,j);
119-
default:
120-
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
121-
}
122-
}
123-
}
124-
KOKKOS_INLINE_FUNCTION
125-
Real eval (int i, int j, int k) const {
126-
if constexpr (scalar_left) {
127-
switch (m_cmp) {
128-
case Comparison::EQ: return m_left == m_right.eval(i,j,k);
129-
case Comparison::NE: return m_left != m_right.eval(i,j,k);
130-
case Comparison::GT: return m_left > m_right.eval(i,j,k);
131-
case Comparison::GE: return m_left >= m_right.eval(i,j,k);
132-
case Comparison::LT: return m_left < m_right.eval(i,j,k);
133-
case Comparison::LE: return m_left <= m_right.eval(i,j,k);
134-
default:
135-
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
136-
}
137-
} else if constexpr (scalar_right) {
138-
switch (m_cmp) {
139-
case Comparison::EQ: return m_left.eval(i,j,k) == m_right;
140-
case Comparison::NE: return m_left.eval(i,j,k) != m_right;
141-
case Comparison::GT: return m_left.eval(i,j,k) > m_right;
142-
case Comparison::GE: return m_left.eval(i,j,k) >= m_right;
143-
case Comparison::LT: return m_left.eval(i,j,k) < m_right;
144-
case Comparison::LE: return m_left.eval(i,j,k) <= m_right;
145-
default:
146-
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
147-
}
148-
} else {
149-
switch (m_cmp) {
150-
case Comparison::EQ: return m_left.eval(i,j,k) == m_right.eval(i,j,k);
151-
case Comparison::NE: return m_left.eval(i,j,k) != m_right.eval(i,j,k);
152-
case Comparison::GT: return m_left.eval(i,j,k) > m_right.eval(i,j,k);
153-
case Comparison::GE: return m_left.eval(i,j,k) >= m_right.eval(i,j,k);
154-
case Comparison::LT: return m_left.eval(i,j,k) < m_right.eval(i,j,k);
155-
case Comparison::LE: return m_left.eval(i,j,k) <= m_right.eval(i,j,k);
77+
case Comparison::EQ: return m_left.eval(args...) == m_right.eval(args...);
78+
case Comparison::NE: return m_left.eval(args...) != m_right.eval(args...);
79+
case Comparison::GT: return m_left.eval(args...) > m_right.eval(args...);
80+
case Comparison::GE: return m_left.eval(args...) >= m_right.eval(args...);
81+
case Comparison::LT: return m_left.eval(args...) < m_right.eval(args...);
82+
case Comparison::LE: return m_left.eval(args...) <= m_right.eval(args...);
15683
default:
15784
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
15885
}

src/expression/ekat_expression_conditional.hpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,13 @@ class ConditionalExpression : public Expression<ConditionalExpression<ECond,ELef
2020
return std::max(m_cmp.num_indices(),std::max(m_left.num_indices(),m_right.num_indices()));
2121
}
2222

23+
template<typename... Args>
2324
KOKKOS_INLINE_FUNCTION
24-
Real eval (int i) const {
25-
if (m_cmp.eval(i))
26-
return m_left.eval(i);
25+
Real eval(Args... args) const {
26+
if (m_cmp.eval(args...))
27+
return m_left.eval(args...);
2728
else
28-
return m_right.eval(i);
29-
}
30-
KOKKOS_INLINE_FUNCTION
31-
Real eval (int i, int j) const {
32-
if (m_cmp.eval(i,j))
33-
return m_left.eval(i,j);
34-
else
35-
return m_right.eval(i,j);
36-
}
37-
KOKKOS_INLINE_FUNCTION
38-
Real eval (int i, int j, int k) const {
39-
if (m_cmp.eval(i,j,k))
40-
return m_left.eval(i,j,k);
41-
else
42-
return m_right.eval(i,j,k);
29+
return m_right.eval(args...);
4330
}
4431
protected:
4532

src/expression/ekat_expression_math.hpp

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,15 @@ class PowExpression : public Expression<PowExpression<EBase,EExp>> {
3232
}
3333
}
3434

35+
template<typename... Args>
3536
KOKKOS_INLINE_FUNCTION
36-
Real eval (int i) const {
37+
Real eval(Args... args) const {
3738
if constexpr (scalar_base) {
38-
return Kokkos::pow(m_base,m_exp.eval(i));
39+
return Kokkos::pow(m_base,m_exp.eval(args...));
3940
} else if constexpr (scalar_exp) {
40-
return Kokkos::pow(m_base.eval(i),m_exp);
41+
return Kokkos::pow(m_base.eval(args...),m_exp);
4142
} else {
42-
return Kokkos::pow(m_base.eval(i),m_exp.eval(i));
43-
}
44-
}
45-
KOKKOS_INLINE_FUNCTION
46-
Real eval (int i, int j) const {
47-
if constexpr (scalar_base) {
48-
return Kokkos::pow(m_base,m_exp.eval(i,j));
49-
} else if constexpr (scalar_exp) {
50-
return Kokkos::pow(m_base.eval(i,j),m_exp);
51-
} else {
52-
return Kokkos::pow(m_base.eval(i,j),m_exp.eval(i,j));
53-
}
54-
}
55-
KOKKOS_INLINE_FUNCTION
56-
Real eval (int i, int j, int k) const {
57-
if constexpr (scalar_base) {
58-
return Kokkos::pow(m_base,m_exp.eval(i,j,k));
59-
} else if constexpr (scalar_exp) {
60-
return Kokkos::pow(m_base.eval(i,j,k),m_exp);
61-
} else {
62-
return Kokkos::pow(m_base.eval(i,j,k),m_exp.eval(i,j,k));
43+
return Kokkos::pow(m_base.eval(args...),m_exp.eval(args...));
6344
}
6445
}
6546
protected:
@@ -89,17 +70,10 @@ pow (const Expression<EBase>& b, const Expression<EExp>& e)
8970
\
9071
int num_indices () const { return m_arg.num_indices(); } \
9172
\
73+
template<typename... Args> \
9274
KOKKOS_INLINE_FUNCTION \
93-
Real eval (int i) const { \
94-
return Kokkos::impl(m_arg.eval(i)); \
95-
} \
96-
KOKKOS_INLINE_FUNCTION \
97-
Real eval (int i,int j) const { \
98-
return Kokkos::impl(m_arg.eval(i,j)); \
99-
} \
100-
KOKKOS_INLINE_FUNCTION \
101-
Real eval (int i,int j,int k) const { \
102-
return Kokkos::impl(m_arg.eval(i,j,k)); \
75+
Real eval(Args... args) const { \
76+
return Kokkos::impl(m_arg.eval(args...)); \
10377
} \
10478
protected: \
10579
\

src/expression/ekat_expression_scalar.hpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,9 @@ class ScalarExpression : public Expression<ScalarExpression<ST>> {
1616

1717
int num_indices () const { return 0; }
1818

19+
template<typename... Args>
1920
KOKKOS_INLINE_FUNCTION
20-
Real eval(int) const {
21-
return m_value;
22-
}
23-
KOKKOS_INLINE_FUNCTION
24-
Real eval(int,int) const {
25-
return m_value;
26-
}
27-
KOKKOS_INLINE_FUNCTION
28-
Real eval(int,int,int) const {
21+
Real eval(Args...) const {
2922
return m_value;
3023
}
3124

src/expression/ekat_expression_view.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,10 @@ class ViewExpression : public Expression<ViewExpression<ViewT>> {
2121

2222
int num_indices () const { return ViewT::rank; }
2323

24+
template<typename... Args>
2425
KOKKOS_INLINE_FUNCTION
25-
Real eval(int i) const {
26-
return m_view(i);
27-
}
28-
KOKKOS_INLINE_FUNCTION
29-
Real eval(int i,int j) const {
30-
return m_view(i,j);
31-
}
32-
KOKKOS_INLINE_FUNCTION
33-
Real eval(int i,int j,int k) const {
34-
return m_view(i,j,k);
26+
Real eval(Args... args) const {
27+
return m_view(args...);
3528
}
3629

3730
protected:

0 commit comments

Comments
 (0)