77// modified, or distributed except according to those terms.
88
99use futures_util:: FutureExt ;
10+ use priority_queue:: PriorityQueue ;
1011use tokio:: sync:: mpsc;
1112
1213use std:: {
1314 cmp:: { Ordering , Reverse } ,
14- collections:: { BinaryHeap , VecDeque } ,
15+ collections:: VecDeque ,
1516 convert:: TryFrom ,
17+ hash:: { Hash , Hasher } ,
1618 pin:: Pin ,
1719 str:: FromStr ,
1820 sync:: { atomic, Arc , Mutex } ,
@@ -63,7 +65,7 @@ impl From<Conn> for IdlingConn {
6365/// This is fine as long as we never do expensive work while holding the lock!
6466#[ derive( Debug ) ]
6567struct Exchange {
66- waiting : BinaryHeap < QueuedWaker > ,
68+ waiting : Waitlist ,
6769 available : VecDeque < IdlingConn > ,
6870 exist : usize ,
6971 // only used to spawn the recycler the first time we're in async context
@@ -88,9 +90,45 @@ impl Exchange {
8890 }
8991}
9092
93+ #[ derive( Default , Debug ) ]
94+ struct Waitlist {
95+ queue : PriorityQueue < QueuedWaker , QueueId > ,
96+ }
97+
98+ impl Waitlist {
99+ fn push ( & mut self , w : Waker , queue_id : QueueId ) {
100+ self . queue . push (
101+ QueuedWaker {
102+ queue_id,
103+ waker : Some ( w) ,
104+ } ,
105+ queue_id,
106+ ) ;
107+ }
108+
109+ fn pop ( & mut self ) -> Option < Waker > {
110+ match self . queue . pop ( ) {
111+ Some ( ( qw, _) ) => Some ( qw. waker . unwrap ( ) ) ,
112+ None => None ,
113+ }
114+ }
115+
116+ fn remove ( & mut self , id : QueueId ) {
117+ let tmp = QueuedWaker {
118+ queue_id : id,
119+ waker : None ,
120+ } ;
121+ self . queue . remove ( & tmp) ;
122+ }
123+
124+ fn is_empty ( & self ) -> bool {
125+ self . queue . is_empty ( )
126+ }
127+ }
128+
91129const QUEUE_END_ID : QueueId = QueueId ( Reverse ( u64:: MAX ) ) ;
92130
93- #[ derive( Debug , Copy , Clone , Eq , PartialEq , Ord , PartialOrd ) ]
131+ #[ derive( Debug , Copy , Clone , Eq , PartialEq , Ord , PartialOrd , Hash ) ]
94132pub ( crate ) struct QueueId ( Reverse < u64 > ) ;
95133
96134impl QueueId {
@@ -104,13 +142,7 @@ impl QueueId {
104142#[ derive( Debug ) ]
105143struct QueuedWaker {
106144 queue_id : QueueId ,
107- waker : Waker ,
108- }
109-
110- impl QueuedWaker {
111- fn new ( queue_id : QueueId , waker : Waker ) -> Self {
112- QueuedWaker { queue_id, waker }
113- }
145+ waker : Option < Waker > ,
114146}
115147
116148impl Eq for QueuedWaker { }
@@ -133,6 +165,12 @@ impl PartialOrd for QueuedWaker {
133165 }
134166}
135167
168+ impl Hash for QueuedWaker {
169+ fn hash < H : Hasher > ( & self , state : & mut H ) {
170+ self . queue_id . hash ( state)
171+ }
172+ }
173+
136174/// Connection pool data.
137175#[ derive( Debug ) ]
138176pub struct Inner {
@@ -177,7 +215,7 @@ impl Pool {
177215 closed : false . into ( ) ,
178216 exchange : Mutex :: new ( Exchange {
179217 available : VecDeque :: with_capacity ( pool_opts. constraints ( ) . max ( ) ) ,
180- waiting : BinaryHeap :: new ( ) ,
218+ waiting : Waitlist :: default ( ) ,
181219 exist : 0 ,
182220 recycler : Some ( ( rx, pool_opts) ) ,
183221 } ) ,
@@ -227,8 +265,8 @@ impl Pool {
227265 let mut exchange = self . inner . exchange . lock ( ) . unwrap ( ) ;
228266 if exchange. available . len ( ) < self . opts . pool_opts ( ) . active_bound ( ) {
229267 exchange. available . push_back ( conn. into ( ) ) ;
230- if let Some ( qw ) = exchange. waiting . pop ( ) {
231- qw . waker . wake ( ) ;
268+ if let Some ( w ) = exchange. waiting . pop ( ) {
269+ w . wake ( ) ;
232270 }
233271 return ;
234272 }
@@ -262,8 +300,8 @@ impl Pool {
262300 let mut exchange = self . inner . exchange . lock ( ) . unwrap ( ) ;
263301 exchange. exist -= 1 ;
264302 // we just enabled the creation of a new connection!
265- if let Some ( qw ) = exchange. waiting . pop ( ) {
266- qw . waker . wake ( ) ;
303+ if let Some ( w ) = exchange. waiting . pop ( ) {
304+ w . wake ( ) ;
267305 }
268306 }
269307
@@ -296,9 +334,7 @@ impl Pool {
296334
297335 // Check if others are waiting and we're not queued.
298336 if !exchange. waiting . is_empty ( ) && !queued {
299- exchange
300- . waiting
301- . push ( QueuedWaker :: new ( queue_id, cx. waker ( ) . clone ( ) ) ) ;
337+ exchange. waiting . push ( cx. waker ( ) . clone ( ) , queue_id) ;
302338 return Poll :: Pending ;
303339 }
304340
@@ -328,11 +364,14 @@ impl Pool {
328364 }
329365
330366 // Polled, but no conn available? Back into the queue.
331- exchange
332- . waiting
333- . push ( QueuedWaker :: new ( queue_id, cx. waker ( ) . clone ( ) ) ) ;
367+ exchange. waiting . push ( cx. waker ( ) . clone ( ) , queue_id) ;
334368 Poll :: Pending
335369 }
370+
371+ fn unqueue ( & self , queue_id : QueueId ) {
372+ let mut exchange = self . inner . exchange . lock ( ) . unwrap ( ) ;
373+ exchange. waiting . remove ( queue_id) ;
374+ }
336375}
337376
338377impl Drop for Conn {
@@ -363,12 +402,20 @@ mod test {
363402 try_join, FutureExt ,
364403 } ;
365404 use mysql_common:: row:: Row ;
366- use tokio:: time:: sleep;
405+ use tokio:: time:: { sleep, timeout } ;
367406
368- use std:: time:: Duration ;
407+ use std:: {
408+ cmp:: Reverse ,
409+ task:: { RawWaker , RawWakerVTable , Waker } ,
410+ time:: Duration ,
411+ } ;
369412
370413 use crate :: {
371- conn:: pool:: Pool , opts:: PoolOpts , prelude:: * , test_misc:: get_opts, PoolConstraints , TxOpts ,
414+ conn:: pool:: { Pool , QueueId , Waitlist , QUEUE_END_ID } ,
415+ opts:: PoolOpts ,
416+ prelude:: * ,
417+ test_misc:: get_opts,
418+ PoolConstraints , TxOpts ,
372419 } ;
373420
374421 macro_rules! conn_ex_field {
@@ -824,6 +871,27 @@ mod test {
824871 Ok ( ( ) )
825872 }
826873
874+ #[ tokio:: test]
875+ async fn should_remove_waker_of_cancelled_task ( ) {
876+ let pool_constraints = PoolConstraints :: new ( 1 , 1 ) . unwrap ( ) ;
877+ let pool_opts = PoolOpts :: default ( ) . with_constraints ( pool_constraints) ;
878+
879+ let pool = Pool :: new ( get_opts ( ) . pool_opts ( pool_opts) ) ;
880+ let only_conn = pool. get_conn ( ) . await . unwrap ( ) ;
881+
882+ let join_handle = tokio:: spawn ( timeout ( Duration :: from_secs ( 1 ) , pool. get_conn ( ) ) ) ;
883+
884+ sleep ( Duration :: from_secs ( 2 ) ) . await ;
885+
886+ match join_handle. await . unwrap ( ) {
887+ Err ( _elapsed) => ( ) ,
888+ _ => panic ! ( "unexpected Ok()" ) ,
889+ }
890+ drop ( only_conn) ;
891+
892+ assert_eq ! ( 0 , pool. inner. exchange. lock( ) . unwrap( ) . waiting. queue. len( ) ) ;
893+ }
894+
827895 #[ tokio:: test]
828896 async fn should_work_if_pooled_connection_operation_is_cancelled ( ) -> super :: Result < ( ) > {
829897 let pool = Pool :: new ( get_opts ( ) ) ;
@@ -868,6 +936,40 @@ mod test {
868936 Ok ( ( ) )
869937 }
870938
939+ #[ test]
940+ fn waitlist_integrity ( ) {
941+ const DATA : * const ( ) = & ( ) ;
942+ const NOOP_CLONE_FN : unsafe fn ( * const ( ) ) -> RawWaker = |_| RawWaker :: new ( DATA , & RW_VTABLE ) ;
943+ const NOOP_FN : unsafe fn ( * const ( ) ) = |_| { } ;
944+ static RW_VTABLE : RawWakerVTable =
945+ RawWakerVTable :: new ( NOOP_CLONE_FN , NOOP_FN , NOOP_FN , NOOP_FN ) ;
946+ let w = unsafe { Waker :: from_raw ( RawWaker :: new ( DATA , & RW_VTABLE ) ) } ;
947+
948+ let mut waitlist = Waitlist :: default ( ) ;
949+ assert_eq ! ( 0 , waitlist. queue. len( ) ) ;
950+
951+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 4 ) ) ) ;
952+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 2 ) ) ) ;
953+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 8 ) ) ) ;
954+ waitlist. push ( w. clone ( ) , QUEUE_END_ID ) ;
955+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 10 ) ) ) ;
956+
957+ waitlist. remove ( QueueId ( Reverse ( 8 ) ) ) ;
958+
959+ assert_eq ! ( 4 , waitlist. queue. len( ) ) ;
960+
961+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
962+ assert_eq ! ( 2 , id. 0 . 0 ) ;
963+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
964+ assert_eq ! ( 4 , id. 0 . 0 ) ;
965+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
966+ assert_eq ! ( 10 , id. 0 . 0 ) ;
967+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
968+ assert_eq ! ( QUEUE_END_ID , id) ;
969+
970+ assert_eq ! ( 0 , waitlist. queue. len( ) ) ;
971+ }
972+
871973 #[ cfg( feature = "nightly" ) ]
872974 mod bench {
873975 use futures_util:: future:: { FutureExt , TryFutureExt } ;
0 commit comments