Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions src/collective/in_memory_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <algorithm>
#include <functional>
#include <stdexcept>

#include "comm.h"

namespace xgboost::collective {
Expand All @@ -18,14 +20,14 @@ class AllgatherFunctor {
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
: world_size_{world_size}, rank_{rank} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) {
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
if (buffer->Empty()) {
// Resize the buffer if this is the first request.
buffer->resize(bytes * world_size_);
buffer->Resize(bytes * world_size_);
}

// Splice the input into the common buffer.
buffer->replace(rank_ * bytes, bytes, input, bytes);
buffer->Replace(rank_ * bytes, bytes, input);
}

private:
Expand All @@ -44,11 +46,11 @@ class AllgatherVFunctor {
std::map<std::size_t, std::string_view>* data)
: world_size_{world_size}, rank_{rank}, data_{data} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
data_->emplace(rank_, std::string_view{input, bytes});
if (data_->size() == static_cast<std::size_t>(world_size_)) {
for (auto const& kv : *data_) {
buffer->append(kv.second);
buffer->Append(kv.second);
}
data_->clear();
}
Expand All @@ -70,14 +72,16 @@ class AllreduceFunctor {
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
: data_type_{dataType}, operation_{operation} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) {
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
if (buffer->Empty()) {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
buffer->Assign(input, bytes);
} else {
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
CHECK_EQ(bytes % n_bytes_type, 0) << "Input size is not a multiple of its element size.";
CHECK_EQ(buffer->Size(), bytes) << "Input size differs across allreduce calls.";
// Apply the reduce_operation to the input and the buffer.
Accumulate(input, bytes / n_bytes_type, &buffer->front());
Accumulate(input, bytes, buffer);
}
}

Expand Down Expand Up @@ -128,39 +132,41 @@ class AllreduceFunctor {
}
}

void Accumulate(char const* input, std::size_t size, char* buffer) const {
void Accumulate(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
using Type = ArrayInterfaceHandler::Type;
auto data = buffer->Data();
auto size = bytes / DispatchDType(data_type_, [](auto t) { return sizeof(t); });
switch (data_type_) {
case Type::kI1:
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
Accumulate(reinterpret_cast<std::int8_t*>(data),
reinterpret_cast<std::int8_t const*>(input), size, operation_);
break;
case Type::kU1:
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
Accumulate(reinterpret_cast<std::uint8_t*>(data),
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
break;
case Type::kI4:
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
Accumulate(reinterpret_cast<std::int32_t*>(data),
reinterpret_cast<std::int32_t const*>(input), size, operation_);
break;
case Type::kU4:
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
Accumulate(reinterpret_cast<std::uint32_t*>(data),
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
break;
case Type::kI8:
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
Accumulate(reinterpret_cast<std::int64_t*>(data),
reinterpret_cast<std::int64_t const*>(input), size, operation_);
break;
case Type::kU8:
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
Accumulate(reinterpret_cast<std::uint64_t*>(data),
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
break;
case Type::kF4:
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
Accumulate(reinterpret_cast<float*>(data), reinterpret_cast<float const*>(input), size,
operation_);
break;
case Type::kF8:
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
Accumulate(reinterpret_cast<double*>(data), reinterpret_cast<double const*>(input), size,
operation_);
break;
default:
Expand All @@ -182,10 +188,10 @@ class BroadcastFunctor {

BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
if (rank_ == root_) {
// Copy the input if this is the root.
buffer->assign(input, bytes);
buffer->Assign(input, bytes);
}
}

Expand Down Expand Up @@ -246,9 +252,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
HandlerFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
if (input != output->data()) {
output->assign(input, bytes);
}
output->assign(input, bytes);
return;
}

Expand All @@ -263,7 +267,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*

if (received_ == world_size_) {
LOG(DEBUG) << functor.name << " rank " << rank << ": all requests received";
output->assign(buffer_);
output->assign(buffer_.Data(), buffer_.Size());
sent_++;
lock.unlock();
cv_.notify_all();
Expand All @@ -274,14 +278,14 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
cv_.wait(lock, [this] { return received_ == world_size_; });

LOG(DEBUG) << functor.name << " rank " << rank << ": sending reply";
output->assign(buffer_);
output->assign(buffer_.Data(), buffer_.Size());
sent_++;

if (sent_ == world_size_) {
LOG(DEBUG) << functor.name << " rank " << rank << ": all replies sent";
sent_ = 0;
received_ = 0;
buffer_.clear();
buffer_.Clear();
sequence_number_++;
lock.unlock();
cv_.notify_all();
Expand Down
56 changes: 53 additions & 3 deletions src/collective/in_memory_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,63 @@
*/
#pragma once
#include <condition_variable>
#include <cstddef>
#include <cstring>
#include <map>
#include <string>
#include <vector>

#include "../data/array_interface.h"
#include "comm.h"

namespace xgboost::collective {
class AlignedByteBuffer {
using StorageT = std::max_align_t;

public:
[[nodiscard]] bool Empty() const { return size_ == 0; }
[[nodiscard]] std::size_t Size() const { return size_; }

[[nodiscard]] char* Data() { return reinterpret_cast<char*>(storage_.data()); }
[[nodiscard]] char const* Data() const { return reinterpret_cast<char const*>(storage_.data()); }

void Clear() {
storage_.clear();
size_ = 0;
}

void Resize(std::size_t n_bytes) {
storage_.resize((n_bytes + sizeof(StorageT) - 1) / sizeof(StorageT));
size_ = n_bytes;
}

void Assign(char const* input, std::size_t n_bytes) {
this->Resize(n_bytes);
if (n_bytes != 0) {
std::memcpy(this->Data(), input, n_bytes);
}
}

void Replace(std::size_t pos, std::size_t n_bytes, char const* input) {
CHECK_LE(pos + n_bytes, size_);
if (n_bytes != 0) {
std::memcpy(this->Data() + pos, input, n_bytes);
}
}

void Append(std::string_view data) {
auto old_size = size_;
this->Resize(size_ + data.size());
if (!data.empty()) {
std::memcpy(this->Data() + old_size, data.data(), data.size());
}
}

private:
std::vector<StorageT> storage_{};
std::size_t size_{0};
};

/**
* @brief Handles collective communication primitives in memory.
*
Expand Down Expand Up @@ -116,10 +166,10 @@ class InMemoryHandler {
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
std::int32_t rank, HandlerFunctor const& functor);

std::int32_t world_size_{}; /// Number of workers.
std::int32_t world_size_{}; /// Number of workers.
std::int64_t received_{}; /// Number of calls received with the current sequence.
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
std::string buffer_{}; /// A shared common buffer.
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
AlignedByteBuffer buffer_{}; /// A shared common buffer.
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
uint64_t sequence_number_{}; /// Call sequence number.
mutable std::mutex mutex_; /// Lock.
Expand Down
Loading