diff --git a/CMakeLists.txt b/CMakeLists.txt index 1553d669..083eae20 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,7 @@ set(Omega_h_USE_pybind11_DEFAULT OFF) bob_public_dep(pybind11) if(Omega_h_USE_Kokkos) + bob_option(Omega_h_FORCE_KOKKOS_SORT "Force omega_h sort to use kokkos sortbykey." ON) if ("CUDA" IN_LIST Kokkos_DEVICES) bob_option(Omega_h_USE_CUDA "Whether to use CUDA" "ON") endif() @@ -151,6 +152,7 @@ set(Omega_h_KEY_BOOLS Omega_h_ENABLE_DEMANGLED_STACKTRACE Omega_h_DBG Omega_h_USE_Kokkos + Omega_h_FORCE_KOKKOS_SORT Omega_h_USE_OpenMP Omega_h_USE_CUDA Omega_h_USE_SYCL diff --git a/meshes/sorting/ab2b.dat b/meshes/sorting/ab2b.dat new file mode 100644 index 00000000..c1f5acba Binary files /dev/null and b/meshes/sorting/ab2b.dat differ diff --git a/meshes/sorting/goldSorted.dat b/meshes/sorting/goldSorted.dat new file mode 100644 index 00000000..4d65bbba Binary files /dev/null and b/meshes/sorting/goldSorted.dat differ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 56563590..2746b62e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -660,7 +660,7 @@ if(BUILD_TESTING) else() test_func(describe_serial 1 ./describe ${CMAKE_SOURCE_DIR}/meshes/box_3d.osh) endif() - + test_func(sort_test 1 ./sort_test ${CMAKE_SOURCE_DIR}/meshes/sorting/ab2b.dat ${CMAKE_SOURCE_DIR}/meshes/sorting/goldSorted.dat) if (Omega_h_USE_ADIOS2) osh_add_util(bp2osh) osh_add_util(osh2bp) diff --git a/src/Omega_h_sort.cpp b/src/Omega_h_sort.cpp index 3af90adc..6f9c90b4 100644 --- a/src/Omega_h_sort.cpp +++ b/src/Omega_h_sort.cpp @@ -3,7 +3,9 @@ #include #include #endif - +#if defined(OMEGA_H_FORCE_KOKKOS_SORT) +#include +#endif #include #include #include @@ -75,7 +77,7 @@ struct CompareKeySets { T y = keys_[b * N + i]; if (x != y) return x < y; } - return false; + return a < b; } }; @@ -87,8 +89,15 @@ static LOs sort_by_keys_tmpl(Read keys) { LO* begin = perm.data(); LO* end = perm.data() + n; T const* keyptr = keys.data(); +#if defined(OMEGA_H_FORCE_KOKKOS_SORT) + using ExecSpace = Kokkos::DefaultExecutionSpace; + ExecSpace space{}; + Write base(n, 0, 1); + Kokkos::Experimental::sort_by_key(space, base.view(), perm.view(), CompareKeySets(keyptr)); +#else parallel_sort>( begin, end, CompareKeySets(keyptr)); +#endif end_code(); return perm; } diff --git a/src/sort_test.cpp b/src/sort_test.cpp index c0b11bd4..eaa3c03e 100644 --- a/src/sort_test.cpp +++ b/src/sort_test.cpp @@ -1,4 +1,5 @@ #include "Omega_h_library.hpp" +#include "Omega_h_cmdline.hpp" #include "Omega_h_array_ops.hpp" #include "Omega_h_sort.hpp" #include "Omega_h_file.hpp" @@ -10,6 +11,10 @@ int main(int argc, char** argv) { using namespace Omega_h; auto lib = Library(&argc, &argv); auto world = lib.world(); + Omega_h::CmdLine cmdline; + cmdline.add_arg("array-in"); + cmdline.add_arg("gold-array-in"); + if (!cmdline.parse_final(world, &argc, argv)) return -1; { LOs a({0, 2, 0, 1}); LOs perm = sort_by_keys(a,1); @@ -37,13 +42,13 @@ int main(int argc, char** argv) { OMEGA_H_CHECK(perm == LOs({1, 0, 2})); } { - for(int i=0; i<3; i++) { - fprintf(stderr, "large test %d\n", i); Read keys, gold; - std::ifstream in("ab2b"+std::to_string(i)+".dat", std::ios::in); + auto array_in = cmdline.get("array-in"); + std::ifstream in(array_in, std::ios::in); assert(in.is_open()); binary::read_array(in, keys, false, false); - std::ifstream inGold("ba2ab"+std::to_string(i)+".dat", std::ios::in); + auto gold_array_in = cmdline.get("gold-array-in"); + std::ifstream inGold(gold_array_in, std::ios::in); assert(in.is_open()); binary::read_array(inGold, gold, false, false); in.close(); @@ -72,7 +77,6 @@ int main(int argc, char** argv) { auto permMatch = (perm == gold); fprintf(stderr, "perm matches (==) %s\n", (permMatch) ? "yes" : "no"); OMEGA_H_CHECK(permMatch); - } } return 0; }