|
| 1 | +#ifndef EKAT_EXPRESSION_COMPARE_HPP |
| 2 | +#define EKAT_EXPRESSION_COMPARE_HPP |
| 3 | + |
| 4 | +#include "ekat_expression_base.hpp" |
| 5 | + |
| 6 | +#include "ekat_std_utils.hpp" |
| 7 | +#include "ekat_kernel_assert.hpp" |
| 8 | +#include "ekat_assert.hpp" |
| 9 | + |
| 10 | +namespace ekat { |
| 11 | + |
| 12 | +enum class Comparison : int { |
| 13 | + EQ, // == |
| 14 | + NE, // != |
| 15 | + GT, // > |
| 16 | + GE, // >= |
| 17 | + LT, // < |
| 18 | + LE // <= |
| 19 | +}; |
| 20 | + |
| 21 | +template<typename ELeft, typename ERight> |
| 22 | +class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> { |
| 23 | +public: |
| 24 | + using ret_t = int; |
| 25 | + |
| 26 | + static constexpr bool expr_l = is_expr_v<ELeft>; |
| 27 | + static constexpr bool expr_r = is_expr_v<ERight>; |
| 28 | + |
| 29 | + static_assert(expr_l or expr_r, |
| 30 | + "[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n"); |
| 31 | + |
| 32 | + CmpExpression (const ELeft& left, const ERight& right, Comparison CMP) |
| 33 | + : m_left(left) |
| 34 | + , m_right(right) |
| 35 | + , m_cmp(CMP) |
| 36 | + { |
| 37 | + auto valid = {Comparison::EQ,Comparison::NE,Comparison::GT, |
| 38 | + Comparison::GE,Comparison::LT,Comparison::LE}; |
| 39 | + EKAT_REQUIRE_MSG (ekat::contains(valid,CMP), |
| 40 | + "[CmpExpression] Error! Unrecognized/unsupported Comparison value.\n"); |
| 41 | + } |
| 42 | + |
| 43 | + int num_indices () const { |
| 44 | + if constexpr (not expr_l) { |
| 45 | + return m_right.num_indices(); |
| 46 | + } else if constexpr (not expr_r) { |
| 47 | + return m_left.num_indices(); |
| 48 | + } else { |
| 49 | + return std::max(m_left.num_indices(),m_right.num_indices()); |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + template<typename... Args> |
| 54 | + KOKKOS_INLINE_FUNCTION |
| 55 | + ret_t eval(Args... args) const { |
| 56 | + if constexpr (not expr_l) { |
| 57 | + switch (m_cmp) { |
| 58 | + case Comparison::EQ: return m_left == m_right.eval(args...); |
| 59 | + case Comparison::NE: return m_left != m_right.eval(args...); |
| 60 | + case Comparison::GT: return m_left > m_right.eval(args...); |
| 61 | + case Comparison::GE: return m_left >= m_right.eval(args...); |
| 62 | + case Comparison::LT: return m_left < m_right.eval(args...); |
| 63 | + case Comparison::LE: return m_left <= m_right.eval(args...); |
| 64 | + default: |
| 65 | + EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n"); |
| 66 | + } |
| 67 | + } else if constexpr (not expr_r) { |
| 68 | + switch (m_cmp) { |
| 69 | + case Comparison::EQ: return m_left.eval(args...) == m_right; |
| 70 | + case Comparison::NE: return m_left.eval(args...) != m_right; |
| 71 | + case Comparison::GT: return m_left.eval(args...) > m_right; |
| 72 | + case Comparison::GE: return m_left.eval(args...) >= m_right; |
| 73 | + case Comparison::LT: return m_left.eval(args...) < m_right; |
| 74 | + case Comparison::LE: return m_left.eval(args...) <= m_right; |
| 75 | + default: |
| 76 | + EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n"); |
| 77 | + } |
| 78 | + } else { |
| 79 | + switch (m_cmp) { |
| 80 | + case Comparison::EQ: return m_left.eval(args...) == m_right.eval(args...); |
| 81 | + case Comparison::NE: return m_left.eval(args...) != m_right.eval(args...); |
| 82 | + case Comparison::GT: return m_left.eval(args...) > m_right.eval(args...); |
| 83 | + case Comparison::GE: return m_left.eval(args...) >= m_right.eval(args...); |
| 84 | + case Comparison::LT: return m_left.eval(args...) < m_right.eval(args...); |
| 85 | + case Comparison::LE: return m_left.eval(args...) <= m_right.eval(args...); |
| 86 | + default: |
| 87 | + EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n"); |
| 88 | + } |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + static int ret_type () { return 0; } |
| 93 | +protected: |
| 94 | + |
| 95 | + ELeft m_left; |
| 96 | + ERight m_right; |
| 97 | + |
| 98 | + Comparison m_cmp; |
| 99 | +}; |
| 100 | + |
| 101 | +template<typename ELeft, typename ERight> |
| 102 | +struct is_expr<CmpExpression<ELeft,ERight>> : std::true_type {}; |
| 103 | + |
| 104 | +// Overload cmp operators for Expression types |
| 105 | +template<typename ELeft, typename ERight> |
| 106 | +std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> |
| 107 | +operator== (const ELeft& l, const ERight& r) |
| 108 | +{ |
| 109 | + return CmpExpression<ELeft,ERight>(l,r,Comparison::EQ); |
| 110 | +} |
| 111 | + |
| 112 | +template<typename ELeft, typename ERight> |
| 113 | +std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> |
| 114 | +operator!= (const ELeft& l, const ERight& r) |
| 115 | +{ |
| 116 | + return CmpExpression<ELeft,ERight>(l,r,Comparison::NE); |
| 117 | +} |
| 118 | + |
| 119 | +template<typename ELeft, typename ERight> |
| 120 | +std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> |
| 121 | +operator> (const ELeft& l, const ERight& r) |
| 122 | +{ |
| 123 | + return CmpExpression<ELeft,ERight>(l,r,Comparison::GT); |
| 124 | +} |
| 125 | + |
| 126 | +template<typename ELeft, typename ERight> |
| 127 | +std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> |
| 128 | +operator>= (const ELeft& l, const ERight& r) |
| 129 | +{ |
| 130 | + return CmpExpression<ELeft,ERight>(l,r,Comparison::GE); |
| 131 | +} |
| 132 | + |
| 133 | +template<typename ELeft, typename ERight> |
| 134 | +std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> |
| 135 | +operator< (const ELeft& l, const ERight& r) |
| 136 | +{ |
| 137 | + return CmpExpression<ELeft,ERight>(l,r,Comparison::LT); |
| 138 | +} |
| 139 | + |
| 140 | +template<typename ELeft, typename ERight> |
| 141 | +std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> |
| 142 | +operator<= (const ELeft& l, const ERight& r) |
| 143 | +{ |
| 144 | + return CmpExpression<ELeft,ERight>(l,r,Comparison::LE); |
| 145 | +} |
| 146 | + |
| 147 | +} // namespace ekat |
| 148 | + |
| 149 | +#endif // EKAT_EXPRESSION_COMPARE_HPP |
0 commit comments