Skip to content

Commit bea733d

Browse files
authored
Merge pull request #12 from fengidri/main
Adjust the barrier api, change default mode to all instance barrier; Add barrier in server side in bmk.
2 parents 4e76502 + 0e8ad0b commit bea733d

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

fserver/csrc/public.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ void wait(int handler) {
148148
fworker_->Wait(handler);
149149
}
150150

151-
void barrier(bool include_server, bool include_worker) {
151+
void barrier(bool include_server, bool include_worker, bool instrance_barrier=true) {
152152
int node_group = 0;
153153
if (include_server) {
154154
node_group += ps::kServerGroup;
@@ -158,9 +158,9 @@ void barrier(bool include_server, bool include_worker) {
158158
}
159159

160160
if (role_ == Node::WORKER && include_worker) {
161-
ps::Postoffice::GetWorker(instance_id_)->Barrier(0, node_group);
161+
ps::Postoffice::GetWorker(instance_id_)->DoBarrier(0, node_group, instrance_barrier);
162162
} else if (role_ == Node::SERVER && include_server) {
163-
ps::Postoffice::GetServer(instance_id_)->Barrier(0, node_group);
163+
ps::Postoffice::GetServer(instance_id_)->DoBarrier(0, node_group, instrance_barrier);
164164
} else {
165165
ps::Postoffice::Get()->Barrier(0, node_group);
166166
}
@@ -261,7 +261,11 @@ void pybind_public(py::module &m){
261261
// fetch_trace needs gil_scoped_release
262262
m.def("fetch_trace", &fetch_trace, py::call_guard<py::gil_scoped_release>());
263263
m.def("get_all_handlers", &get_all_handlers, py::call_guard<py::none>());
264-
m.def("barrier", &barrier, py::call_guard<py::none>());
264+
m.def("barrier", &barrier,
265+
py::arg("include_server"),
266+
py::arg("include_client"),
267+
py::arg("instance_barrier") = true,
268+
py::call_guard<py::none>());
265269
m.def("get_nanosecond", &get_nanosecond, py::call_guard<py::none>());
266270
}
267271

include/ps/internal/postoffice.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,13 @@ class Postoffice {
236236
* \param node_id the barrier group id
237237
*/
238238
void Barrier(int customer_id, int node_group);
239+
/**
240+
* \brief barrier for all postoffice instances or groups
241+
* \param customer_id the id of the customer
242+
* \param node_id the barrier group id
243+
* \param instance_barrier whether it's for all postoffice instances or groups
244+
*/
245+
void DoBarrier(int customer_id, int node_group, bool instance_barrier);
239246
/**
240247
* \brief process a control message, called by van
241248
* \param the received message
@@ -267,14 +274,6 @@ class Postoffice {
267274
explicit Postoffice(int instance_idx);
268275
~Postoffice() { delete van_; }
269276

270-
/**
271-
* \brief barrier for all postoffice instances or groups
272-
* \param customer_id the id of the customer
273-
* \param node_id the barrier group id
274-
* \param instance_barrier whether it's for all postoffice instances or groups
275-
*/
276-
void DoBarrier(int customer_id, int node_group, bool instance_barrier);
277-
278277
static Postoffice* po_scheduler_;
279278
static std::mutex init_mu_;
280279
// the group of postoffices for workers

tests/benchmark/bmk_comm_latency_multiserver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _str_formater(data):
171171
if is_worker:
172172
th = threading.Thread(target=print_thread)
173173
th.start()
174-
f.barrier(True, True)
174+
f.barrier(False, True)
175175
q = Queue()
176176
time_list = []
177177
net_cost_list = [[] for _ in range(bsz + 1 + bsz)]
@@ -202,7 +202,7 @@ def worker():
202202
elif is_server:
203203
ret_buffer = torch.rand([65535, dim], dtype=torch.bfloat16, device='cuda')
204204
count = 0
205-
205+
f.barrier(True, False)
206206
def server():
207207
global count
208208
iter_count = 0

0 commit comments

Comments
 (0)