diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc index 37be3f9c7127..a26334049cc8 100644 --- a/src/collective/in_memory_handler.cc +++ b/src/collective/in_memory_handler.cc @@ -5,6 +5,8 @@ #include #include +#include + #include "comm.h" namespace xgboost::collective { @@ -18,14 +20,14 @@ class AllgatherFunctor { AllgatherFunctor(std::int32_t world_size, std::int32_t rank) : world_size_{world_size}, rank_{rank} {} - void operator()(char const* input, std::size_t bytes, std::string* buffer) const { - if (buffer->empty()) { + void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const { + if (buffer->Empty()) { // Resize the buffer if this is the first request. - buffer->resize(bytes * world_size_); + buffer->Resize(bytes * world_size_); } // Splice the input into the common buffer. - buffer->replace(rank_ * bytes, bytes, input, bytes); + buffer->Replace(rank_ * bytes, bytes, input); } private: @@ -44,11 +46,11 @@ class AllgatherVFunctor { std::map* data) : world_size_{world_size}, rank_{rank}, data_{data} {} - void operator()(char const* input, std::size_t bytes, std::string* buffer) const { + void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const { data_->emplace(rank_, std::string_view{input, bytes}); if (data_->size() == static_cast(world_size_)) { for (auto const& kv : *data_) { - buffer->append(kv.second); + buffer->Append(kv.second); } data_->clear(); } @@ -70,14 +72,16 @@ class AllreduceFunctor { AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation) : data_type_{dataType}, operation_{operation} {} - void operator()(char const* input, std::size_t bytes, std::string* buffer) const { - if (buffer->empty()) { + void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const { + if (buffer->Empty()) { // Copy the input if this is the first request. - buffer->assign(input, bytes); + buffer->Assign(input, bytes); } else { auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); }); + CHECK_EQ(bytes % n_bytes_type, 0) << "Input size is not a multiple of its element size."; + CHECK_EQ(buffer->Size(), bytes) << "Input size differs across allreduce calls."; // Apply the reduce_operation to the input and the buffer. - Accumulate(input, bytes / n_bytes_type, &buffer->front()); + Accumulate(input, bytes, buffer); } } @@ -128,39 +132,41 @@ class AllreduceFunctor { } } - void Accumulate(char const* input, std::size_t size, char* buffer) const { + void Accumulate(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const { using Type = ArrayInterfaceHandler::Type; + auto data = buffer->Data(); + auto size = bytes / DispatchDType(data_type_, [](auto t) { return sizeof(t); }); switch (data_type_) { case Type::kI1: - Accumulate(reinterpret_cast(buffer), + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kU1: - Accumulate(reinterpret_cast(buffer), + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kI4: - Accumulate(reinterpret_cast(buffer), + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kU4: - Accumulate(reinterpret_cast(buffer), + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kI8: - Accumulate(reinterpret_cast(buffer), + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kU8: - Accumulate(reinterpret_cast(buffer), + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kF4: - Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; case Type::kF8: - Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, + Accumulate(reinterpret_cast(data), reinterpret_cast(input), size, operation_); break; default: @@ -182,10 +188,10 @@ class BroadcastFunctor { BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {} - void operator()(char const* input, std::size_t bytes, std::string* buffer) const { + void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const { if (rank_ == root_) { // Copy the input if this is the root. - buffer->assign(input, bytes); + buffer->Assign(input, bytes); } } @@ -246,9 +252,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* HandlerFunctor const& functor) { // Pass through if there is only 1 client. if (world_size_ == 1) { - if (input != output->data()) { - output->assign(input, bytes); - } + output->assign(input, bytes); return; } @@ -263,7 +267,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* if (received_ == world_size_) { LOG(DEBUG) << functor.name << " rank " << rank << ": all requests received"; - output->assign(buffer_); + output->assign(buffer_.Data(), buffer_.Size()); sent_++; lock.unlock(); cv_.notify_all(); @@ -274,14 +278,14 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* cv_.wait(lock, [this] { return received_ == world_size_; }); LOG(DEBUG) << functor.name << " rank " << rank << ": sending reply"; - output->assign(buffer_); + output->assign(buffer_.Data(), buffer_.Size()); sent_++; if (sent_ == world_size_) { LOG(DEBUG) << functor.name << " rank " << rank << ": all replies sent"; sent_ = 0; received_ = 0; - buffer_.clear(); + buffer_.Clear(); sequence_number_++; lock.unlock(); cv_.notify_all(); diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index 7c3465d08b8b..ebd6fe6a763e 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -3,13 +3,63 @@ */ #pragma once #include +#include +#include #include #include +#include #include "../data/array_interface.h" #include "comm.h" namespace xgboost::collective { +class AlignedByteBuffer { + using StorageT = std::max_align_t; + + public: + [[nodiscard]] bool Empty() const { return size_ == 0; } + [[nodiscard]] std::size_t Size() const { return size_; } + + [[nodiscard]] char* Data() { return reinterpret_cast(storage_.data()); } + [[nodiscard]] char const* Data() const { return reinterpret_cast(storage_.data()); } + + void Clear() { + storage_.clear(); + size_ = 0; + } + + void Resize(std::size_t n_bytes) { + storage_.resize((n_bytes + sizeof(StorageT) - 1) / sizeof(StorageT)); + size_ = n_bytes; + } + + void Assign(char const* input, std::size_t n_bytes) { + this->Resize(n_bytes); + if (n_bytes != 0) { + std::memcpy(this->Data(), input, n_bytes); + } + } + + void Replace(std::size_t pos, std::size_t n_bytes, char const* input) { + CHECK_LE(pos + n_bytes, size_); + if (n_bytes != 0) { + std::memcpy(this->Data() + pos, input, n_bytes); + } + } + + void Append(std::string_view data) { + auto old_size = size_; + this->Resize(size_ + data.size()); + if (!data.empty()) { + std::memcpy(this->Data() + old_size, data.data(), data.size()); + } + } + + private: + std::vector storage_{}; + std::size_t size_{0}; +}; + /** * @brief Handles collective communication primitives in memory. * @@ -116,10 +166,10 @@ class InMemoryHandler { void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number, std::int32_t rank, HandlerFunctor const& functor); - std::int32_t world_size_{}; /// Number of workers. + std::int32_t world_size_{}; /// Number of workers. std::int64_t received_{}; /// Number of calls received with the current sequence. - std::int64_t sent_{}; /// Number of calls completed with the current sequence. - std::string buffer_{}; /// A shared common buffer. + std::int64_t sent_{}; /// Number of calls completed with the current sequence. + AlignedByteBuffer buffer_{}; /// A shared common buffer. std::map aux_{}; /// A shared auxiliary map. uint64_t sequence_number_{}; /// Call sequence number. mutable std::mutex mutex_; /// Lock.