Skip to content

Commit 5ed45d4

Browse files
authored
Merge pull request #37 from niehao100/public
Support For VLLM AFD Connector Beta.
2 parents 1f0fe81 + 22666d1 commit 5ed45d4

29 files changed

+1188
-72
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ message("MY PYTHON_EXECUTABLE ${Python_EXECUTABLE}")
1414
message("MY PYTORCH_CMAKE_PREFIX_PATH ${PYTORCH_CMAKE_PREFIX_PATH}")
1515

1616
list(APPEND CMAKE_PREFIX_PATH "${PYTORCH_CMAKE_PREFIX_PATH}/Torch")
17+
1718
find_package(Torch REQUIRED CONFIG)
1819
message("MY TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}")
1920
message("MY CUDA_INCLUDE_DIRS ${CUDA_INCLUDE_DIRS}")
@@ -27,7 +28,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
2728
if("$ENV{USE_CUDA}" STREQUAL "0")
2829
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_RDMA -DSTEPMESH_USE_GDR")
2930
else()
30-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_CUDA -DSTEPMESH_USE_GDR -DDMLC_USE_RDMA")
31+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_ZMQ -DDMLC_USE_CUDA -DSTEPMESH_USE_GDR -DDMLC_USE_RDMA -DSTEPMESH_ENABLE_TRACE")
3132
endif()
3233

3334
link_directories("${PROJECT_SOURCE_DIR}/deps/lib")

fserver/csrc/public.hpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void respond(std::vector<torch::Tensor>& tensors,
100100
PS_CHECK_EQ(tensors.size(), reqmeta.pull_tensors.size());
101101
std::vector<KeyTensor> result;
102102
for (size_t i = 0; i < tensors.size(); ++i) {
103-
result.push_back({reqmeta.pull_tensors[i].key, std::move(tensors[i])});
103+
result.push_back({reqmeta.pull_tensors[i].key, std::move(tensors[i].detach())});
104104
}
105105
fserver_->Response(reqmeta, result, need_event);
106106
}
@@ -130,19 +130,19 @@ int push_pull(std::vector<torch::Tensor>& push_tensors,
130130
auto pull_batch = KeyTensorBatch(pull_tensors.size());
131131
for (size_t i = 0; i < push_tensors.size(); i++) {
132132
push_batch[i] = KeyTensor{
133-
static_cast<uint64_t>(push_keys[i]), std::move(push_tensors[i])
133+
static_cast<uint64_t>(push_keys[i]), std::move(push_tensors[i].detach())
134134
};
135135
}
136136
for (size_t i = 0; i < pull_tensors.size(); i++) {
137137
pull_batch[i] = KeyTensor{
138-
static_cast<uint64_t>(pull_keys[i]), std::move(pull_tensors[i])
138+
static_cast<uint64_t>(pull_keys[i]), std::move(pull_tensors[i].detach())
139139
};
140140
}
141141
return fworker_->ZBatchPushPull(push_batch, pull_batch);
142142
}
143143

