forked from stepfun-ai/StepMesh
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpostoffice.h
More file actions
319 lines (296 loc) · 10 KB
/
postoffice.h
File metadata and controls
319 lines (296 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
/**
* Copyright (c) 2015 by Contributors
* Modifications Copyright (C) by StepAI Contributors. 2025.
*/
#ifndef PS_INTERNAL_POSTOFFICE_H_
#define PS_INTERNAL_POSTOFFICE_H_
#include <algorithm>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "ps/internal/customer.h"
#include "ps/internal/env.h"
#include "ps/internal/van.h"
#include "ps/range.h"
namespace ps {
/**
* \brief the center of the system
*/
class Postoffice {
public:
/**
* \brief return the first valid Postoffice instance in the following order:
* scheduler, server, worker.
*/
static Postoffice* Get() {
PS_CHECK(initialized_) << "Please call ps::StartPS() first";
if (po_scheduler_) return po_scheduler_;
if (po_server_group_.size()) return po_server_group_.at(0);
return po_worker_group_.at(0);
}
/**
* \brief return the Postoffice instance for scheduler if it exists.
* Otherwise, return the one for the server.
* \param index the instance offset inside the worker group.
* it should be less than DMLC_GROUP_SIZE.
*/
static Postoffice* GetServer(int index = 0) {
PS_CHECK(initialized_) << "Please call ps::StartPS() first";
if (po_scheduler_) return po_scheduler_;
return po_server_group_.at(index);
}
/**
* \brief return the Postoffice instance for scheduler.
*/
static Postoffice* GetScheduler() {
PS_CHECK(initialized_) << "Please call ps::StartPS() first";
return po_scheduler_;
}
/**
* \brief return the Postoffice instance for worker.
* \param index the instance offset inside the worker group.
* it should be less than DMLC_GROUP_SIZE.
*/
static Postoffice* GetWorker(int index = 0) {
PS_CHECK(initialized_) << "Please call ps::StartPS() first";
return po_worker_group_.at(index);
}
/** \brief get the van */
Van* van() { return van_; }
/**
* \brief start the system
*
* This function will block until every nodes are started if do_barrier is
True.
* \param customer_id the customer id
* \param role the role of the postoffice
* \param rank the rank. -1 means no preference and the rank will be assigned
by the scheduler.
* \param do_barrier whether to block until every nodes are started.
* \param argv0 the program name, used for logging.
*/
void Start(int customer_id, const Node::Role role, int rank,
const bool do_barrier, const char* argv0);
/**
* \brief terminate the system
*
* All nodes should call this function before existing.
* \param do_barrier whether to do block until every node is finalized,
* default true.
*/
void Finalize(const int customer_id, const bool do_barrier = true);
/**
* \brief add an customer to the system. threadsafe
*/
void AddCustomer(Customer* customer);
/**
* \brief remove a customer by given it's id. threasafe
*/
void RemoveCustomer(Customer* customer);
/**
* \brief get the customer by id, threadsafe
* \param app_id the application id
* \param customer_id the customer id
* \param timeout timeout in sec
* \return return nullptr if doesn't exist and timeout
*/
Customer* GetCustomer(int app_id, int customer_id, int timeout = 0) const;
/**
* \brief get the ids of a role group, threadsafe
*
* if it is a node group, return the list of postoffice INSTANCE ids in this
* role group. otherwise, return {node_id}
*/
const std::vector<int>& GetNodeIDs(int node_id) const {
const auto it = node_ids_.find(node_id);
PS_CHECK(it != node_ids_.cend()) << "node " << node_id << " doesn't exist";
return it->second;
}
/**
* \brief return the key ranges of all server GROUP nodes
*/
const std::vector<Range>& GetServerKeyRanges();
/**
* \brief the template of a callback
*/
using Callback = std::function<void()>;
/**
* \brief Register a callback to the system which is called after Finalize()
*
* The following codes are equal
* \code {cpp}
* RegisterExitCallback(cb);
* Finalize();
* \endcode
*
* \code {cpp}
* Finalize();
* cb();
* \endcode
* \param cb the callback function
*/
void RegisterExitCallback(const Callback& cb) { exit_callback_ = cb; }
/**
* \brief convert a worker group's rank into a instance id with the
* provded instance offset from that group
* \param rank the worker group rank
* \param instance_idx the offset of the instance in the group
*/
inline int GroupWorkerRankToInstanceID(int rank, int instance_idx) {
int instance_rank = rank * group_size_ + instance_idx;
return WorkerRankToID(instance_rank);
}
/**
* \brief convert a server group's rank into a instance id with the
* provided instance offset from that group
* \param rank the server group rank
* \param instance_idx the offset of the instance in the group
*/
inline int GroupServerRankToInstanceID(int rank, int instance_idx) {
int instance_rank = rank * group_size_ + instance_idx;
return ServerRankToID(instance_rank);
}
/**
* \brief convert an instance id into a server group or worker group rank
* \param id the instance id
*/
inline int InstanceIDtoGroupRank(int id) {
int instance_rank = IDtoRank(id);
int group_rank = instance_rank / group_size_;
return group_rank;
}
/**
* \brief convert from a worker rank into a node id
* \param rank the worker rank
*/
static inline int WorkerRankToID(int rank) { return rank * 2 + 9; }
/**
* \brief convert from a server rank into a node id
* \param rank the server rank
*/
static inline int ServerRankToID(int rank) { return rank * 2 + 8; }
/**
* \brief convert from a node id into a server or worker rank
* \param id the node id
*/
static inline int IDtoRank(int id) {
#ifdef _MSC_VER
#undef max
#endif
return std::max((id - 8) / 2, 0);
}
/** \brief Returns the size of a worker/server group */
int group_size() const { return group_size_; }
/** \brief Returns the number of worker groups */
int num_workers() const { return num_workers_; }
/** \brief Returns the number of server groups */
int num_servers() const { return num_servers_; }
/** \brief Returns the number of worker instances */
int num_worker_instances() const { return num_workers_ * group_size_; }
/** \brief Returns the number of server instances */
int num_server_instances() const { return num_servers_ * group_size_; }
/** \brief Returns the rank of this node in its role group
*
* Each worker will have a unique rank within [0, NumWorkers()). So are
* servers. This function is available only after \ref Start has been called.
*/
int my_rank() const { return IDtoRank(van_->my_node().id); }
int preferred_rank() const { return preferred_rank_; }
/** \brief Returns true if this node is a worker node */
int is_worker() const { return is_worker_; }
/** \brief Returns true if this node is a server node. */
int is_server() const { return is_server_; }
/** \brief Returns true if this node is a scheduler node. */
int is_scheduler() const { return is_scheduler_; }
std::string role_str() const {
std::string str;
if (is_worker_) str = "worker";
if (is_scheduler_) str = "scheduler";
if (is_server_) str = "server";
return str;
}
/** \brief Returns the verbose level. */
int verbose() const { return verbose_; }
/** \brief Return whether this node is a recovery node */
bool is_recovery() const { return van_->my_node().is_recovery; }
/**
* \brief barrier
* \param node_id the barrier group id
*/
void Barrier(int customer_id, int node_group);
/**
* \brief barrier for all postoffice instances or groups
* \param customer_id the id of the customer
* \param node_id the barrier group id
* \param instance_barrier whether it's for all postoffice instances or groups
*/
void DoBarrier(int customer_id, int node_group, bool instance_barrier);
/**
* \brief process a control message, called by van
* \param the received message
*/
void Manage(const Message& recv);
/**
* \brief update the heartbeat record map
* \param node_id the \ref Node id
* \param t the last received heartbeat time
*/
void UpdateHeartbeat(int node_id, time_t t) {
std::lock_guard<std::mutex> lk(heartbeat_mu_);
heartbeats_[node_id] = t;
}
/**
* \brief get node ids that haven't reported heartbeats for over t seconds
* \param t timeout in sec
*/
std::vector<int> GetDeadNodes(int t = 60);
// initialize all instances in the group for this role
static void Init(ps::Node::Role role);
private:
/**
* \param instance_idx the offset of the instance inside the group.
* It should be less than DMLC_GROUP_SIZE
*/
explicit Postoffice(int instance_idx);
~Postoffice() { delete van_; }
static Postoffice* po_scheduler_;
static std::mutex init_mu_;
// the group of postoffices for workers
static std::vector<Postoffice*> po_worker_group_;
// the group of postoffices for servers
static std::vector<Postoffice*> po_server_group_;
// initialization
static bool initialized_;
void InitEnvironment();
Van* van_;
mutable std::mutex mu_;
// app_id -> (customer_id -> customer pointer)
std::unordered_map<int, std::unordered_map<int, Customer*>> customers_;
std::unordered_map<int, std::vector<int>> node_ids_;
std::mutex server_key_ranges_mu_;
std::vector<Range> server_key_ranges_;
bool is_worker_, is_server_, is_scheduler_;
int num_servers_, num_workers_, group_size_;
// a hint for preferred rank
int preferred_rank_;
std::unordered_map<int, std::unordered_map<int, bool>> barrier_done_;
int verbose_;
std::mutex barrier_mu_;
std::condition_variable barrier_cond_;
std::mutex heartbeat_mu_;
std::mutex start_mu_;
int init_stage_ = 0;
int instance_idx_ = 0;
std::unordered_map<int, time_t> heartbeats_;
Callback exit_callback_;
/** \brief Holding a shared_ptr to prevent it from being destructed too early
*/
std::shared_ptr<Environment> env_ref_;
time_t start_time_;
bool started_;
DISALLOW_COPY_AND_ASSIGN(Postoffice);
};
} // namespace ps
#endif // PS_INTERNAL_POSTOFFICE_H_