Skip to content

Commit 24b9d43

Browse files
committed
GTEST/UCP: Add mixed mem types test
1 parent 4c3114a commit 24b9d43

1 file changed

Lines changed: 69 additions & 24 deletions

File tree

test/gtest/ucp/test_ucp_rma.cc

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,39 @@ class test_ucp_rma_sgl : public test_ucp_rma {
731731
init_sgl_ctx(ctx, std::vector<size_t>(num_elems, buf_size));
732732
}
733733

734+
void init_sgl_ctx_mixed_mem_types(sgl_ctx &ctx) {
735+
static constexpr size_t buf_size = 64;
736+
static constexpr size_t num = 2;
737+
const ucs_memory_type_t types[] = {UCS_MEMORY_TYPE_HOST,
738+
UCS_MEMORY_TYPE_CUDA};
739+
740+
ctx.rkey_handles.resize(num);
741+
ctx.buffers.resize(num);
742+
ctx.remote_addrs.resize(num);
743+
ctx.lengths.resize(num);
744+
ctx.memhs.resize(num);
745+
ctx.rkeys.resize(num);
746+
ctx.src.reserve(num);
747+
ctx.dst.reserve(num);
748+
749+
for (size_t i = 0; i < num; i++) {
750+
ctx.src.emplace_back(buf_size, sender(), 0, types[i]);
751+
ctx.dst.emplace_back(buf_size, receiver(), 0, types[i]);
752+
}
753+
754+
for (size_t i = 0; i < num; i++) {
755+
ctx.src[i].memset(static_cast<uint8_t>(i + 1));
756+
ctx.dst[i].memset(0);
757+
ctx.dst[i].rkey(sender(), ctx.rkey_handles[i]);
758+
759+
ctx.buffers[i] = ctx.src[i].ptr();
760+
ctx.memhs[i] = ctx.src[i].memh();
761+
ctx.remote_addrs[i] = reinterpret_cast<uint64_t>(ctx.dst[i].ptr());
762+
ctx.lengths[i] = buf_size;
763+
ctx.rkeys[i] = ctx.rkey_handles[i];
764+
}
765+
}
766+
734767
static ucp_dt_local_sgl_t
735768
make_local_sgl(sgl_ctx &ctx, uint64_t field_mask) {
736769
ucp_dt_local_sgl_t sgl = {};
@@ -885,17 +918,13 @@ class test_ucp_rma_sgl : public test_ucp_rma {
885918
UCP_DT_REMOTE_SGL_FIELD_LENGTHS |
886919
UCP_DT_REMOTE_SGL_FIELD_RKEYS;
887920

888-
void expect_sgl_put_invalid_param(uint64_t local_mask,
889-
uint64_t remote_mask,
890-
uint64_t remote_addr = UCP_REMOTE_ADDR_INVALID,
891-
ucp_rkey_h rkey = UCP_RKEY_INVALID,
892-
uint32_t clear_param_mask = 0,
893-
bool null_remote = false,
894-
size_t count = 2,
895-
size_t remote_count = 0) {
896-
sgl_ctx ctx;
897-
init_sgl_ctx(ctx, count, 64);
898-
921+
void expect_sgl_put_invalid_param_ctx(
922+
sgl_ctx &ctx, uint64_t local_mask, uint64_t remote_mask,
923+
size_t count, uint64_t remote_addr = UCP_REMOTE_ADDR_INVALID,
924+
ucp_rkey_h rkey = UCP_RKEY_INVALID,
925+
uint32_t clear_param_mask = 0, bool null_remote = false,
926+
size_t remote_count = 0)
927+
{
899928
size_t effective_remote_count = remote_count ? remote_count : count;
900929
ucp_dt_local_sgl_t local = make_local_sgl(ctx, local_mask);
901930
ucp_dt_remote_sgl_t remote = make_remote_sgl(ctx, remote_mask);
@@ -912,6 +941,22 @@ class test_ucp_rma_sgl : public test_ucp_rma {
912941
remote_addr, rkey, &param);
913942
EXPECT_EQ(UCS_ERR_INVALID_PARAM, UCS_PTR_STATUS(sptr));
914943
}
944+
945+
void expect_sgl_put_invalid_param(uint64_t local_mask,
946+
uint64_t remote_mask,
947+
uint64_t remote_addr = UCP_REMOTE_ADDR_INVALID,
948+
ucp_rkey_h rkey = UCP_RKEY_INVALID,
949+
uint32_t clear_param_mask = 0,
950+
bool null_remote = false,
951+
size_t count = 2,
952+
size_t remote_count = 0)
953+
{
954+
sgl_ctx ctx;
955+
init_sgl_ctx(ctx, count, 64);
956+
expect_sgl_put_invalid_param_ctx(ctx, local_mask, remote_mask, count,
957+
remote_addr, rkey, clear_param_mask,
958+
null_remote, remote_count);
959+
}
915960
};
916961

917962
UCS_TEST_P(test_ucp_rma_sgl, put_various_counts) {
@@ -1001,19 +1046,9 @@ UCS_TEST_SKIP_COND_P(test_ucp_rma_sgl, put_invalid_rkey,
10011046
!ENABLE_PARAMS_CHECK) {
10021047
sgl_ctx ctx;
10031048
init_sgl_ctx(ctx, 2, 64);
1004-
1005-
uint64_t local_mask = LOCAL_MASK_DEFAULT;
1006-
uint64_t remote_mask = REMOTE_MASK_DEFAULT;
1007-
1008-
ucp_dt_local_sgl_t local = make_local_sgl(ctx, local_mask);
1009-
ucp_dt_remote_sgl_t remote = make_remote_sgl(ctx, remote_mask);
1010-
ucp_request_param_t param = make_sgl_param(&remote, 2);
1011-
1012-
scoped_log_handler wrap_err(wrap_errors_logger);
1013-
ucs_status_ptr_t sptr = ucp_put_nbx(sender().ep(), &local, 2,
1014-
UCP_REMOTE_ADDR_INVALID,
1015-
ctx.rkeys[0], &param);
1016-
EXPECT_EQ(UCS_ERR_INVALID_PARAM, UCS_PTR_STATUS(sptr));
1049+
expect_sgl_put_invalid_param_ctx(ctx, LOCAL_MASK_DEFAULT,
1050+
REMOTE_MASK_DEFAULT, 2,
1051+
UCP_REMOTE_ADDR_INVALID, ctx.rkeys[0]);
10171052
}
10181053

10191054
UCS_TEST_SKIP_COND_P(test_ucp_rma_sgl, put_missing_remote_field,
@@ -1063,6 +1098,16 @@ UCS_TEST_SKIP_COND_P(test_ucp_rma_sgl, put_count_mismatch,
10631098
0, false, 4, 3);
10641099
}
10651100

1101+
UCS_TEST_SKIP_COND_P(test_ucp_rma_sgl, put_mixed_mem_types,
1102+
!ENABLE_PARAMS_CHECK ||
1103+
!mem_buffer::is_mem_type_supported(
1104+
UCS_MEMORY_TYPE_CUDA)) {
1105+
sgl_ctx ctx;
1106+
init_sgl_ctx_mixed_mem_types(ctx);
1107+
expect_sgl_put_invalid_param_ctx(ctx, LOCAL_MASK_DEFAULT,
1108+
REMOTE_MASK_DEFAULT, 2);
1109+
}
1110+
10661111
UCS_TEST_P(test_ucp_rma_sgl, put_zero_count) {
10671112
test_put_sgl(0, 64, true, false, true, true);
10681113
}

0 commit comments

Comments
 (0)