Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ endif()

option(EKAT_ENABLE_ALGORITHM "Enable EKAT algorithm utilities" ${PKG_DEFAULT})
option(EKAT_ENABLE_KOKKOS "Enable EKAT kokkos utilities" ${PKG_DEFAULT})
option(EKAT_ENABLE_EXPRESSION "Enable EKAT expression templates" ${PKG_DEFAULT})
option(EKAT_ENABLE_LOGGING "Enable EKAT spdlog utilities" ${PKG_DEFAULT})
option(EKAT_ENABLE_PACK "Enable EKAT packs utilities" ${PKG_DEFAULT})
option(EKAT_ENABLE_YAML_PARSER "Enable EKAT yaml parsing utilities" ${PKG_DEFAULT})
Expand Down Expand Up @@ -45,6 +46,9 @@ endif()
if (EKAT_ENABLE_KOKKOS)
add_subdirectory(kokkos)
endif()
if (EKAT_ENABLE_EXPRESSION)
add_subdirectory(expression)
endif()
if (EKAT_ENABLE_LOGGING)
add_subdirectory(logging)
endif()
Expand Down
42 changes: 42 additions & 0 deletions src/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
add_library(ekat_expression INTERFACE)

target_link_libraries (ekat_expression INTERFACE
ekat::KokkosUtils)

target_include_directories(ekat_expression INTERFACE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ekat>)

# Set the PUBLIC_HEADER property
set (HEADERS
ekat_expression_base.hpp
ekat_expression_binary_op.hpp
ekat_expression_compare.hpp
ekat_expression_conditional.hpp
ekat_expression_eval.hpp
ekat_expression_math.hpp
ekat_expression_view.hpp
)
set_target_properties(ekat_expression PROPERTIES PUBLIC_HEADER "${HEADERS}")

# Set the name to be used when exportin the target
# This, together with the NAMESPACE property set in the main CMakeLists.txt
# install call, will force user to link ekat::Expression
set_target_properties(ekat_expression PROPERTIES
EXPORT_NAME Expression
PUBLIC_HEADER "${HEADERS}")

# Install the package
install (TARGETS ekat_expression
EXPORT EkatTargets
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ekat)

# Regardless of whether we use INSTALLED ekat, or BUILT (via add_subdirectory),
# we want to be able to access ekat's targets via ekat::TARGET
add_library(ekat::Expression ALIAS ekat_expression)

# Link to the all libs target
target_link_libraries(ekat_all_libs INTERFACE ekat_expression)
40 changes: 40 additions & 0 deletions src/expression/ekat_expression_base.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef EKAT_EXPRESSION_HPP
#define EKAT_EXPRESSION_HPP

#include <Kokkos_Core.hpp>

