Skip to content

Commit c9da285

Browse files
authored
Merge pull request #169 from nicoleavans/hip-test
2 parents e1be009 + 89bf530 commit c9da285

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

unit_tests/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,20 @@ if(KOKKOSCOMM_ENABLE_MPI)
9595
${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} 2 ./test-mpi-cuda-sendrecv
9696
)
9797
endif()
98+
99+
if (Kokkos_ENABLE_HIP)
100+
add_executable(test-mpi-hip-sendrecv)
101+
target_sources(test-mpi-hip-sendrecv PRIVATE mpi/test_mpi_hip_sendrecv.cpp)
102+
target_link_libraries(test-mpi-hip-sendrecv MPI::MPI_CXX)
103+
find_package(hip REQUIRED)
104+
target_link_libraries(test-mpi-hip-sendrecv hip::host)
105+
target_compile_options(test-mpi-hip-sendrecv PUBLIC -x hip)
106+
add_test(
107+
NAME test-mpi-hip-sendrecv
108+
COMMAND
109+
${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} 2 ./test-mpi-hip-sendrecv
110+
)
111+
endif()
98112
endif()
99113

100114
# Tests using the MPI communication space, but not linking with MPI itself
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2025) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
/*
18+
This test verifies that GPU-aware MPI is operating as expected if HIP is enabled.
19+
If not, the HIP error is specifically reported using hipGetErrorString().
20+
*/
21+
22+
#include <iostream>
23+
24+
#include <mpi.h>
25+
#include <hip/hip_runtime.h>
26+
27+
// Macro to check for HIP errors
28+
#define HIP(call) \
29+
do { \
30+
hipError_t err = call; \
31+
if (err != hipSuccess) { \
32+
std::cerr << "HIP error in file '" << __FILE__ << "' in line " << __LINE__ << ": " << hipGetErrorString(err) \
33+
<< " (" << err << ")" << std::endl; \
34+
exit(EXIT_FAILURE); \
35+
} \
36+
} while (0)
37+
38+
namespace {
39+
40+
template <typename Scalar>
41+
__global__ void init_array(Scalar* a, int sz) {
42+
int i = blockIdx.x * blockDim.x + threadIdx.x;
43+
if (i < sz) {
44+
a[i] = Scalar(i);
45+
}
46+
}
47+
48+
template <typename Scalar>
49+
__global__ void check_array(const Scalar* a, int sz, int* errs) {
50+
int i = blockIdx.x * blockDim.x + threadIdx.x;
51+
if (i < sz && a[i] != Scalar(i)) {
52+
atomicAdd(errs, 1);
53+
printf("ERROR: a[%d](%p) = %f != %f\n", int(i), a + i, double(a[i]), double(i));
54+
}
55+
}
56+
57+
// get the built-in MPI Datatype for int32_t, int64_t, or float
58+
template <typename Scalar>
59+
MPI_Datatype mpi_type() {
60+
if constexpr (std::is_same_v<Scalar, int32_t>) {
61+
return MPI_INT;
62+
} else if constexpr (std::is_same_v<Scalar, int64_t>) {
63+
return MPI_LONG_LONG;
64+
} else if constexpr (std::is_same_v<Scalar, float>) {
65+
return MPI_FLOAT;
66+
} else {
67+
static_assert(std::is_void_v<Scalar>, "unsupported type");
68+
}
69+
}
70+
71+
// return ptr + offset (in bytes)
72+
void* byte_offset(void* ptr, std::size_t offset) {
73+
return reinterpret_cast<void*>(reinterpret_cast<std::uintptr_t>(ptr) + offset);
74+
}
75+
76+
template <typename Scalar>
77+
void run_test(int num_elements, int alignment) {
78+
int rank, size;
79+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
80+
MPI_Comm_size(MPI_COMM_WORLD, &size);
81+
82+
if (0 == rank) {
83+
// get a string name of the Scalar type
84+
const char* name;
85+
if constexpr (std::is_same_v<Scalar, int32_t>) {
86+
name = "int32_t";
87+
} else if constexpr (std::is_same_v<Scalar, float>) {
88+
name = "float";
89+
} else if constexpr (std::is_same_v<Scalar, int64_t>) {
90+
name = "int64_t";
91+
} else {
92+
static_assert(std::is_void_v<Scalar>, "unsupported type");
93+
}
94+
95+
std::cerr << __FILE__ << ":" << __LINE__ << " test: " << num_elements << " " << name << " " << alignment << "\n";
96+
}
97+
98+
if (2 != size) {
99+
std::cerr << "test requires 2 processes, got " << size << "\n";
100+
MPI_Abort(MPI_COMM_WORLD, 1);
101+
}
102+
103+
Scalar* d_recv_buf;
104+
int* d_errs;
105+
int h_errs = 0;
106+
107+
size_t buffer_size = num_elements * sizeof(Scalar) + alignment;
108+
109+
HIP(hipMalloc(&d_recv_buf, buffer_size));
110+
HIP(hipMalloc(&d_errs, sizeof(int)));
111+
HIP(hipMemset(d_errs, 0, sizeof(int)));
112+
Scalar* recv_buf = reinterpret_cast<Scalar*>(byte_offset(d_recv_buf, alignment));
113+
114+
if (rank == 0) {
115+
Scalar* d_send_buf;
116+
HIP(hipMalloc(&d_send_buf, buffer_size));
117+
Scalar* send_buf = reinterpret_cast<Scalar*>(byte_offset(d_send_buf, alignment));
118+
init_array<<<(num_elements + 255) / 256, 256>>>(send_buf, num_elements);
119+
HIP(hipDeviceSynchronize());
120+
MPI_Send(send_buf, num_elements, mpi_type<Scalar>(), 1, 0, MPI_COMM_WORLD);
121+
HIP(hipFree(d_send_buf));
122+
} else if (rank == 1) {
123+
MPI_Recv(recv_buf, num_elements, mpi_type<Scalar>(), 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
124+
check_array<<<(num_elements + 255) / 256, 256>>>(recv_buf, num_elements, d_errs);
125+
HIP(hipDeviceSynchronize());
126+
}
127+
128+
HIP(hipMemcpy(&h_errs, d_errs, sizeof(int), hipMemcpyDeviceToHost));
129+
130+
if (h_errs > 0) {
131+
std::cerr << "[" << rank << "] " << __FILE__ << ":" << __LINE__ << " h_errs=" << h_errs << "\n";
132+
MPI_Abort(MPI_COMM_WORLD, 1);
133+
}
134+
135+
HIP(hipFree(d_recv_buf));
136+
HIP(hipFree(d_errs));
137+
}
138+
139+
template <typename Scalar>
140+
void run_test() {
141+
int offset = 128;
142+
for (size_t _ : {0, 1, 2}) { // run a few times
143+
for (size_t n : {113, 16, 8, 4, 2, 1}) {
144+
MPI_Barrier(MPI_COMM_WORLD);
145+
run_test<Scalar>(n, offset);
146+
MPI_Barrier(MPI_COMM_WORLD);
147+
}
148+
}
149+
}
150+
151+
} // namespace
152+
153+
int main(int argc, char** argv) {
154+
MPI_Init(&argc, &argv);
155+
run_test<int32_t>();
156+
run_test<int64_t>();
157+
run_test<float>();
158+
MPI_Finalize();
159+
return 0;
160+
}

0 commit comments

Comments
 (0)