@@ -26,6 +26,7 @@ use hyper::Response;
26
26
use hyper:: client:: ResponseFuture ;
27
27
use hyper:: client:: connect:: Connect ;
28
28
use tokio:: io:: AsyncReadExt ;
29
+ use tokio:: sync:: Semaphore ;
29
30
use tokio_util:: io:: StreamReader ;
30
31
31
32
use crate :: HttpError ;
@@ -49,6 +50,9 @@ pub struct HttpClient {
49
50
supports_vpnless : bool ,
50
51
http2 : bool ,
51
52
stats : HttpNetworkStats ,
53
+ // tokio::sync::Semaphore doesn't impl Allocative
54
+ #[ allocative( skip) ]
55
+ concurrent_requests_budget : Arc < Semaphore > ,
52
56
}
53
57
54
58
impl HttpClient {
@@ -124,6 +128,7 @@ impl HttpClient {
124
128
) ;
125
129
change_scheme_to_http ( & mut request) ?;
126
130
}
131
+ let semaphore_guard = self . concurrent_requests_budget . acquire ( ) . await . unwrap ( ) ;
127
132
let resp = self . inner . request ( request) . await . map_err ( |e| {
128
133
if is_hyper_error_due_to_timeout ( & e) {
129
134
HttpError :: Timeout {
@@ -134,11 +139,14 @@ impl HttpClient {
134
139
HttpError :: SendRequest { uri, source : e }
135
140
}
136
141
} ) ?;
137
- Ok (
138
- resp. map ( |body| {
139
- CountingStream :: new ( body, self . stats . downloaded_bytes ( ) . dupe ( ) ) . boxed ( )
140
- } ) ,
141
- )
142
+ Ok ( resp. map ( move |body| {
143
+ CountingStream :: new ( body, self . stats . downloaded_bytes ( ) . dupe ( ) )
144
+ . inspect ( move |_| {
145
+ // Ensure we keep a concurrent request permit alive until the stream is consumed
146
+ let _guard = & semaphore_guard;
147
+ } )
148
+ . boxed ( )
149
+ } ) )
142
150
}
143
151
144
152
/// Send a generic request.
@@ -768,6 +776,38 @@ mod tests {
768
776
769
777
Ok ( ( ) )
770
778
}
779
+
780
+ #[ tokio:: test]
781
+ async fn test_concurrency_limit ( ) -> buck2_error:: Result < ( ) > {
782
+ let test_server = httptest:: Server :: run ( ) ;
783
+ test_server. expect (
784
+ Expectation :: matching ( request:: method_path ( "GET" , "/foo" ) )
785
+ . times ( 3 )
786
+ . respond_with ( responders:: status_code ( 200 ) ) ,
787
+ ) ;
788
+
789
+ let client = HttpClientBuilder :: https_with_system_roots ( )
790
+ . await ?
791
+ . with_max_concurrent_requests ( 2 )
792
+ . build ( ) ;
793
+ let url = test_server. url_str ( "/foo" ) ;
794
+ let req1 = client. get ( & url) . await ?;
795
+ let req2 = client. get ( & url) . await ?;
796
+ assert_eq ! ( client. concurrent_requests_budget. available_permits( ) , 0 ) ;
797
+ let mut req3 = std:: pin:: pin!( client. get( & url) ) ;
798
+ assert ! (
799
+ tokio:: time:: timeout( tokio:: time:: Duration :: from_millis( 100 ) , & mut req3)
800
+ . await
801
+ . is_err( )
802
+ ) ;
803
+ drop ( req1) ;
804
+ req3. await ?;
805
+ assert_eq ! ( client. concurrent_requests_budget. available_permits( ) , 1 ) ;
806
+ drop ( req2) ;
807
+ assert_eq ! ( client. concurrent_requests_budget. available_permits( ) , 2 ) ;
808
+
809
+ Ok ( ( ) )
810
+ }
771
811
}
772
812
773
813
// TODO(skarlage, T160529958): Debug why these tests fail on CircleCI
0 commit comments