namespace ekat {

template<typename Derived>
class Expression {
public:

int num_indices () const { return cast().num_indices(); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be turned into a static fcn, as the number of indices used by an expression is fully determined by its type.

So far, this fcn is only used in the evaluate call, to make sure that the rank of the view passed to the evaluation fcn is compatible with the expression to evaluate. I may revisit this pattern in a follow-up pr...


template<typename... Args>
KOKKOS_INLINE_FUNCTION
auto eval(Args... args) const {
static_assert(std::conjunction_v<std::is_integral<Args>...>,
"[Expression] All arguments must be integral types!");
static_assert(sizeof...(Args) <= 7,
"[Expression] The number of arguments must be between 0 and 7.");
return cast().eval(args...);
}

KOKKOS_INLINE_FUNCTION
const Derived& cast () const { return static_cast<const Derived&>(*this); }
};

// Some meta-utilities that will prove useful in derived classes

// Detect if a type is an Expression
template<typename T>
struct is_expr : std::false_type {};
template<typename D>
struct is_expr<Expression<D>> : std::true_type {};
template<typename T>
constexpr bool is_expr_v = is_expr<T>::value;

} // namespace ekat

#endif // EKAT_EXPRESSION_HPP
137 changes: 137 additions & 0 deletions src/expression/ekat_expression_binary_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#ifndef EKAT_EXPRESSION_BINARY_OP_HPP
#define EKAT_EXPRESSION_BINARY_OP_HPP

#include "ekat_expression_base.hpp"

namespace ekat {

enum class BinOp {
Plus,
Minus,
Mult,
Div,
Max,
Min
};

template<typename ELeft, typename ERight, BinOp OP>
class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{
public:
static constexpr bool expr_l = is_expr_v<ELeft>;
static constexpr bool expr_r = is_expr_v<ERight>;

static_assert (expr_l or expr_r,
"[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n");

BinaryExpression (const ELeft& left, const ERight& right)
: m_left(left)
, m_right(right)
{
// Nothing to do here
}

int num_indices () const {
if constexpr (not expr_l) {
return m_right.num_indices();
} else if constexpr (not expr_r) {
return m_left.num_indices();
} else {
return std::max(m_left.num_indices(),m_right.num_indices());
}
}

template<typename... Args>
KOKKOS_INLINE_FUNCTION
auto eval(Args... args) const {
if constexpr (not expr_l) {
return eval_impl(m_left,m_right.eval(args...));
} else if constexpr (not expr_r) {
return eval_impl(m_left.eval(args...),m_right);
} else {
return eval_impl(m_left.eval(args...),m_right.eval(args...));
}
}

static auto ret_type () { return ELeft::ret_type() + ERight::ret_type(); }
protected:

template<typename T1, typename T2>
KOKKOS_INLINE_FUNCTION
auto eval_impl(const T1 l, const T2 r) const {
if constexpr (OP==BinOp::Plus) {
return l+r;
} else if constexpr (OP==BinOp::Minus) {
return l-r;
} else if constexpr (OP==BinOp::Mult) {
return l*r;
} else if constexpr (OP==BinOp::Div) {
return l/r;
} else if constexpr (OP==BinOp::Max) {
return Kokkos::max(l,r);
} else if constexpr (OP==BinOp::Min) {
return Kokkos::min(l,r);
}
}

ELeft m_left;
ERight m_right;
};

template<typename ELeft, typename ERight, BinOp OP>
struct is_expr<BinaryExpression<ELeft,ERight,OP>> : std::true_type {};

// Unary minus implemented as -1*expr
template<typename ERight>
BinaryExpression<int,ERight,BinOp::Mult>
operator- (const Expression<ERight>& r)
{
return BinaryExpression<int,ERight,BinOp::Mult>(-1,r.cast());
}

// Overload arithmetic operators
template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Plus>>
operator+ (const ELeft& l, const ERight& r)
{
return BinaryExpression<ELeft,ERight,BinOp::Plus>(l,r);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Minus>>
operator- (const ELeft& l, const ERight& r)
{
return BinaryExpression<ELeft,ERight,BinOp::Minus>(l,r);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Mult>>
operator* (const ELeft& l, const ERight& r)
{
return BinaryExpression<ELeft,ERight,BinOp::Mult>(l,r);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Div>>
operator/ (const ELeft& l, const ERight& r)
{
return BinaryExpression<ELeft,ERight,BinOp::Div>(l,r);
}

// Overload max/min functions
template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Max>>
max (const ELeft& l, const ERight& r)
{
return BinaryExpression<ELeft,ERight,BinOp::Max>(l,r);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Min>>
min (const ELeft& l, const ERight& r)
{
return BinaryExpression<ELeft,ERight,BinOp::Min>(l,r);
}

} // namespace ekat

#endif // EKAT_EXPRESSION_BINARY_OP_HPP
149 changes: 149 additions & 0 deletions src/expression/ekat_expression_compare.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#ifndef EKAT_EXPRESSION_COMPARE_HPP
#define EKAT_EXPRESSION_COMPARE_HPP

#include "ekat_expression_base.hpp"

#include "ekat_std_utils.hpp"
#include "ekat_kernel_assert.hpp"
#include "ekat_assert.hpp"

namespace ekat {

enum class Comparison : int {
EQ, // ==
NE, // !=
GT, // >
GE, // >=
LT, // <
LE // <=
};

template<typename ELeft, typename ERight>
class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> {
public:
using ret_t = int;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used int for some reason, but it could probably be changed into bool. I already have a follow up branch where I extend the expressions framework to work with Packs, so it can be changed there.


static constexpr bool expr_l = is_expr_v<ELeft>;
static constexpr bool expr_r = is_expr_v<ERight>;

static_assert(expr_l or expr_r,
"[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n");

CmpExpression (const ELeft& left, const ERight& right, Comparison CMP)
: m_left(left)
, m_right(right)
, m_cmp(CMP)
{
auto valid = {Comparison::EQ,Comparison::NE,Comparison::GT,
Comparison::GE,Comparison::LT,Comparison::LE};
EKAT_REQUIRE_MSG (ekat::contains(valid,CMP),
"[CmpExpression] Error! Unrecognized/unsupported Comparison value.\n");
}

int num_indices () const {
if constexpr (not expr_l) {
return m_right.num_indices();
} else if constexpr (not expr_r) {
return m_left.num_indices();
} else {
return std::max(m_left.num_indices(),m_right.num_indices());
}
}

template<typename... Args>
KOKKOS_INLINE_FUNCTION
ret_t eval(Args... args) const {
if constexpr (not expr_l) {
switch (m_cmp) {
case Comparison::EQ: return m_left == m_right.eval(args...);
case Comparison::NE: return m_left != m_right.eval(args...);
case Comparison::GT: return m_left > m_right.eval(args...);
case Comparison::GE: return m_left >= m_right.eval(args...);
case Comparison::LT: return m_left < m_right.eval(args...);
case Comparison::LE: return m_left <= m_right.eval(args...);
default:
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
}
} else if constexpr (not expr_r) {
switch (m_cmp) {
case Comparison::EQ: return m_left.eval(args...) == m_right;
case Comparison::NE: return m_left.eval(args...) != m_right;
case Comparison::GT: return m_left.eval(args...) > m_right;
case Comparison::GE: return m_left.eval(args...) >= m_right;
case Comparison::LT: return m_left.eval(args...) < m_right;
case Comparison::LE: return m_left.eval(args...) <= m_right;
default:
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
}
} else {
switch (m_cmp) {
case Comparison::EQ: return m_left.eval(args...) == m_right.eval(args...);
case Comparison::NE: return m_left.eval(args...) != m_right.eval(args...);
case Comparison::GT: return m_left.eval(args...) > m_right.eval(args...);
case Comparison::GE: return m_left.eval(args...) >= m_right.eval(args...);
case Comparison::LT: return m_left.eval(args...) < m_right.eval(args...);
case Comparison::LE: return m_left.eval(args...) <= m_right.eval(args...);
default:
EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n");
}
}
}

static int ret_type () { return 0; }
protected:

ELeft m_left;
ERight m_right;

Comparison m_cmp;
};

template<typename ELeft, typename ERight>
struct is_expr<CmpExpression<ELeft,ERight>> : std::true_type {};

// Overload cmp operators for Expression types
template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
operator== (const ELeft& l, const ERight& r)
{
return CmpExpression<ELeft,ERight>(l,r,Comparison::EQ);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
operator!= (const ELeft& l, const ERight& r)
{
return CmpExpression<ELeft,ERight>(l,r,Comparison::NE);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
operator> (const ELeft& l, const ERight& r)
{
return CmpExpression<ELeft,ERight>(l,r,Comparison::GT);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
operator>= (const ELeft& l, const ERight& r)
{
return CmpExpression<ELeft,ERight>(l,r,Comparison::GE);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
operator< (const ELeft& l, const ERight& r)
{
return CmpExpression<ELeft,ERight>(l,r,Comparison::LT);
}

template<typename ELeft, typename ERight>
std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>>
operator<= (const ELeft& l, const ERight& r)
{
return CmpExpression<ELeft,ERight>(l,r,Comparison::LE);
}

} // namespace ekat

#endif // EKAT_EXPRESSION_COMPARE_HPP
Loading