@@ -47,24 +47,17 @@ inline MPI_Datatype __tmp_mpi_type_f64_3;
4747
4848#define __SYCL_TYPE_COMMIT_len2 (base_name, src_type ) \
4949 { \
50- base_name a; \
51- MPI_Aint offset_##base_name = ((size_t ) ((char *) &(a.x ()) - (char *) &(a))); \
52- MPICHECK (MPI_Type_create_struct ( \
53- 1 , &__len_vec2, &offset_##base_name, &mpi_type_##src_type, &mpi_type_##base_name)); \
50+ check_offset_validity<base_name>(); \
51+ MPICHECK (MPI_Type_contiguous (__len_vec2, mpi_type_##src_type, &mpi_type_##base_name)); \
5452 MPICHECK (MPI_Type_commit (&mpi_type_##base_name)); \
5553 shamlog_debug_mpi_ln (" SyclMpiTypes" , " init mpi type for : " #base_name); \
5654 }
5755
5856#define __SYCL_TYPE_COMMIT_len3 (base_name, src_type ) \
5957 { \
60- base_name a; \
61- MPI_Aint offset_##base_name = ((size_t ) ((char *) &(a.x ()) - (char *) &(a))); \
62- MPICHECK (MPI_Type_create_struct ( \
63- 1 , \
64- &__len_vec3, \
65- &offset_##base_name, \
66- &mpi_type_##src_type, \
67- &__tmp_mpi_type_##base_name)); \
58+ check_offset_validity<base_name>(); \
59+ MPICHECK ( \
60+ MPI_Type_contiguous (__len_vec3, mpi_type_##src_type, &__tmp_mpi_type_##base_name)); \
6861 MPICHECK (MPI_Type_create_resized ( \
6962 __tmp_mpi_type_##base_name, 0 , sizeof (base_name), &mpi_type_##base_name)); \
7063 MPICHECK (MPI_Type_commit (&mpi_type_##base_name)); \
@@ -73,35 +66,40 @@ inline MPI_Datatype __tmp_mpi_type_f64_3;
7366
7467#define __SYCL_TYPE_COMMIT_len4 (base_name, src_type ) \
7568 { \
76- base_name a; \
77- MPI_Aint offset_##base_name = ((size_t ) ((char *) &(a.x ()) - (char *) &(a))); \
78- MPICHECK (MPI_Type_create_struct ( \
79- 1 , &__len_vec4, &offset_##base_name, &mpi_type_##src_type, &mpi_type_##base_name)); \
69+ check_offset_validity<base_name>(); \
70+ MPICHECK (MPI_Type_contiguous (__len_vec4, mpi_type_##src_type, &mpi_type_##base_name)); \
8071 MPICHECK (MPI_Type_commit (&mpi_type_##base_name)); \
8172 shamlog_debug_mpi_ln (" SyclMpiTypes" , " init mpi type for : " #base_name); \
8273 }
8374
8475#define __SYCL_TYPE_COMMIT_len8 (base_name, src_type ) \
8576 { \
86- base_name a; \
87- MPI_Aint offset_##base_name = ((size_t ) ((char *) &(a.s0 ()) - (char *) &(a))); \
88- MPICHECK (MPI_Type_create_struct ( \
89- 1 , &__len_vec8, &offset_##base_name, &mpi_type_##src_type, &mpi_type_##base_name)); \
77+ check_offset_validity<base_name>(); \
78+ MPICHECK (MPI_Type_contiguous (__len_vec8, mpi_type_##src_type, &mpi_type_##base_name)); \
9079 MPICHECK (MPI_Type_commit (&mpi_type_##base_name)); \
9180 shamlog_debug_mpi_ln (" SyclMpiTypes" , " init mpi type for : " #base_name); \
9281 }
9382
9483#define __SYCL_TYPE_COMMIT_len16 (base_name, src_type ) \
9584 { \
96- base_name a; \
97- MPI_Aint offset_##base_name = ((size_t ) ((char *) &(a.s0 ()) - (char *) &(a))); \
98- MPICHECK (MPI_Type_create_struct ( \
99- 1 , &__len_vec16, &offset_##base_name, &mpi_type_##src_type, &mpi_type_##base_name)); \
85+ check_offset_validity<base_name>(); \
86+ MPICHECK (MPI_Type_contiguous (__len_vec16, mpi_type_##src_type, &mpi_type_##base_name)); \
10087 MPICHECK (MPI_Type_commit (&mpi_type_##base_name)); \
10188 shamlog_debug_mpi_ln (" SyclMpiTypes" , " init mpi type for : " #base_name); \
10289 }
10390
104- // TODO check mpi errors
91+ template <class T >
92+ void check_offset_validity () {
93+ T a{};
94+
95+ std::ptrdiff_t base = reinterpret_cast <std::ptrdiff_t >(&a);
96+ std::ptrdiff_t s0 = reinterpret_cast <std::ptrdiff_t >(&a.s0 ());
97+
98+ if (s0 - base != 0 ) {
99+ throw shambase::make_except_with_loc<std::runtime_error>(shambase::format (
100+ " Offset is not valid for type {}, base = {}, s0 = {}" , typeid (T).name (), base, s0));
101+ }
102+ }
105103
106104void create_sycl_mpi_types () {
107105
0 commit comments