From 945807d758600bb0ccc5fc4bc352ed9701c93df1 Mon Sep 17 00:00:00 2001 From: ckegel <57967583+CKegel@users.noreply.github.com> Date: Tue, 18 Feb 2025 15:01:02 -0500 Subject: [PATCH 1/4] Add kokkos sort_by_key stable sort implementation. --- CMakeLists.txt | 1 + src/Omega_h_sort.cpp | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1553d6698..641958b07 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ bob_option(Omega_h_ENABLE_DEMANGLED_STACKTRACE "Add linker options to enable hum bob_option(Omega_h_DBG "Enable debug prints, stacktraces, etc." OFF) bob_option(Omega_h_GPU_CHECK "Run GPU check after each test" OFF) #the check command is hardcoded! bob_option(Omega_h_USE_CTAGS "Generate Ctags" OFF) +bob_option(Omega_h_FORCE_KOKKOS_SORT "Force omega_h sort to use kokkos sortbykey." ON) if (Omega_h_ENABLE_DEMANGLED_STACKTRACE) message(STATUS "CMAKE_BUILD_TYPE= ${CMAKE_BUILD_TYPE}") diff --git a/src/Omega_h_sort.cpp b/src/Omega_h_sort.cpp index 3af90adc3..f2eb5fd96 100644 --- a/src/Omega_h_sort.cpp +++ b/src/Omega_h_sort.cpp @@ -3,7 +3,9 @@ #include #include #endif - +#if defined(OMEGA_H_FORCE_KOKKOS_SORT) +#include +#endif #include #include #include @@ -87,8 +89,15 @@ static LOs sort_by_keys_tmpl(Read keys) { LO* begin = perm.data(); LO* end = perm.data() + n; T const* keyptr = keys.data(); +#if defined(OMEGA_H_FORCE_KOKKOS_SORT) + using ExecSpace = Kokkos::DefaultExecutionSpace; + ExecSpace space{}; + Write base(n, 0, 1); + Kokkos::Experimental::sort_by_key(space, base.view(), perm.view(), CompareKeySets(keyptr)); +#else parallel_sort>( begin, end, CompareKeySets(keyptr)); +#endif end_code(); return perm; } From 40046df23764583dd1cd1e107f9aa76b2e7e4708 Mon Sep 17 00:00:00 2001 From: ckegel <57967583+CKegel@users.noreply.github.com> Date: Tue, 18 Feb 2025 15:27:27 -0500 Subject: [PATCH 2/4] Add FORCE_KOKKOS_SORT compiler directive to the Omega_h bools list. --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 641958b07..03eabd1fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -152,6 +152,7 @@ set(Omega_h_KEY_BOOLS Omega_h_ENABLE_DEMANGLED_STACKTRACE Omega_h_DBG Omega_h_USE_Kokkos + Omega_h_FORCE_KOKKOS_SORT Omega_h_USE_OpenMP Omega_h_USE_CUDA Omega_h_USE_SYCL From 8476d0ae5cffb09fe2d5cd365dfe1b6ae426cca5 Mon Sep 17 00:00:00 2001 From: ckegel <57967583+CKegel@users.noreply.github.com> Date: Tue, 25 Feb 2025 14:47:15 -0500 Subject: [PATCH 3/4] Fallback to original tuple index in sort. --- src/Omega_h_sort.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Omega_h_sort.cpp b/src/Omega_h_sort.cpp index f2eb5fd96..6f9c90b47 100644 --- a/src/Omega_h_sort.cpp +++ b/src/Omega_h_sort.cpp @@ -77,7 +77,7 @@ struct CompareKeySets { T y = keys_[b * N + i]; if (x != y) return x < y; } - return false; + return a < b; } }; From a0345401f1080aff7993b03814a39f5c2043bdc6 Mon Sep 17 00:00:00 2001 From: ckegel <57967583+CKegel@users.noreply.github.com> Date: Fri, 21 Mar 2025 14:53:30 -0400 Subject: [PATCH 4/4] Add Omega_h stable sort comparator to sort_test. --- src/sort_test.cpp | 57 ++++++++++++++++++++--------------------------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/src/sort_test.cpp b/src/sort_test.cpp index c0b11bd44..a3c908b1b 100644 --- a/src/sort_test.cpp +++ b/src/sort_test.cpp @@ -6,6 +6,23 @@ #include "Omega_h_for.hpp" #include +struct CompareKeySets { + Omega_h::Write const* keys_; + int N; + CompareKeySets(Omega_h::Write const* keys, int n) { + keys_ = keys; + N = n; +} + OMEGA_H_INLINE bool operator()(const Omega_h::LO& a, const Omega_h::LO& b) const { + for (int i = 0; i < N; ++i) { + Omega_h::LO x = (*keys_)[a * N + i]; + Omega_h::LO y = (*keys_)[b * N + i]; + if (x != y) return x < y; + } + return false; + } +}; + int main(int argc, char** argv) { using namespace Omega_h; auto lib = Library(&argc, &argv); @@ -39,39 +56,13 @@ int main(int argc, char** argv) { { for(int i=0; i<3; i++) { fprintf(stderr, "large test %d\n", i); - Read keys, gold; - std::ifstream in("ab2b"+std::to_string(i)+".dat", std::ios::in); - assert(in.is_open()); - binary::read_array(in, keys, false, false); - std::ifstream inGold("ba2ab"+std::to_string(i)+".dat", std::ios::in); - assert(in.is_open()); - binary::read_array(inGold, gold, false, false); - in.close(); - inGold.close(); - LOs perm = sort_by_keys(keys); - auto perm_hr = HostRead(perm); - auto gold_hr = HostRead(gold); - bool isSame = true; - assert(perm_hr.size() == gold_hr.size()); - for(int j=0; j cnt({0}); - auto countNEQ = OMEGA_H_LAMBDA(int i) { - if(perm[i] != gold[i]) { - atomic_increment(&cnt[0]); - } - }; - parallel_for(perm.size(), countNEQ); - auto cnt_hr = HostRead(cnt); - fprintf(stderr, "device matches %s\n", (cnt_hr[0] == 0) ? "yes" : "no"); - auto permMatch = (perm == gold); - fprintf(stderr, "perm matches (==) %s\n", (permMatch) ? "yes" : "no"); - OMEGA_H_CHECK(permMatch); + Write random_keys(); + auto n = 1; + //auto n = divide_no_remainder(random_keys.size(), i); + Write gold_perm(n, 0, 1); + LO* begin = gold_perm.data(); + LO* end = gold_perm.data() + n; + std::stable_sort(begin, end, CompareKeySets(&random_keys, i)); } } return 0;