-
Notifications
You must be signed in to change notification settings - Fork 10
Framework for expression-templates evaluation based on views #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
e8558f2
Start working on expression templates package
bartgol 561bf8f
Add unit tests for expression package
bartgol b6f5b57
Perform tolerance-based checks in expressions tests
bartgol 8e9c3dc
Use variadic templates to support arbitrary number of indices
bartgol 91f3975
Remove hard-coded return type of expression eval method
bartgol 0b499e3
Remove ScalarExpression and support scalars directly in BinaryExpression
bartgol 2c9baad
Remove operator overloads for views
bartgol 3774451
Fix small bug in BinaryExpression and add overloads
bartgol 5c8c945
Add cmp operator overloads when one of the terms is a scalar
bartgol 68a2c4e
Add pow overloads when one of the terms is a scalar
bartgol 929598e
Change check on expression template arg
bartgol 4843cbf
Simplify operator overloads using is_expr_v and enable_if_t
bartgol 0ae6b45
Allow ConditionalExpression with non-expr condition or left or right …
bartgol ba6f4b0
Add specializations of is_expr meta-util
bartgol 37d63f9
Fix includes in unit test
bartgol File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| add_library(ekat_expression INTERFACE) | ||
|
|
||
| target_link_libraries (ekat_expression INTERFACE | ||
| ekat::KokkosUtils) | ||
|
|
||
| target_include_directories(ekat_expression INTERFACE | ||
| $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}> | ||
| $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ekat>) | ||
|
|
||
| # Set the PUBLIC_HEADER property | ||
| set (HEADERS | ||
| ekat_expression_base.hpp | ||
| ekat_expression_binary_op.hpp | ||
| ekat_expression_compare.hpp | ||
| ekat_expression_conditional.hpp | ||
| ekat_expression_eval.hpp | ||
| ekat_expression_math.hpp | ||
| ekat_expression_view.hpp | ||
| ) | ||
| set_target_properties(ekat_expression PROPERTIES PUBLIC_HEADER "${HEADERS}") | ||
|
|
||
| # Set the name to be used when exportin the target | ||
| # This, together with the NAMESPACE property set in the main CMakeLists.txt | ||
| # install call, will force user to link ekat::Expression | ||
| set_target_properties(ekat_expression PROPERTIES | ||
| EXPORT_NAME Expression | ||
| PUBLIC_HEADER "${HEADERS}") | ||
|
|
||
| # Install the package | ||
| install (TARGETS ekat_expression | ||
| EXPORT EkatTargets | ||
| RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} | ||
| LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} | ||
| ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} | ||
| PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ekat) | ||
|
|
||
| # Regardless of whether we use INSTALLED ekat, or BUILT (via add_subdirectory), | ||
| # we want to be able to access ekat's targets via ekat::TARGET | ||
| add_library(ekat::Expression ALIAS ekat_expression) | ||
|
|
||
| # Link to the all libs target | ||
| target_link_libraries(ekat_all_libs INTERFACE ekat_expression) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| #ifndef EKAT_EXPRESSION_HPP | ||
| #define EKAT_EXPRESSION_HPP | ||
|
|
||
| #include <Kokkos_Core.hpp> | ||
|
|
||
| namespace ekat { | ||
|
|
||
| template<typename Derived> | ||
| class Expression { | ||
| public: | ||
|
|
||
| int num_indices () const { return cast().num_indices(); } | ||
|
|
||
| template<typename... Args> | ||
| KOKKOS_INLINE_FUNCTION | ||
| auto eval(Args... args) const { | ||
| static_assert(std::conjunction_v<std::is_integral<Args>...>, | ||
| "[Expression] All arguments must be integral types!"); | ||
| static_assert(sizeof...(Args) <= 7, | ||
| "[Expression] The number of arguments must be between 0 and 7."); | ||
| return cast().eval(args...); | ||
| } | ||
|
|
||
| KOKKOS_INLINE_FUNCTION | ||
| const Derived& cast () const { return static_cast<const Derived&>(*this); } | ||
| }; | ||
|
|
||
| // Some meta-utilities that will prove useful in derived classes | ||
|
|
||
| // Detect if a type is an Expression | ||
| template<typename T> | ||
| struct is_expr : std::false_type {}; | ||
| template<typename D> | ||
| struct is_expr<Expression<D>> : std::true_type {}; | ||
| template<typename T> | ||
| constexpr bool is_expr_v = is_expr<T>::value; | ||
|
|
||
| } // namespace ekat | ||
|
|
||
| #endif // EKAT_EXPRESSION_HPP | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| #ifndef EKAT_EXPRESSION_BINARY_OP_HPP | ||
| #define EKAT_EXPRESSION_BINARY_OP_HPP | ||
|
|
||
| #include "ekat_expression_base.hpp" | ||
|
|
||
| namespace ekat { | ||
|
|
||
| enum class BinOp { | ||
| Plus, | ||
| Minus, | ||
| Mult, | ||
| Div, | ||
| Max, | ||
| Min | ||
| }; | ||
|
|
||
| template<typename ELeft, typename ERight, BinOp OP> | ||
| class BinaryExpression : public Expression<BinaryExpression<ELeft,ERight,OP>>{ | ||
| public: | ||
| static constexpr bool expr_l = is_expr_v<ELeft>; | ||
| static constexpr bool expr_r = is_expr_v<ERight>; | ||
|
|
||
| static_assert (expr_l or expr_r, | ||
| "[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n"); | ||
|
|
||
| BinaryExpression (const ELeft& left, const ERight& right) | ||
| : m_left(left) | ||
| , m_right(right) | ||
| { | ||
| // Nothing to do here | ||
| } | ||
|
|
||
| int num_indices () const { | ||
| if constexpr (not expr_l) { | ||
| return m_right.num_indices(); | ||
| } else if constexpr (not expr_r) { | ||
| return m_left.num_indices(); | ||
| } else { | ||
| return std::max(m_left.num_indices(),m_right.num_indices()); | ||
| } | ||
| } | ||
|
|
||
| template<typename... Args> | ||
| KOKKOS_INLINE_FUNCTION | ||
| auto eval(Args... args) const { | ||
| if constexpr (not expr_l) { | ||
| return eval_impl(m_left,m_right.eval(args...)); | ||
| } else if constexpr (not expr_r) { | ||
| return eval_impl(m_left.eval(args...),m_right); | ||
| } else { | ||
| return eval_impl(m_left.eval(args...),m_right.eval(args...)); | ||
| } | ||
| } | ||
|
|
||
| static auto ret_type () { return ELeft::ret_type() + ERight::ret_type(); } | ||
| protected: | ||
|
|
||
| template<typename T1, typename T2> | ||
| KOKKOS_INLINE_FUNCTION | ||
| auto eval_impl(const T1 l, const T2 r) const { | ||
| if constexpr (OP==BinOp::Plus) { | ||
| return l+r; | ||
| } else if constexpr (OP==BinOp::Minus) { | ||
| return l-r; | ||
| } else if constexpr (OP==BinOp::Mult) { | ||
| return l*r; | ||
| } else if constexpr (OP==BinOp::Div) { | ||
| return l/r; | ||
| } else if constexpr (OP==BinOp::Max) { | ||
| return Kokkos::max(l,r); | ||
| } else if constexpr (OP==BinOp::Min) { | ||
| return Kokkos::min(l,r); | ||
| } | ||
| } | ||
|
|
||
| ELeft m_left; | ||
| ERight m_right; | ||
| }; | ||
|
|
||
| template<typename ELeft, typename ERight, BinOp OP> | ||
| struct is_expr<BinaryExpression<ELeft,ERight,OP>> : std::true_type {}; | ||
|
|
||
| // Unary minus implemented as -1*expr | ||
| template<typename ERight> | ||
| BinaryExpression<int,ERight,BinOp::Mult> | ||
| operator- (const Expression<ERight>& r) | ||
| { | ||
| return BinaryExpression<int,ERight,BinOp::Mult>(-1,r.cast()); | ||
| } | ||
|
|
||
| // Overload arithmetic operators | ||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Plus>> | ||
| operator+ (const ELeft& l, const ERight& r) | ||
| { | ||
| return BinaryExpression<ELeft,ERight,BinOp::Plus>(l,r); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Minus>> | ||
| operator- (const ELeft& l, const ERight& r) | ||
| { | ||
| return BinaryExpression<ELeft,ERight,BinOp::Minus>(l,r); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Mult>> | ||
| operator* (const ELeft& l, const ERight& r) | ||
| { | ||
| return BinaryExpression<ELeft,ERight,BinOp::Mult>(l,r); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Div>> | ||
| operator/ (const ELeft& l, const ERight& r) | ||
| { | ||
| return BinaryExpression<ELeft,ERight,BinOp::Div>(l,r); | ||
| } | ||
|
|
||
| // Overload max/min functions | ||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Max>> | ||
| max (const ELeft& l, const ERight& r) | ||
| { | ||
| return BinaryExpression<ELeft,ERight,BinOp::Max>(l,r); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,BinaryExpression<ELeft,ERight,BinOp::Min>> | ||
| min (const ELeft& l, const ERight& r) | ||
| { | ||
| return BinaryExpression<ELeft,ERight,BinOp::Min>(l,r); | ||
| } | ||
|
|
||
| } // namespace ekat | ||
|
|
||
| #endif // EKAT_EXPRESSION_BINARY_OP_HPP |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| #ifndef EKAT_EXPRESSION_COMPARE_HPP | ||
| #define EKAT_EXPRESSION_COMPARE_HPP | ||
|
|
||
| #include "ekat_expression_base.hpp" | ||
|
|
||
| #include "ekat_std_utils.hpp" | ||
| #include "ekat_kernel_assert.hpp" | ||
| #include "ekat_assert.hpp" | ||
|
|
||
| namespace ekat { | ||
|
|
||
| enum class Comparison : int { | ||
| EQ, // == | ||
| NE, // != | ||
| GT, // > | ||
| GE, // >= | ||
| LT, // < | ||
| LE // <= | ||
| }; | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| class CmpExpression : public Expression<CmpExpression<ELeft,ERight>> { | ||
| public: | ||
| using ret_t = int; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used |
||
|
|
||
| static constexpr bool expr_l = is_expr_v<ELeft>; | ||
| static constexpr bool expr_r = is_expr_v<ERight>; | ||
|
|
||
| static_assert(expr_l or expr_r, | ||
| "[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n"); | ||
|
|
||
| CmpExpression (const ELeft& left, const ERight& right, Comparison CMP) | ||
| : m_left(left) | ||
| , m_right(right) | ||
| , m_cmp(CMP) | ||
| { | ||
| auto valid = {Comparison::EQ,Comparison::NE,Comparison::GT, | ||
| Comparison::GE,Comparison::LT,Comparison::LE}; | ||
| EKAT_REQUIRE_MSG (ekat::contains(valid,CMP), | ||
| "[CmpExpression] Error! Unrecognized/unsupported Comparison value.\n"); | ||
| } | ||
|
|
||
| int num_indices () const { | ||
| if constexpr (not expr_l) { | ||
| return m_right.num_indices(); | ||
| } else if constexpr (not expr_r) { | ||
| return m_left.num_indices(); | ||
| } else { | ||
| return std::max(m_left.num_indices(),m_right.num_indices()); | ||
| } | ||
| } | ||
|
|
||
| template<typename... Args> | ||
| KOKKOS_INLINE_FUNCTION | ||
| ret_t eval(Args... args) const { | ||
| if constexpr (not expr_l) { | ||
| switch (m_cmp) { | ||
| case Comparison::EQ: return m_left == m_right.eval(args...); | ||
| case Comparison::NE: return m_left != m_right.eval(args...); | ||
| case Comparison::GT: return m_left > m_right.eval(args...); | ||
| case Comparison::GE: return m_left >= m_right.eval(args...); | ||
| case Comparison::LT: return m_left < m_right.eval(args...); | ||
| case Comparison::LE: return m_left <= m_right.eval(args...); | ||
| default: | ||
| EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n"); | ||
| } | ||
| } else if constexpr (not expr_r) { | ||
| switch (m_cmp) { | ||
| case Comparison::EQ: return m_left.eval(args...) == m_right; | ||
| case Comparison::NE: return m_left.eval(args...) != m_right; | ||
| case Comparison::GT: return m_left.eval(args...) > m_right; | ||
| case Comparison::GE: return m_left.eval(args...) >= m_right; | ||
| case Comparison::LT: return m_left.eval(args...) < m_right; | ||
| case Comparison::LE: return m_left.eval(args...) <= m_right; | ||
| default: | ||
| EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n"); | ||
| } | ||
| } else { | ||
| switch (m_cmp) { | ||
| case Comparison::EQ: return m_left.eval(args...) == m_right.eval(args...); | ||
| case Comparison::NE: return m_left.eval(args...) != m_right.eval(args...); | ||
| case Comparison::GT: return m_left.eval(args...) > m_right.eval(args...); | ||
| case Comparison::GE: return m_left.eval(args...) >= m_right.eval(args...); | ||
| case Comparison::LT: return m_left.eval(args...) < m_right.eval(args...); | ||
| case Comparison::LE: return m_left.eval(args...) <= m_right.eval(args...); | ||
| default: | ||
| EKAT_KERNEL_ERROR_MSG ("Internal error! Unsupported cmp operator.\n"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static int ret_type () { return 0; } | ||
| protected: | ||
|
|
||
| ELeft m_left; | ||
| ERight m_right; | ||
|
|
||
| Comparison m_cmp; | ||
| }; | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| struct is_expr<CmpExpression<ELeft,ERight>> : std::true_type {}; | ||
|
|
||
| // Overload cmp operators for Expression types | ||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> | ||
| operator== (const ELeft& l, const ERight& r) | ||
| { | ||
| return CmpExpression<ELeft,ERight>(l,r,Comparison::EQ); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> | ||
| operator!= (const ELeft& l, const ERight& r) | ||
| { | ||
| return CmpExpression<ELeft,ERight>(l,r,Comparison::NE); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> | ||
| operator> (const ELeft& l, const ERight& r) | ||
| { | ||
| return CmpExpression<ELeft,ERight>(l,r,Comparison::GT); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> | ||
| operator>= (const ELeft& l, const ERight& r) | ||
| { | ||
| return CmpExpression<ELeft,ERight>(l,r,Comparison::GE); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> | ||
| operator< (const ELeft& l, const ERight& r) | ||
| { | ||
| return CmpExpression<ELeft,ERight>(l,r,Comparison::LT); | ||
| } | ||
|
|
||
| template<typename ELeft, typename ERight> | ||
| std::enable_if_t<is_expr_v<ELeft> or is_expr_v<ERight>,CmpExpression<ELeft,ERight>> | ||
| operator<= (const ELeft& l, const ERight& r) | ||
| { | ||
| return CmpExpression<ELeft,ERight>(l,r,Comparison::LE); | ||
| } | ||
|
|
||
| } // namespace ekat | ||
|
|
||
| #endif // EKAT_EXPRESSION_COMPARE_HPP | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be turned into a static fcn, as the number of indices used by an expression is fully determined by its type.
So far, this fcn is only used in the
evaluatecall, to make sure that the rank of the view passed to the evaluation fcn is compatible with the expression to evaluate. I may revisit this pattern in a follow-up pr...