diff --git a/Makefile b/Makefile index 124c569..1dfcbfe 100644 --- a/Makefile +++ b/Makefile @@ -104,6 +104,6 @@ test: $(TEST) af: @mkdir -p cmake_build - @cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=/usr/bin/python3 -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j + @cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j @mkdir -p build diff --git a/fserver/csrc/private.hpp b/fserver/csrc/private.hpp index aa3e668..816563d 100644 --- a/fserver/csrc/private.hpp +++ b/fserver/csrc/private.hpp @@ -13,9 +13,72 @@ using namespace ps; #ifdef DMLC_USE_CUDA -void pybind_private(py::module &m){} +class SimpleNotify{ +private: + int notify_cnt = 1; + CUdeviceptr dflag; + uint32_t* hflag; + std::thread th_; + std::future> fut; +public: + void init() { + cudaHostAlloc(&hflag, sizeof(uint32_t), cudaHostAllocMapped); + cudaHostGetDevicePointer((void**)&dflag, (void*)hflag, 0); + } + + // for worker + void wait_event_done(){ + if (th_.joinable()) { + th_.join(); + } + } + + // for worker + void stream_wait_event(int handler) { + auto stream = at::cuda::getCurrentCUDAStream(); + cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ); + th_ = std::thread([handler, this]{ + fworker_->Wait(handler); + *(this->hflag) = this->notify_cnt; + ++(this->notify_cnt); + }); + } + + void block_now_stream() { + auto stream = at::cuda::getCurrentCUDAStream(); + cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ); + } + + // for server + void block_now_stream_and_get_batch() { + auto stream = at::cuda::getCurrentCUDAStream(); + cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ); + fut = std::async(std::launch::async, [this]{ + auto ret = get_batch(); + *(this->hflag) = this->notify_cnt; + ++(this->notify_cnt); + return ret; + }); + } + + // for server + std::vector get_future_batch_data(){ + return fut.get(); + } +}; + +void pybind_private(py::module &m){ + py::class_(m, "SimpleNotify") + .def(py::init<>()) + .def("init", &SimpleNotify::init) + .def("block_now_stream_and_get_batch", &SimpleNotify::block_now_stream_and_get_batch) + .def("get_future_batch_data", &SimpleNotify::get_future_batch_data) + .def("block_now_stream", &SimpleNotify::block_now_stream) + .def("wait_event_done", &SimpleNotify::wait_event_done) + .def("stream_wait_event", &SimpleNotify::stream_wait_event); +} #else void pybind_private(py::module &m){} #endif //DMLC_USE_CUDA -#endif //PRIVATE_OPS_ \ No newline at end of file +#endif //PRIVATE_OPS_ diff --git a/fserver/csrc/public.hpp b/fserver/csrc/public.hpp index 93a1f18..005ab61 100644 --- a/fserver/csrc/public.hpp +++ b/fserver/csrc/public.hpp @@ -38,9 +38,6 @@ int instance_id_ = 0; int num_worker_ = 1; uint64_t worker_mask_ = 0x1; -typedef std::tuple, std::vector> - ServerDataBatch; - std::mutex mu_; uint64_t handler_counter_ = 0; std::unordered_map meta_map_; diff --git a/fserver/csrc/util.h b/fserver/csrc/util.h index 9085c7f..7176d2d 100644 --- a/fserver/csrc/util.h +++ b/fserver/csrc/util.h @@ -20,4 +20,6 @@ #ifndef UTIL_H_ #define UTIL_H_ +typedef std::tuple, std::vector> + ServerDataBatch; #endif // UTIL_H_