25
25
#include < numeric>
26
26
#include < concepts>
27
27
28
+ #include < mpi/mpi.hpp>
29
+
28
30
#include " address_space.hpp"
29
31
#include " ../macros.hpp"
30
32
@@ -47,6 +49,7 @@ namespace nda::mem {
47
49
struct blk_t {
48
50
char *ptr = nullptr ;
49
51
size_t s = 0 ;
52
+ void *userdata = nullptr ;
50
53
};
51
54
52
55
// ------------------------- Malloc allocator ----------------------------
@@ -335,4 +338,47 @@ namespace nda::mem {
335
338
auto const &histogram () const noexcept { return hist; }
336
339
};
337
340
341
+ // ------------------------- MPI shared memory allocator ----------------------------
342
+ //
343
+ // Allocates the same amount of memory on each shared memory island
344
+ //
345
+ class shared_allocator {
346
+ public:
347
+ shared_allocator () = default ;
348
+ shared_allocator (shared_allocator const &) = delete ;
349
+ shared_allocator (shared_allocator &&) = default ;
350
+ shared_allocator &operator =(shared_allocator const &) = delete ;
351
+ shared_allocator &operator =(shared_allocator &&) = default ;
352
+
353
+ static constexpr auto address_space = Host;
354
+
355
+ static blk_t allocate (size_t s) noexcept {
356
+ return allocate (s, mpi::communicator{}.split_shared ());
357
+ }
358
+
359
+ static blk_t allocate (MPI_Aint s, mpi::shared_communicator shm) noexcept {
360
+ auto *win = new mpi::shared_window<char >{shm, shm.rank () == 0 ? s : 0 };
361
+ return {(char *)win->base (0 ), (std::size_t )s, (void *)win}; // NOLINT
362
+ }
363
+
364
+ static blk_t allocate_zero (size_t s) noexcept {
365
+ return allocate_zero (s, mpi::communicator{}.split_shared ());
366
+ }
367
+
368
+ static blk_t allocate_zero (MPI_Aint s, mpi::shared_communicator shm) noexcept {
369
+ auto *win = new mpi::shared_window<char >{shm, shm.rank () == 0 ? s : 0 };
370
+ char *baseptr = win->base (0 );
371
+ win->fence ();
372
+ if (shm.rank () == 0 ) {
373
+ std::memset (baseptr, 0 , s);
374
+ }
375
+ win->fence ();
376
+ return {baseptr, (std::size_t )s, (void *)win}; // NOLINT
377
+ }
378
+
379
+ static void deallocate (blk_t b) noexcept {
380
+ delete static_cast <mpi::shared_window<char >*>(b.userdata );
381
+ }
382
+ };
383
+
338
384
} // namespace nda::mem
0 commit comments