1+ #pragma once
2+
3+ #include < algorithm>
4+ #include < condition_variable>
5+ #include < functional>
6+ #include < memory>
7+ #include < queue>
8+ #include < shared_mutex>
9+ #include < stdio.h>
10+ #include < thread>
11+ #include < vector>
12+
13+ namespace kv_cache_manager {
14+
15+ template <typename ClientType>
16+ class ClientPool {
17+ public:
18+ using CreateClientCallback = std::function<std::shared_ptr<ClientType>()>;
19+
20+ explicit ClientPool (CreateClientCallback cb) : cb_(cb) {}
21+ virtual ~ClientPool () = default ;
22+
23+ class PoolState {
24+ public:
25+ std::shared_ptr<ClientType> AcquireClient (int64_t lock_timeout_ms) {
26+ std::chrono::milliseconds timeout (lock_timeout_ms);
27+ {
28+ std::unique_lock lock (mtx_);
29+ if (!cv_.wait_for (lock, timeout, [this ] { return !client_pool_.empty (); })) {
30+ return nullptr ;
31+ }
32+ std::shared_ptr<ClientType> client = std::move (client_pool_.front ());
33+ client_pool_.pop ();
34+ return client;
35+ }
36+ return nullptr ;
37+ }
38+
39+ void ReleaseClient (std::shared_ptr<ClientType> client, bool is_new = false ) {
40+ if (client) {
41+ std::unique_lock lock (mtx_);
42+ if (is_new) {
43+ client_ref_.push_back (client);
44+ }
45+ client_pool_.push (client);
46+ cv_.notify_one ();
47+ }
48+ }
49+
50+ size_t AllClientSize () const {
51+ std::shared_lock lock (mtx_);
52+ return client_ref_.size ();
53+ }
54+
55+ size_t FreeClientSize () const {
56+ std::shared_lock lock (mtx_);
57+ return client_pool_.size ();
58+ }
59+
60+ private:
61+ mutable std::shared_mutex mtx_;
62+ std::condition_variable_any cv_;
63+ std::queue<std::shared_ptr<ClientType>> client_pool_;
64+ std::vector<std::shared_ptr<ClientType>> client_ref_;
65+ };
66+
67+ class ClientHandle {
68+ public:
69+ ClientHandle (std::shared_ptr<PoolState> pool_state, std::shared_ptr<ClientType> client)
70+ : pool_state_(pool_state), client_(std::move(client)) {}
71+ ClientHandle (const ClientHandle &other) = delete ;
72+ ClientHandle (ClientHandle &&other)
73+ : pool_state_(std::move(other.pool_state_)), client_(std::move(other.client_)) {}
74+
75+ ~ClientHandle () {
76+ if (client_ && pool_state_) {
77+ pool_state_->ReleaseClient (client_);
78+ }
79+ }
80+
81+ ClientType *operator ->() { return client_.get (); }
82+ ClientType &operator *() { return *client_; }
83+ explicit operator bool () const { return client_ != nullptr ; }
84+
85+ private:
86+ std::shared_ptr<PoolState> pool_state_;
87+ std::shared_ptr<ClientType> client_;
88+ };
89+
90+ virtual bool Initialize () = 0;
91+ virtual ClientHandle AcquireClient (int64_t timeout_ms = 1000 ) = 0;
92+
93+ protected:
94+ bool InitializePoolStateWithSize (size_t pool_size) {
95+ pool_state_ = std::make_shared<PoolState>();
96+ for (size_t i = 0 ; i < pool_size; ++i) {
97+ auto client = cb_ ();
98+ if (client == nullptr ) {
99+ return false ;
100+ }
101+ pool_state_->ReleaseClient (client, true );
102+ }
103+ return true ;
104+ }
105+
106+ std::shared_ptr<PoolState> pool_state_;
107+ CreateClientCallback cb_;
108+ };
109+
110+ template <typename ClientType>
111+ class StaticClientPool : public ClientPool <ClientType> {
112+ using Base = ClientPool<ClientType>;
113+ static constexpr size_t kDefaultPoolSize = 4 ;
114+
115+ public:
116+ explicit StaticClientPool (typename Base::CreateClientCallback cb, size_t pool_size = kDefaultPoolSize )
117+ : ClientPool<ClientType>(cb), pool_size_(pool_size) {}
118+
119+ bool Initialize () override { return this ->InitializePoolStateWithSize (pool_size_); }
120+ typename Base::ClientHandle AcquireClient (int64_t timeout_ms = 1000 ) override {
121+ if (!this ->pool_state_ ) {
122+ return typename Base::ClientHandle (nullptr , nullptr );
123+ }
124+ auto client = this ->pool_state_ ->AcquireClient (timeout_ms);
125+ return typename Base::ClientHandle (this ->pool_state_ , std::move (client));
126+ }
127+
128+ private:
129+ size_t pool_size_;
130+ };
131+
132+ template <typename ClientType>
133+ class DynamicClientPool : public ClientPool <ClientType> {
134+ using Base = ClientPool<ClientType>;
135+
136+ public:
137+ explicit DynamicClientPool (typename Base::CreateClientCallback cb, int32_t min_pool_size, int32_t max_pool_size)
138+ : ClientPool<ClientType>(cb), min_pool_size_(min_pool_size), max_pool_size_(max_pool_size) {}
139+
140+ bool Initialize () override { return this ->InitializePoolStateWithSize (min_pool_size_); }
141+ typename Base::ClientHandle AcquireClient (int64_t timeout_ms = 1000 ) override {
142+ if (!this ->pool_state_ ) {
143+ return typename Base::ClientHandle (nullptr , nullptr );
144+ }
145+ std::shared_ptr<ClientType> client;
146+ if (this ->pool_state_ ->FreeClientSize () > 0 || this ->pool_state_ ->AllClientSize () >= max_pool_size_) {
147+ client = this ->pool_state_ ->AcquireClient (timeout_ms);
148+ }
149+ if (client == nullptr ) {
150+ if (static_cast <int32_t >(this ->pool_state_ ->AllClientSize ()) < max_pool_size_) {
151+ {
152+ std::unique_lock lock (acq_mux_);
153+ // double check
154+ if (static_cast <int32_t >(this ->pool_state_ ->AllClientSize ()) < max_pool_size_) {
155+ auto temp_client = this ->cb_ ();
156+ if (temp_client != nullptr ) {
157+ this ->pool_state_ ->ReleaseClient (temp_client, true );
158+ }
159+ }
160+ }
161+ client = this ->pool_state_ ->AcquireClient (timeout_ms);
162+ }
163+ }
164+
165+ return typename Base::ClientHandle (this ->pool_state_ , std::move (client));
166+ }
167+
168+ private:
169+ int32_t min_pool_size_;
170+ int32_t max_pool_size_;
171+ std::mutex acq_mux_;
172+ };
173+
174+ } // namespace kv_cache_manager
0 commit comments