@@ -100,7 +100,7 @@ void respond(std::vector<torch::Tensor>& tensors,
100100 PS_CHECK_EQ (tensors.size (), reqmeta.pull_tensors .size ());
101101 std::vector<KeyTensor> result;
102102 for (size_t i = 0 ; i < tensors.size (); ++i) {
103- result.push_back ({reqmeta.pull_tensors [i].key , std::move (tensors[i])});
103+ result.push_back ({reqmeta.pull_tensors [i].key , std::move (tensors[i]. detach () )});
104104 }
105105 fserver_->Response (reqmeta, result, need_event);
106106}
@@ -130,19 +130,19 @@ int push_pull(std::vector<torch::Tensor>& push_tensors,
130130 auto pull_batch = KeyTensorBatch (pull_tensors.size ());
131131 for (size_t i = 0 ; i < push_tensors.size (); i++) {
132132 push_batch[i] = KeyTensor{
133- static_cast <uint64_t >(push_keys[i]), std::move (push_tensors[i])
133+ static_cast <uint64_t >(push_keys[i]), std::move (push_tensors[i]. detach () )
134134 };
135135 }
136136 for (size_t i = 0 ; i < pull_tensors.size (); i++) {
137137 pull_batch[i] = KeyTensor{
138- static_cast <uint64_t >(pull_keys[i]), std::move (pull_tensors[i])
138+ static_cast <uint64_t >(pull_keys[i]), std::move (pull_tensors[i]. detach () )
139139 };
140140 }
141141 return fworker_->ZBatchPushPull (push_batch, pull_batch);
142142}
143143
144- void wait (int handler) {
145- fworker_->Wait (handler);
144+ void wait (int handler, uint64_t timeout_ms = 1000 ) {
145+ fworker_->Wait (handler, timeout_ms );
146146}
147147
148148void barrier (bool include_server, bool include_worker, bool instrance_barrier=true ) {
@@ -163,26 +163,29 @@ void barrier(bool include_server, bool include_worker, bool instrance_barrier=tr
163163 }
164164}
165165
166+
166167void init () {
168+
167169 std::string role_str = ps::GetEnv (" DMLC_ROLE" , " server" );
170+ int offset = 0 ;
168171 role_ = ps::GetRole (role_str);
169172
170173 ps::Environment::Get ()->find (" STEPMESH_GPU" , &gpu_, gpu_);
171174 ps::Environment::Get ()->find (" DMLC_GROUP_SIZE" , &group_size_, group_size_);
172175 ps::Environment::Get ()->find (" DMLC_NODE_RANK" , &node_rank_, node_rank_);
173- ps::Environment::Get ()->find (" DMLC_INSTANCE_ID" , &instance_id_, gpu_);
176+ ps::Environment::Get ()->find (" DMLC_RANK_OFFSET" , &offset, offset);
177+ ps::Environment::Get ()->find (" DMLC_INSTANCE_ID" , &instance_id_, gpu_ + offset);
174178 ps::Environment::Get ()->find (" DMLC_NUM_WORKER" , &num_worker_, num_worker_);
175-
179+
176180 worker_mask_ = (1 << num_worker_) - 1 ;
177181 q_.resize (num_worker_);
178182 q_signal_.store (0 );;
179-
180- ps::StartPS (0 , role_, group_size_ * node_rank_ + gpu_, true );
183+ ps::StartPS (0 , role_, group_size_ * node_rank_ + gpu_ + offset, true );
181184 if (role_ == Node::WORKER) {
182- fworker_ = new AFTensorWorker (instance_id_);
185+ fworker_ = new AFTensorWorker (instance_id_ );
183186 barrier (true , true );
184187 } else if (role_ == Node::SERVER) {
185- fserver_ = new AFTensorServer (instance_id_);
188+ fserver_ = new AFTensorServer (instance_id_ );
186189 fserver_->SetRequestHandle (RequestHandler);
187190 ps::RegisterExitCallback ([]() { delete fserver_; });
188191 barrier (true , true );
@@ -242,8 +245,16 @@ void pybind_public(py::module &m){
242245 py::call_guard<py::gil_scoped_release>());
243246
244247 // APIs for Attention Instances
245- m.def (" push_pull" , &push_pull, py::call_guard<py::none>());
246- m.def (" wait" , &wait, py::call_guard<py::none>());
248+ m.def (" push_pull" , &push_pull,
249+ py::arg (" push_tensors" ),
250+ py::arg (" push_keys" ),
251+ py::arg (" pull_tensors" ),
252+ py::arg (" pull_keys" ),
253+ py::call_guard<py::none>());
254+ m.def (" wait" , &wait,
255+ py::arg (" handler" ),
256+ py::arg (" timeout_ms" ) = 10000 ,
257+ py::call_guard<py::none>());
247258
248259 // APIs for FFN Instances
249260 m.def (" get_batch" , &get_batch, py::call_guard<py::none>());
0 commit comments