Skip to content

Commit f25e7ac

Browse files
authored
connection_manager: replace usize counter with Semaphore RAII permit (TraceMachina#2350)
1 parent 0cf6af7 commit f25e7ac

2 files changed

Lines changed: 243 additions & 29 deletions

File tree

nativelink-util/src/connection_manager.rs

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use futures::Future;
2222
use futures::stream::{FuturesUnordered, StreamExt, unfold};
2323
use nativelink_config::stores::Retry;
2424
use nativelink_error::{Code, Error, make_err};
25-
use tokio::sync::{mpsc, oneshot};
25+
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
2626
use tonic::transport::{Channel, Endpoint, channel};
2727
use tracing::{debug, error, info, warn};
2828

@@ -95,8 +95,13 @@ struct ConnectionManagerWorker {
9595
endpoints: Vec<(ConnectionIndex, Endpoint)>,
9696
/// The channel used to communicate between a Connection and the worker.
9797
connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
98-
/// The number of connections that are currently allowed to be made.
99-
available_connections: usize,
98+
/// Gates the maximum number of in-flight `Connection` objects.
99+
/// Was an explicit `usize` counter; now an `Arc<Semaphore>` so the
100+
/// `OwnedSemaphorePermit` held by each `Connection` releases on
101+
/// drop (RAII), instead of relying on a `ConnectionRequest::Dropped`
102+
/// round-trip that could be lost on tonic transport errors or task
103+
/// aborts.
104+
available_connections: Arc<Semaphore>,
100105
/// Channels that are currently being connected.
101106
connecting_channels: FuturesUnordered<Pin<Box<dyn Future<Output = IndexedChannel> + Send>>>,
102107
/// Connected channels that are available for use.
@@ -136,14 +141,16 @@ impl ConnectionManager {
136141
.collect();
137142

138143
if max_concurrent_requests == 0 {
139-
max_concurrent_requests = usize::MAX;
144+
max_concurrent_requests = Semaphore::MAX_PERMITS;
145+
} else {
146+
max_concurrent_requests = max_concurrent_requests.min(Semaphore::MAX_PERMITS);
140147
}
141148
if connections_per_endpoint == 0 {
142149
connections_per_endpoint = 1;
143150
}
144151
let worker = ConnectionManagerWorker {
145152
endpoints,
146-
available_connections: max_concurrent_requests,
153+
available_connections: Arc::new(Semaphore::new(max_concurrent_requests)),
147154
connection_tx,
148155
connecting_channels: FuturesUnordered::new(),
149156
available_channels: VecDeque::new(),
@@ -309,15 +316,15 @@ impl ConnectionManagerWorker {
309316

310317
// This must never be made async otherwise the select may cancel it.
311318
fn handle_worker(&mut self, reason: String, tx: oneshot::Sender<Connection>) {
312-
if let Some(channel) = (self.available_connections > 0)
313-
.then_some(())
314-
.and_then(|()| self.available_channels.pop_front())
319+
let maybe_permit = self.available_connections.clone().try_acquire_owned().ok();
320+
if let Some(permit) = maybe_permit
321+
&& let Some(channel) = self.available_channels.pop_front()
315322
{
316323
debug!(reason, "ConnectionManager: request running");
317-
self.provide_channel(channel, tx);
324+
self.provide_channel(channel, tx, permit);
318325
} else {
319326
debug!(
320-
available_connections = self.available_connections,
327+
available_permits = self.available_connections.available_permits(),
321328
available_channels = self.available_channels.len(),
322329
waiting_connections = self.waiting_connections.len(),
323330
reason,
@@ -327,31 +334,36 @@ impl ConnectionManagerWorker {
327334
}
328335
}
329336

330-
fn provide_channel(&mut self, channel: EstablishedChannel, tx: oneshot::Sender<Connection>) {
331-
// We decrement here because we create Connection, this will signal when
332-
// it is Dropped and therefore increment this again.
333-
self.available_connections -= 1;
337+
fn provide_channel(
338+
&mut self,
339+
channel: EstablishedChannel,
340+
tx: oneshot::Sender<Connection>,
341+
permit: OwnedSemaphorePermit,
342+
) {
334343
drop(tx.send(Connection {
335344
tx: self.connection_tx.clone(),
336345
pending_channel: Some(channel.channel.clone()),
337346
channel,
347+
_permit: permit,
338348
}));
339349
}
340350

341351
fn maybe_available_connection(&mut self) {
342-
while self.available_connections > 0
343-
&& !self.waiting_connections.is_empty()
344-
&& !self.available_channels.is_empty()
345-
{
346-
if let Some(channel) = self.available_channels.pop_front() {
347-
if let Some((reason, tx)) = self.waiting_connections.pop_front() {
348-
debug!(reason, "ConnectionManager: channel available, running");
349-
self.provide_channel(channel, tx);
350-
} else {
351-
// This should never happen, but better than an unwrap.
352-
self.available_channels.push_front(channel);
353-
}
354-
}
352+
while !self.waiting_connections.is_empty() && !self.available_channels.is_empty() {
353+
let Some(permit) = self.available_connections.clone().try_acquire_owned().ok() else {
354+
break;
355+
};
356+
let Some(channel) = self.available_channels.pop_front() else {
357+
drop(permit);
358+
break;
359+
};
360+
let Some((reason, tx)) = self.waiting_connections.pop_front() else {
361+
self.available_channels.push_front(channel);
362+
drop(permit);
363+
break;
364+
};
365+
debug!(reason, "ConnectionManager: channel available, running");
366+
self.provide_channel(channel, tx, permit);
355367
}
356368
}
357369

@@ -362,7 +374,6 @@ impl ConnectionManagerWorker {
362374
if let Some(channel) = maybe_channel {
363375
self.available_channels.push_back(channel);
364376
}
365-
self.available_connections += 1;
366377
self.maybe_available_connection();
367378
}
368379
ConnectionRequest::Connected(channel) => {
@@ -394,7 +405,8 @@ impl ConnectionManagerWorker {
394405
/// re-connecting the underlying channel on error. It depends on users
395406
/// reporting all errors.
396407
/// NOTE: This should never be cloneable because its lifetime is linked to the
397-
/// `ConnectionManagerWorker::available_connections`.
408+
/// semaphore permit it carries — `_permit` is released exactly once,
409+
/// when the `Connection` drops.
398410
#[derive(Debug)]
399411
pub struct Connection {
400412
/// Communication with `ConnectionManagerWorker` to inform about transport
@@ -406,6 +418,7 @@ pub struct Connection {
406418
pending_channel: Option<Channel>,
407419
/// The identifier to send to `tx`.
408420
channel: EstablishedChannel,
421+
_permit: OwnedSemaphorePermit,
409422
}
410423

411424
impl Drop for Connection {
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// Copyright 2026 The NativeLink Authors. All rights reserved.
2+
//
3+
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// See LICENSE file for details
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//! Tests for `ConnectionManager`'s permit accounting.
16+
//!
17+
//! The bug these tests exist to prevent: in production we observed
18+
//! `available_connections: 18446744073709551589` (`u64::MAX − 26`) while
19+
//! `waiting_connections` climbed unbounded, ultimately killing the worker
20+
//! process via `OOMKilled` (exit 137). Switching from a manual `usize`
21+
//! counter to `Arc<Semaphore>` with `OwnedSemaphorePermit` makes the leak
22+
//! structurally impossible — these tests pin that property by exercising
23+
//! the full request-acquire-release cycle through the public API many
24+
//! times over a tight permit budget. With a leak, the cycle eventually
25+
//! blocks forever; without one, every iteration completes inside the
26+
//! per-call timeout.
27+
28+
use core::pin::Pin;
29+
use core::time::Duration;
30+
use std::sync::Arc;
31+
32+
use nativelink_config::stores::Retry;
33+
use nativelink_error::Error;
34+
use nativelink_macro::nativelink_test;
35+
use nativelink_proto::google::bytestream::byte_stream_server::{ByteStream, ByteStreamServer};
36+
use nativelink_proto::google::bytestream::{
37+
QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest,
38+
WriteResponse,
39+
};
40+
use nativelink_util::background_spawn;
41+
use nativelink_util::connection_manager::ConnectionManager;
42+
use pretty_assertions::assert_eq;
43+
use tokio::time::timeout;
44+
use tokio_stream::Stream;
45+
use tonic::transport::server::TcpIncoming;
46+
use tonic::transport::{Endpoint, Server};
47+
use tonic::{Request, Response, Status, Streaming};
48+
49+
#[derive(Clone)]
50+
struct FakeByteStream;
51+
52+
#[tonic::async_trait]
53+
impl ByteStream for FakeByteStream {
54+
type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>;
55+
56+
async fn read(
57+
&self,
58+
_request: Request<ReadRequest>,
59+
) -> Result<Response<Self::ReadStream>, Status> {
60+
Err(Status::unimplemented("fake"))
61+
}
62+
63+
async fn write(
64+
&self,
65+
_request: Request<Streaming<WriteRequest>>,
66+
) -> Result<Response<WriteResponse>, Status> {
67+
Err(Status::unimplemented("fake"))
68+
}
69+
70+
async fn query_write_status(
71+
&self,
72+
_request: Request<QueryWriteStatusRequest>,
73+
) -> Result<Response<QueryWriteStatusResponse>, Status> {
74+
Err(Status::unimplemented("fake"))
75+
}
76+
}
77+
78+
async fn fake_grpc_server_endpoint() -> Endpoint {
79+
let listener = TcpIncoming::bind("127.0.0.1:0".parse().unwrap()).unwrap();
80+
let port = listener.local_addr().unwrap().port();
81+
background_spawn!("connection_manager_test_server", async move {
82+
Server::builder()
83+
.add_service(ByteStreamServer::new(FakeByteStream))
84+
.serve_with_incoming(listener)
85+
.await
86+
.unwrap();
87+
});
88+
Endpoint::from_shared(format!("http://127.0.0.1:{port}")).unwrap()
89+
}
90+
91+
/// Identity jitter so retry timing stays predictable in tests.
92+
fn no_jitter() -> Arc<dyn Fn(Duration) -> Duration + Send + Sync> {
93+
Arc::new(|d| d)
94+
}
95+
96+
#[nativelink_test]
97+
async fn permits_released_on_drop_no_leak() -> Result<(), Error> {
98+
const MAX_CONCURRENT: usize = 2;
99+
const ITERATIONS: usize = 100;
100+
101+
let endpoint = fake_grpc_server_endpoint().await;
102+
let cm = ConnectionManager::new(
103+
vec![endpoint],
104+
/* connections_per_endpoint = */ MAX_CONCURRENT,
105+
MAX_CONCURRENT,
106+
Retry::default(),
107+
no_jitter(),
108+
);
109+
110+
for i in 0..ITERATIONS {
111+
let c1 = timeout(Duration::from_secs(5), cm.connection(format!("iter-{i}-a")))
112+
.await
113+
.unwrap_or_else(|_| panic!("iter {i}: first acquire blocked >5s — permit leak"))?;
114+
let c2 = timeout(Duration::from_secs(5), cm.connection(format!("iter-{i}-b")))
115+
.await
116+
.unwrap_or_else(|_| panic!("iter {i}: second acquire blocked >5s — permit leak"))?;
117+
drop(c1);
118+
drop(c2);
119+
}
120+
121+
Ok(())
122+
}
123+
124+
#[nativelink_test]
125+
async fn aborted_caller_future_does_not_leak_permits() -> Result<(), Error> {
126+
const MAX_CONCURRENT: usize = 2;
127+
128+
let endpoint = fake_grpc_server_endpoint().await;
129+
let cm = Arc::new(ConnectionManager::new(
130+
vec![endpoint],
131+
/* connections_per_endpoint = */ MAX_CONCURRENT,
132+
MAX_CONCURRENT,
133+
Retry::default(),
134+
no_jitter(),
135+
));
136+
137+
let mut handles = Vec::new();
138+
for i in 0..(MAX_CONCURRENT * 5) {
139+
let cm = Arc::clone(&cm);
140+
handles.push(tokio::spawn(async move {
141+
// Bind to `_conn` (not `_`) so the Connection lives until
142+
// task abort; bare `let _ = ...` would drop it immediately
143+
// and defeat the test.
144+
let _conn = cm.connection(format!("aborted-{i}")).await;
145+
futures::future::pending::<()>().await
146+
}));
147+
}
148+
tokio::time::sleep(Duration::from_millis(100)).await;
149+
for h in handles {
150+
h.abort();
151+
}
152+
tokio::time::sleep(Duration::from_millis(500)).await;
153+
let c1 = timeout(Duration::from_secs(5), cm.connection("post-abort-a".into()))
154+
.await
155+
.expect("post-abort acquire 1 blocked >5s — permit leak")?;
156+
let c2 = timeout(Duration::from_secs(5), cm.connection("post-abort-b".into()))
157+
.await
158+
.expect("post-abort acquire 2 blocked >5s — permit leak")?;
159+
drop(c1);
160+
drop(c2);
161+
Ok(())
162+
}
163+
164+
#[nativelink_test]
165+
async fn extra_request_above_max_blocks_until_a_release() -> Result<(), Error> {
166+
const MAX_CONCURRENT: usize = 2;
167+
168+
let endpoint = fake_grpc_server_endpoint().await;
169+
let cm = Arc::new(ConnectionManager::new(
170+
vec![endpoint],
171+
MAX_CONCURRENT + 1,
172+
MAX_CONCURRENT,
173+
Retry::default(),
174+
no_jitter(),
175+
));
176+
177+
let c1 = cm.connection("hold-1".into()).await?;
178+
let c2 = cm.connection("hold-2".into()).await?;
179+
180+
// Third request must be queued — racing it against a short timeout
181+
// proves it doesn't resolve while permits are exhausted.
182+
let cm_for_third = Arc::clone(&cm);
183+
let third = tokio::spawn(async move { cm_for_third.connection("queued-3".into()).await });
184+
tokio::time::sleep(Duration::from_millis(200)).await;
185+
assert_eq!(
186+
third.is_finished(),
187+
false,
188+
"third connection resolved while permits were exhausted",
189+
);
190+
191+
// Drop one held permit; the queued request should now resolve.
192+
drop(c1);
193+
let c3 = timeout(Duration::from_secs(5), third)
194+
.await
195+
.expect("queued request did not resolve within 5s of permit release")
196+
.unwrap()?;
197+
198+
drop(c2);
199+
drop(c3);
200+
Ok(())
201+
}

0 commit comments

Comments
 (0)