forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathin_memory_handler.cc
More file actions
294 lines (254 loc) · 9.73 KB
/
in_memory_handler.cc
File metadata and controls
294 lines (254 loc) · 9.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "in_memory_handler.h"
#include <algorithm>
#include <functional>
#include <stdexcept>
#include "comm.h"
namespace xgboost::collective {
/**
* @brief Functor for allgather.
*/
class AllgatherFunctor {
public:
std::string const name{"Allgather"};
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
: world_size_{world_size}, rank_{rank} {}
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
if (buffer->empty()) {
// Resize the buffer if this is the first request.
buffer->resize(bytes * world_size_);
}
// Splice the input into the common buffer.
buffer->replace(rank_ * bytes, bytes, input);
}
private:
std::int32_t world_size_;
std::int32_t rank_;
};
/**
* @brief Functor for variable-length allgather.
*/
class AllgatherVFunctor {
public:
std::string const name{"AllgatherV"};
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
std::map<std::size_t, std::string_view>* data)
: world_size_{world_size}, rank_{rank}, data_{data} {}
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
data_->emplace(rank_, std::string_view{input, bytes});
if (data_->size() == static_cast<std::size_t>(world_size_)) {
for (auto const& kv : *data_) {
buffer->append(kv.second);
}
data_->clear();
}
}
private:
std::int32_t world_size_;
std::int32_t rank_;
std::map<std::size_t, std::string_view>* data_;
};
/**
* @brief Functor for allreduce.
*/
class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
: data_type_{dataType}, operation_{operation} {}
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
if (buffer->empty()) {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
CHECK_EQ(bytes % n_bytes_type, 0) << "Input size is not a multiple of its element size.";
CHECK_EQ(buffer->size(), bytes) << "Input size differs across allreduce calls.";
// Apply the reduce_operation to the input and the buffer.
Accumulate(input, bytes, buffer);
}
}
private:
template <class T, std::enable_if_t<std::is_integral_v<T>>* = nullptr>
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Op::kBitwiseAND:
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
break;
case Op::kBitwiseOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
break;
case Op::kBitwiseXOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
}
template <class T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Op::kMax:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::max(a, b); });
break;
case Op::kMin:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::min(a, b); });
break;
case Op::kSum:
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
break;
case Op::kBitwiseAND:
case Op::kBitwiseOR:
case Op::kBitwiseXOR:
AccumulateBitwise(buffer, input, size, reduce_operation);
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
}
void Accumulate(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
using Type = ArrayInterfaceHandler::Type;
auto data = buffer->data();
auto size = bytes / DispatchDType(data_type_, [](auto t) { return sizeof(t); });
switch (data_type_) {
case Type::kI1:
Accumulate(reinterpret_cast<std::int8_t*>(data),
reinterpret_cast<std::int8_t const*>(input), size, operation_);
break;
case Type::kU1:
Accumulate(reinterpret_cast<std::uint8_t*>(data),
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
break;
case Type::kI4:
Accumulate(reinterpret_cast<std::int32_t*>(data),
reinterpret_cast<std::int32_t const*>(input), size, operation_);
break;
case Type::kU4:
Accumulate(reinterpret_cast<std::uint32_t*>(data),
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
break;
case Type::kI8:
Accumulate(reinterpret_cast<std::int64_t*>(data),
reinterpret_cast<std::int64_t const*>(input), size, operation_);
break;
case Type::kU8:
Accumulate(reinterpret_cast<std::uint64_t*>(data),
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
break;
case Type::kF4:
Accumulate(reinterpret_cast<float*>(data), reinterpret_cast<float const*>(input), size,
operation_);
break;
case Type::kF8:
Accumulate(reinterpret_cast<double*>(data), reinterpret_cast<double const*>(input), size,
operation_);
break;
default:
throw std::invalid_argument("Invalid data type");
}
}
private:
ArrayInterfaceHandler::Type data_type_;
Op operation_;
};
/**
* @brief Functor for broadcast.
*/
class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
if (rank_ == root_) {
// Copy the input if this is the root.
buffer->assign(input, bytes);
}
}
private:
std::int32_t rank_;
std::int32_t root_;
};
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
std::unique_lock<std::mutex> lock(mutex_);
world_size_++;
cv_.wait(lock, [this, world_size] { return world_size_ == world_size; });
lock.unlock();
cv_.notify_all();
}
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
received_++;
cv_.wait(lock, [this] { return received_ == world_size_; });
received_ = 0;
world_size_ = 0;
sequence_number_ = 0;
lock.unlock();
cv_.notify_all();
}
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
}
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op) {
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
}
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
}
template <class HandlerFunctor>
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::int32_t rank,
HandlerFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
output->assign(input, bytes);
return;
}
std::unique_lock<std::mutex> lock(mutex_);
LOG(DEBUG) << functor.name << " rank " << rank << ": waiting for current sequence number";
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
LOG(DEBUG) << functor.name << " rank " << rank << ": handling request";
functor(input, bytes, &buffer_);
received_++;
if (received_ == world_size_) {
LOG(DEBUG) << functor.name << " rank " << rank << ": all requests received";
output->assign(buffer_.data(), buffer_.size());
sent_++;
lock.unlock();
cv_.notify_all();
return;
}
LOG(DEBUG) << functor.name << " rank " << rank << ": waiting for all clients";
cv_.wait(lock, [this] { return received_ == world_size_; });
LOG(DEBUG) << functor.name << " rank " << rank << ": sending reply";
output->assign(buffer_.data(), buffer_.size());
sent_++;
if (sent_ == world_size_) {
LOG(DEBUG) << functor.name << " rank " << rank << ": all replies sent";
sent_ = 0;
received_ = 0;
buffer_.clear();
sequence_number_++;
lock.unlock();
cv_.notify_all();
}
}
} // namespace xgboost::collective