diff --git a/CMakeLists.txt b/CMakeLists.txt index 1553d669..03eabd1f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ bob_option(Omega_h_ENABLE_DEMANGLED_STACKTRACE "Add linker options to enable hum bob_option(Omega_h_DBG "Enable debug prints, stacktraces, etc." OFF) bob_option(Omega_h_GPU_CHECK "Run GPU check after each test" OFF) #the check command is hardcoded! bob_option(Omega_h_USE_CTAGS "Generate Ctags" OFF) +bob_option(Omega_h_FORCE_KOKKOS_SORT "Force omega_h sort to use kokkos sortbykey." ON) if (Omega_h_ENABLE_DEMANGLED_STACKTRACE) message(STATUS "CMAKE_BUILD_TYPE= ${CMAKE_BUILD_TYPE}") @@ -151,6 +152,7 @@ set(Omega_h_KEY_BOOLS Omega_h_ENABLE_DEMANGLED_STACKTRACE Omega_h_DBG Omega_h_USE_Kokkos + Omega_h_FORCE_KOKKOS_SORT Omega_h_USE_OpenMP Omega_h_USE_CUDA Omega_h_USE_SYCL diff --git a/src/Omega_h_sort.cpp b/src/Omega_h_sort.cpp index 3af90adc..6f9c90b4 100644 --- a/src/Omega_h_sort.cpp +++ b/src/Omega_h_sort.cpp @@ -3,7 +3,9 @@ #include #include #endif - +#if defined(OMEGA_H_FORCE_KOKKOS_SORT) +#include +#endif #include #include #include @@ -75,7 +77,7 @@ struct CompareKeySets { T y = keys_[b * N + i]; if (x != y) return x < y; } - return false; + return a < b; } }; @@ -87,8 +89,15 @@ static LOs sort_by_keys_tmpl(Read keys) { LO* begin = perm.data(); LO* end = perm.data() + n; T const* keyptr = keys.data(); +#if defined(OMEGA_H_FORCE_KOKKOS_SORT) + using ExecSpace = Kokkos::DefaultExecutionSpace; + ExecSpace space{}; + Write base(n, 0, 1); + Kokkos::Experimental::sort_by_key(space, base.view(), perm.view(), CompareKeySets(keyptr)); +#else parallel_sort>( begin, end, CompareKeySets(keyptr)); +#endif end_code(); return perm; } diff --git a/src/sort_test.cpp b/src/sort_test.cpp index c0b11bd4..a3c908b1 100644 --- a/src/sort_test.cpp +++ b/src/sort_test.cpp @@ -6,6 +6,23 @@ #include "Omega_h_for.hpp" #include +struct CompareKeySets { + Omega_h::Write const* keys_; + int N; + CompareKeySets(Omega_h::Write const* keys, int n) { + keys_ = keys; + N = n; +} + OMEGA_H_INLINE bool operator()(const Omega_h::LO& a, const Omega_h::LO& b) const { + for (int i = 0; i < N; ++i) { + Omega_h::LO x = (*keys_)[a * N + i]; + Omega_h::LO y = (*keys_)[b * N + i]; + if (x != y) return x < y; + } + return false; + } +}; + int main(int argc, char** argv) { using namespace Omega_h; auto lib = Library(&argc, &argv); @@ -39,39 +56,13 @@ int main(int argc, char** argv) { { for(int i=0; i<3; i++) { fprintf(stderr, "large test %d\n", i); - Read keys, gold; - std::ifstream in("ab2b"+std::to_string(i)+".dat", std::ios::in); - assert(in.is_open()); - binary::read_array(in, keys, false, false); - std::ifstream inGold("ba2ab"+std::to_string(i)+".dat", std::ios::in); - assert(in.is_open()); - binary::read_array(inGold, gold, false, false); - in.close(); - inGold.close(); - LOs perm = sort_by_keys(keys); - auto perm_hr = HostRead(perm); - auto gold_hr = HostRead(gold); - bool isSame = true; - assert(perm_hr.size() == gold_hr.size()); - for(int j=0; j cnt({0}); - auto countNEQ = OMEGA_H_LAMBDA(int i) { - if(perm[i] != gold[i]) { - atomic_increment(&cnt[0]); - } - }; - parallel_for(perm.size(), countNEQ); - auto cnt_hr = HostRead(cnt); - fprintf(stderr, "device matches %s\n", (cnt_hr[0] == 0) ? "yes" : "no"); - auto permMatch = (perm == gold); - fprintf(stderr, "perm matches (==) %s\n", (permMatch) ? "yes" : "no"); - OMEGA_H_CHECK(permMatch); + Write random_keys(); + auto n = 1; + //auto n = divide_no_remainder(random_keys.size(), i); + Write gold_perm(n, 0, 1); + LO* begin = gold_perm.data(); + LO* end = gold_perm.data() + n; + std::stable_sort(begin, end, CompareKeySets(&random_keys, i)); } } return 0;