@@ -731,6 +731,39 @@ class test_ucp_rma_sgl : public test_ucp_rma {
731731 init_sgl_ctx (ctx, std::vector<size_t >(num_elems, buf_size));
732732 }
733733
734+ void init_sgl_ctx_mixed_mem_types (sgl_ctx &ctx) {
735+ static constexpr size_t buf_size = 64 ;
736+ static constexpr size_t num = 2 ;
737+ const ucs_memory_type_t types[] = {UCS_MEMORY_TYPE_HOST ,
738+ UCS_MEMORY_TYPE_CUDA };
739+
740+ ctx.rkey_handles .resize (num);
741+ ctx.buffers .resize (num);
742+ ctx.remote_addrs .resize (num);
743+ ctx.lengths .resize (num);
744+ ctx.memhs .resize (num);
745+ ctx.rkeys .resize (num);
746+ ctx.src .reserve (num);
747+ ctx.dst .reserve (num);
748+
749+ for (size_t i = 0 ; i < num; i++) {
750+ ctx.src .emplace_back (buf_size, sender (), 0 , types[i]);
751+ ctx.dst .emplace_back (buf_size, receiver (), 0 , types[i]);
752+ }
753+
754+ for (size_t i = 0 ; i < num; i++) {
755+ ctx.src [i].memset (static_cast <uint8_t >(i + 1 ));
756+ ctx.dst [i].memset (0 );
757+ ctx.dst [i].rkey (sender (), ctx.rkey_handles [i]);
758+
759+ ctx.buffers [i] = ctx.src [i].ptr ();
760+ ctx.memhs [i] = ctx.src [i].memh ();
761+ ctx.remote_addrs [i] = reinterpret_cast <uint64_t >(ctx.dst [i].ptr ());
762+ ctx.lengths [i] = buf_size;
763+ ctx.rkeys [i] = ctx.rkey_handles [i];
764+ }
765+ }
766+
734767 static ucp_dt_local_sgl_t
735768 make_local_sgl (sgl_ctx &ctx, uint64_t field_mask) {
736769 ucp_dt_local_sgl_t sgl = {};
@@ -885,17 +918,13 @@ class test_ucp_rma_sgl : public test_ucp_rma {
885918 UCP_DT_REMOTE_SGL_FIELD_LENGTHS |
886919 UCP_DT_REMOTE_SGL_FIELD_RKEYS ;
887920
888- void expect_sgl_put_invalid_param (uint64_t local_mask,
889- uint64_t remote_mask,
890- uint64_t remote_addr = UCP_REMOTE_ADDR_INVALID ,
891- ucp_rkey_h rkey = UCP_RKEY_INVALID ,
892- uint32_t clear_param_mask = 0 ,
893- bool null_remote = false ,
894- size_t count = 2 ,
895- size_t remote_count = 0 ) {
896- sgl_ctx ctx;
897- init_sgl_ctx (ctx, count, 64 );
898-
921+ void expect_sgl_put_invalid_param_ctx (
922+ sgl_ctx &ctx, uint64_t local_mask, uint64_t remote_mask,
923+ size_t count, uint64_t remote_addr = UCP_REMOTE_ADDR_INVALID ,
924+ ucp_rkey_h rkey = UCP_RKEY_INVALID ,
925+ uint32_t clear_param_mask = 0 , bool null_remote = false ,
926+ size_t remote_count = 0 )
927+ {
899928 size_t effective_remote_count = remote_count ? remote_count : count;
900929 ucp_dt_local_sgl_t local = make_local_sgl (ctx, local_mask);
901930 ucp_dt_remote_sgl_t remote = make_remote_sgl (ctx, remote_mask);
@@ -912,6 +941,22 @@ class test_ucp_rma_sgl : public test_ucp_rma {
912941 remote_addr, rkey, ¶m);
913942 EXPECT_EQ (UCS_ERR_INVALID_PARAM , UCS_PTR_STATUS (sptr));
914943 }
944+
945+ void expect_sgl_put_invalid_param (uint64_t local_mask,
946+ uint64_t remote_mask,
947+ uint64_t remote_addr = UCP_REMOTE_ADDR_INVALID ,
948+ ucp_rkey_h rkey = UCP_RKEY_INVALID ,
949+ uint32_t clear_param_mask = 0 ,
950+ bool null_remote = false ,
951+ size_t count = 2 ,
952+ size_t remote_count = 0 )
953+ {
954+ sgl_ctx ctx;
955+ init_sgl_ctx (ctx, count, 64 );
956+ expect_sgl_put_invalid_param_ctx (ctx, local_mask, remote_mask, count,
957+ remote_addr, rkey, clear_param_mask,
958+ null_remote, remote_count);
959+ }
915960};
916961
917962UCS_TEST_P (test_ucp_rma_sgl, put_various_counts) {
@@ -1001,19 +1046,9 @@ UCS_TEST_SKIP_COND_P(test_ucp_rma_sgl, put_invalid_rkey,
10011046 !ENABLE_PARAMS_CHECK ) {
10021047 sgl_ctx ctx;
10031048 init_sgl_ctx (ctx, 2 , 64 );
1004-
1005- uint64_t local_mask = LOCAL_MASK_DEFAULT ;
1006- uint64_t remote_mask = REMOTE_MASK_DEFAULT ;
1007-
1008- ucp_dt_local_sgl_t local = make_local_sgl (ctx, local_mask);
1009- ucp_dt_remote_sgl_t remote = make_remote_sgl (ctx, remote_mask);
1010- ucp_request_param_t param = make_sgl_param (&remote, 2 );
1011-
1012- scoped_log_handler wrap_err (wrap_errors_logger);
1013- ucs_status_ptr_t sptr = ucp_put_nbx (sender ().ep (), &local, 2 ,
1014- UCP_REMOTE_ADDR_INVALID ,
1015- ctx.rkeys [0 ], ¶m);
1016- EXPECT_EQ (UCS_ERR_INVALID_PARAM , UCS_PTR_STATUS (sptr));
1049+ expect_sgl_put_invalid_param_ctx (ctx, LOCAL_MASK_DEFAULT ,
1050+ REMOTE_MASK_DEFAULT , 2 ,
1051+ UCP_REMOTE_ADDR_INVALID , ctx.rkeys [0 ]);
10171052}
10181053
10191054UCS_TEST_SKIP_COND_P (test_ucp_rma_sgl, put_missing_remote_field,
@@ -1063,6 +1098,16 @@ UCS_TEST_SKIP_COND_P(test_ucp_rma_sgl, put_count_mismatch,
10631098 0 , false , 4 , 3 );
10641099}
10651100
1101+ UCS_TEST_SKIP_COND_P (test_ucp_rma_sgl, put_mixed_mem_types,
1102+ !ENABLE_PARAMS_CHECK ||
1103+ !mem_buffer::is_mem_type_supported(
1104+ UCS_MEMORY_TYPE_CUDA )) {
1105+ sgl_ctx ctx;
1106+ init_sgl_ctx_mixed_mem_types (ctx);
1107+ expect_sgl_put_invalid_param_ctx (ctx, LOCAL_MASK_DEFAULT ,
1108+ REMOTE_MASK_DEFAULT , 2 );
1109+ }
1110+
10661111UCS_TEST_P (test_ucp_rma_sgl, put_zero_count) {
10671112 test_put_sgl (0 , 64 , true , false , true , true );
10681113}
0 commit comments