144-
void wait(int handler) {
145-
fworker_->Wait(handler);
144+
void wait(int handler, uint64_t timeout_ms = 1000) {
145+
fworker_->Wait(handler, timeout_ms);
146146
}
147147

148148
void barrier(bool include_server, bool include_worker, bool instrance_barrier=true) {
@@ -163,26 +163,29 @@ void barrier(bool include_server, bool include_worker, bool instrance_barrier=tr
163163
}
164164
}
165165

166+
166167
void init() {
168+
167169
std::string role_str = ps::GetEnv("DMLC_ROLE", "server");
170+
int offset = 0;
168171
role_ = ps::GetRole(role_str);
169172

170173
ps::Environment::Get()->find("STEPMESH_GPU", &gpu_, gpu_);
171174
ps::Environment::Get()->find("DMLC_GROUP_SIZE", &group_size_, group_size_);
172175
ps::Environment::Get()->find("DMLC_NODE_RANK", &node_rank_, node_rank_);
173-
ps::Environment::Get()->find("DMLC_INSTANCE_ID", &instance_id_, gpu_);
176+
ps::Environment::Get()->find("DMLC_RANK_OFFSET", &offset, offset);
177+
ps::Environment::Get()->find("DMLC_INSTANCE_ID", &instance_id_, gpu_ + offset);
174178
ps::Environment::Get()->find("DMLC_NUM_WORKER", &num_worker_, num_worker_);
175-
179+
176180
worker_mask_ = (1 << num_worker_) - 1;
177181
q_.resize(num_worker_);
178182
q_signal_.store(0);;
179-
180-
ps::StartPS(0, role_, group_size_ * node_rank_ + gpu_, true);
183+
ps::StartPS(0, role_, group_size_ * node_rank_ + gpu_ + offset, true);
181184
if (role_ == Node::WORKER) {
182-
fworker_ = new AFTensorWorker(instance_id_);
185+
fworker_ = new AFTensorWorker(instance_id_ );
183186
barrier(true, true);
184187
} else if (role_ == Node::SERVER) {
185-
fserver_ = new AFTensorServer(instance_id_);
188+
fserver_ = new AFTensorServer(instance_id_ );
186189
fserver_->SetRequestHandle(RequestHandler);
187190
ps::RegisterExitCallback([]() { delete fserver_; });
188191
barrier(true, true);
@@ -242,8 +245,16 @@ void pybind_public(py::module &m){
242245
py::call_guard<py::gil_scoped_release>());
243246

244247
// APIs for Attention Instances
245-
m.def("push_pull", &push_pull, py::call_guard<py::none>());
246-
m.def("wait", &wait, py::call_guard<py::none>());
248+
m.def("push_pull", &push_pull,
249+
py::arg("push_tensors"),
250+
py::arg("push_keys"),
251+
py::arg("pull_tensors"),
252+
py::arg("pull_keys"),
253+
py::call_guard<py::none>());
254+
m.def("wait", &wait,
255+
py::arg("handler"),
256+
py::arg("timeout_ms") = 10000,
257+
py::call_guard<py::none>());
247258

248259
// APIs for FFN Instances
249260
m.def("get_batch", &get_batch, py::call_guard<py::none>());

fserver/csrc/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "ps/ps.h"
2020

2121
#ifndef UTIL_H_
22+
typedef std::tuple<uint64_t, std::vector<torch::Tensor>, std::vector<uint64_t>>
23+
ServerDataBatch;
2224
#define UTIL_H_
2325
typedef std::tuple<uint64_t, std::vector<torch::Tensor>, std::vector<uint64_t>>
2426
ServerDataBatch;

include/dmlc/logging.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class LogMessage {
190190
#endif
191191
{
192192
log_stream_ << "[" << pretty_date_.HumanDate() << "] "
193-
<< getenv("DMLC_ROLE") << " " << file << ":" << line << ": ";
193+
<< getenv("DMLC_ROLE") << " " << getenv("STEPMESH_GPU") << " " << file << ":" << std::dec << line << ": ";
194194
}
195195
~LogMessage() { log_stream_ << "\n"; }
196196
std::ostream &stream() { return log_stream_; }

include/ps/af_tensor_app.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class AFTensorWorker {
133133
req.event = GetEvent();
134134
req.event->Record();
135135

136+
PS_VLOG(3) << "ts" << start_ts << " pushpull_queue_ push "
137+
<< pushpull_queue_.Size();
136138
pushpull_queue_.Push(std::move(req));
137139

138140
// std::unique_lock<std::mutex> timestamp_lock(timestamp_mu_);
@@ -144,13 +146,13 @@ class AFTensorWorker {
144146
* \brief Wait for the operation to complete
145147
* @param timestamp return by push, pull or push-pull operations
146148
*/
147-
void Wait(int timestamp) {
148-
kv_.Wait(timestamp);
149+
void Wait(int timestamp, uint64_t timeout_ms = 10000) {
150+
kv_.Wait(timestamp, timeout_ms);
149151
// std::unique_lock<std::mutex> lock(timestamp_mu_);
150152
auto itr = batch_timestamps_.find(timestamp);
151153
if (itr != batch_timestamps_.end()) {
152154
for (auto ts : itr->second) {
153-
kv_.Wait(ts);
155+
kv_.Wait(ts, timeout_ms);
154156
}
155157
batch_timestamps_.erase(itr);
156158
}
@@ -199,15 +201,15 @@ class AFTensorWorker {
199201
}
200202

201203
void PushPullWorker() {
202-
BindCpuCore(4, 1);
204+
BindCpuCore(3, 1);
203205
Backend::Get()->SetDevice(gpu_);
204-
while (!pushpull_stop_.load()) {
206+
while (true) {
207+
PS_VLOG(4) << "pushpull_queue_ Loop wait ";
205208
AFTensorRequest req;
206-
pushpull_queue_.WaitAndPop(&req);
207-
208209
if (pushpull_stop_.load()) {
209210
break;
210211
}
212+
pushpull_queue_.WaitAndPop(&req, true);
211213

212214
if (req.event != nullptr) {
213215
req.event->Sync();
@@ -216,6 +218,8 @@ class AFTensorWorker {
216218
}
217219
ZBatchPushPull_(req.push, req.push_timestamps, req.pull,
218220
req.pull_timestamps);
221+
PS_VLOG(4) << "pushpull_queue_ Loop done " << req.push_timestamps[0]
222+
<< " " << req.pull_timestamps[0];
219223
}
220224
PS_LOG(INFO) << "Stop PushPullWorker" << gpu_;
221225
}
@@ -233,6 +237,8 @@ class AFTensorWorker {
233237
msg.meta.timestamp = ts;
234238
msg.meta.addr = reinterpret_cast<uint64_t>(tensor.data_ptr());
235239
msg.meta.val_len = tensor.numel() * tensor.itemsize();
240+
PS_VLOG(2) << "ZPush_ addr: 0x" << std::hex << msg.meta.addr << std::dec
241+
<< " val_len: " << msg.meta.val_len;
236242
msg.meta.key = keys[0];
237243
msg.meta.is_tensor = 1;
238244
msg.meta.dtype = static_cast<int>(tensor.scalar_type());

include/ps/internal/customer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct CustomerTracker {
3232
std::atomic<int> response_count;
3333
struct Trace request;
3434
struct Trace response;
35+
uint64_t start_time;
3536
};
3637

3738
class Customer {
@@ -80,7 +81,7 @@ class Customer {
8081
* \brief wait until the request is finished. threadsafe
8182
* \param timestamp the timestamp of the request
8283
*/
83-
void WaitRequest(int timestamp);
84+
void WaitRequest(int timestamp, uint64_t timeout_ms = 10000);
8485

8586
/**
8687
* \brief return the number of responses received for the request. threadsafe

include/ps/internal/message.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ struct Meta {
293293
ss << " }";
294294
}
295295
if (!control.empty() || simple_app) ss << ". NOT DATA MSG!";
296-
ss << "Slave QP Count: " << slave_qp_num;
296+
ss << ", Slave QP Count: " << slave_qp_num;
297297
return ss.str();
298298
}
299299
/** \brief an int head */
@@ -384,6 +384,7 @@ struct Message {
384384
meta.dst_dev_id = val.dst_device_id_;
385385
}
386386
}
387+
387388
std::string DebugString() const {
388389
std::stringstream ss;
389390
ss << meta.DebugString();

include/ps/internal/threadsafe_queue.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <queue>
1111
#include <utility>
1212

13+
#include "dmlc/logging.h"
1314
#include "ps/base.h"
1415
#include "ps/internal/env.h"
1516
#include "ps/internal/spsc_queue.h"
@@ -34,10 +35,10 @@ class ThreadsafeQueue {
3435
* \brief push an value into the end. threadsafe.
3536
* \param new_value the value
3637
*/
37-
inline void Push(T new_value) {
38+
inline void Push(T new_value, bool print_log = false) {
3839
if (lockless_) {
3940
// PushLockless(std::move(new_value));
40-
PushAtomic(std::move(new_value));
41+
PushAtomic(std::move(new_value), print_log);
4142
return;
4243
}
4344
{
@@ -51,10 +52,10 @@ class ThreadsafeQueue {
5152
* \brief wait until pop an element from the beginning, threadsafe
5253
* \param value the poped value
5354
*/
54-
inline void WaitAndPop(T* value) {
55+
inline void WaitAndPop(T* value, bool print_log = false) {
5556
if (lockless_) {
5657
// WaitAndPopLockless(value);
57-
WaitAndPopAtomic(value);
58+
WaitAndPopAtomic(value, print_log);
5859
return;
5960
}
6061
std::unique_lock<std::mutex> lk(mu_);
@@ -96,7 +97,7 @@ class ThreadsafeQueue {
9697
}
9798
}
9899

99-
void PushAtomic(T new_value) {
100+
void PushAtomic(T new_value, bool print_log = false) {
100101
const size_t current_tail = tail_.load(std::memory_order_relaxed);
101102
const size_t next_tail = (current_tail + 1) % capacity_;
102103
while (next_tail == head_.load(std::memory_order_acquire)) {
@@ -112,19 +113,19 @@ class ThreadsafeQueue {
112113
return;
113114
}
114115

115-
void WaitAndPopAtomic(T* value) {
116-
const size_t current_head = head_.load(std::memory_order_relaxed);
117-
116+
void WaitAndPopAtomic(T* value, bool print_log = false) {
117+
size_t current_head = head_.load(std::memory_order_relaxed);
118118
// Check if the queue is empty
119119
// acquire: ensures writes preceding this load in other threads are
120120
// visible. Specifically, ensures the producer's writes to 'tail_' are
121121
// visible.
122-
int max_count = 1000;
122+
int max_count = 5000;
123123
int count = 0;
124124
while (current_head == tail_.load(std::memory_order_acquire)) {
125125
// Queue is empty, spin and yield
126126
count++;
127127
if (count > max_count) {
128+
current_head = head_.load(std::memory_order_relaxed);
128129
count = 0;
129130
// _mm_pause();
130131
}

include/ps/internal/utils.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#define PS_INTERNAL_UTILS_H_
77

88
#include <ctype.h>
9+
#include <execinfo.h>
910
#include <pthread.h>
1011
#include <sched.h>
1112
#include <stdio.h>
@@ -109,25 +110,37 @@ static uint64_t norm = CycleToNs();
109110
/*!
110111
* \brief Get the current nanocount.
111112
*/
112-
static inline uint64_t GetNanosecond() {
113+
static inline uint64_t GetNanosecond(bool return_zero = true) {
113114
#ifdef STEPMESH_ENABLE_TRACE
115+
return_zero = false;
116+
#endif
117+
if (return_zero) {
118+
return 0;
119+
}
114120
if (norm == 0) {
115121
norm = CycleToNs();
116122
}
117123
return static_cast<uint64_t>((_GetCurrentCycle() << 5) / norm);
118-
#else
119-
return 0;
120-
#endif
121124
}
122125

123126
static int PS_VERBOSE = ps::GetEnv("PS_VERBOSE", 0);
124127

128+
/**
129+
* @brief Rename Thread
130+
*
131+
*/
132+
133+
static inline void RenameThread(const std::string &name) {
134+
pthread_setname_np(pthread_self(), name.c_str());
135+
}
136+
125137
/**
126138
* \brief Bind current thread to a specific CPU core.
127139
* \param offset is the start of the core id
128140
* \param core_count is the number of cores the thread need.
129141
*/
130142
static inline void BindCpuCore(int offset, int core_count = 1) {
143+
RenameThread("StepMesh: BindCpuCore");
131144
int gpu = -1;
132145
Environment::Get()->find("STEPMESH_GPU", &gpu, gpu);
133146
int bind_enable = 0;

include/ps/kv_app.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class KVWorker : public SimpleApp {
9090
*/
9191
explicit KVWorker(int app_id, int customer_id, int instance_idx = 0)
9292
: SimpleApp() {
93+
printf("KVWorker instance_idx,%d\n", instance_idx);
9394
postoffice_ = Postoffice::GetWorker(instance_idx);
9495
PS_VLOG(3) << "KVWorker " << instance_idx << " po@"
9596
<< reinterpret_cast<uint64_t>(postoffice_);
@@ -207,7 +208,9 @@ class KVWorker : public SimpleApp {
207208
*
208209
* \param timestamp the timestamp returned by the push or pull
209210
*/
210-
void Wait(int timestamp) { obj_->WaitRequest(timestamp); }
211+
void Wait(int timestamp, uint64_t timeout_ms = 10000) {
212+
obj_->WaitRequest(timestamp, timeout_ms);
213+
}
211214

212215
/**
213216
* \brief zero-copy Push

0 commit comments

Comments
 (0)