Skip to content

Commit 774eb6a

Browse files
committed
refactor(server): extract shared relay-await and sandbox-scan helpers
Replace three copy-pasted relay stream await match blocks in sandbox.rs with a single await_relay_stream() helper that handles timeout, drop, and error cases uniformly, varying only the log context string. Replace three copy-pasted paginated sandbox scan loops in provider.rs with a generic scan_sandboxes() helper that accepts a closure for the per-sandbox predicate and projection, eliminating the duplicated pagination arithmetic and decode boilerplate. No behavior changes; all existing tests pass.
1 parent 77e6c7a commit 774eb6a

2 files changed

Lines changed: 118 additions & 145 deletions

File tree

crates/openshell-server/src/grpc/provider.rs

Lines changed: 64 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,17 @@ pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result<
299299
.map_err(|e| Status::internal(format!("delete provider failed: {e}")))
300300
}
301301

302-
async fn sandboxes_using_provider(
303-
store: &Store,
304-
provider_name: &str,
305-
) -> Result<Vec<String>, Status> {
306-
let mut blocking = Vec::new();
307-
let mut offset = 0;
302+
/// Iterate over every `Sandbox` in the store and collect items produced by
303+
/// `f`. `f` receives each decoded sandbox; returning `Some(T)` includes the
304+
/// value in the output, `None` skips it.
305+
///
306+
/// This is the shared pagination kernel used by all sandbox-scan helpers.
307+
async fn scan_sandboxes<T, F>(store: &Store, mut f: F) -> Result<Vec<T>, Status>
308+
where
309+
F: FnMut(Sandbox) -> Option<T>,
310+
{
311+
let mut out = Vec::new();
312+
let mut offset = 0u32;
308313
loop {
309314
let records = store
310315
.list(Sandbox::object_type(), 1000, offset)
@@ -319,56 +324,50 @@ async fn sandboxes_using_provider(
319324
.map_err(|_| Status::internal("sandbox page size exceeded u32"))?,
320325
)
321326
.ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?;
322-
323327
for record in records {
324328
let sandbox = Sandbox::decode(record.payload.as_slice())
325329
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
326-
let Some(spec) = sandbox.spec.as_ref() else {
327-
continue;
328-
};
329-
if spec.providers.iter().any(|name| name == provider_name) {
330-
blocking.push(sandbox.object_name().to_string());
330+
if let Some(item) = f(sandbox) {
331+
out.push(item);
331332
}
332333
}
333334
}
334-
blocking.sort();
335-
blocking.dedup();
336-
Ok(blocking)
335+
Ok(out)
337336
}
338337

339-
async fn sandboxes_using_provider_records(
338+
async fn sandboxes_using_provider(
340339
store: &Store,
341340
provider_name: &str,
342-
) -> Result<Vec<Sandbox>, Status> {
343-
let mut sandboxes = Vec::new();
344-
let mut offset = 0;
345-
loop {
346-
let records = store
347-
.list(Sandbox::object_type(), 1000, offset)
348-
.await
349-
.map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?;
350-
if records.is_empty() {
351-
break;
341+
) -> Result<Vec<String>, Status> {
342+
let provider_name = provider_name.to_string();
343+
let mut names = scan_sandboxes(store, |sandbox| {
344+
let spec = sandbox.spec.as_ref()?;
345+
if spec.providers.iter().any(|n| n == &provider_name) {
346+
Some(sandbox.object_name().to_string())
347+
} else {
348+
None
352349
}
353-
offset = offset
354-
.checked_add(
355-
u32::try_from(records.len())
356-
.map_err(|_| Status::internal("sandbox page size exceeded u32"))?,
357-
)
358-
.ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?;
350+
})
351+
.await?;
352+
names.sort();
353+
names.dedup();
354+
Ok(names)
355+
}
359356

360-
for record in records {
361-
let sandbox = Sandbox::decode(record.payload.as_slice())
362-
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
363-
let Some(spec) = sandbox.spec.as_ref() else {
364-
continue;
365-
};
366-
if spec.providers.iter().any(|name| name == provider_name) {
367-
sandboxes.push(sandbox);
368-
}
357+
async fn sandboxes_using_provider_records(
358+
store: &Store,
359+
provider_name: &str,
360+
) -> Result<Vec<Sandbox>, Status> {
361+
let provider_name = provider_name.to_string();
362+
scan_sandboxes(store, |sandbox| {
363+
let spec = sandbox.spec.as_ref()?;
364+
if spec.providers.iter().any(|n| n == &provider_name) {
365+
Some(sandbox)
366+
} else {
367+
None
369368
}
370-
}
371-
Ok(sandboxes)
369+
})
370+
.await
372371
}
373372

374373
/// Merge an incoming map into an existing map.
@@ -1045,41 +1044,31 @@ fn has_errors(diagnostics: &[ProfileValidationDiagnostic]) -> bool {
10451044
}
10461045

10471046
async fn sandboxes_using_profile(store: &Store, profile_id: &str) -> Result<Vec<String>, Status> {
1048-
let mut blocking = Vec::new();
1049-
let mut offset = 0;
1050-
loop {
1051-
let records = store
1052-
.list(Sandbox::object_type(), 1000, offset)
1053-
.await
1054-
.map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?;
1055-
if records.is_empty() {
1056-
break;
1057-
}
1058-
offset = offset
1059-
.checked_add(
1060-
u32::try_from(records.len())
1061-
.map_err(|_| Status::internal("sandbox page size exceeded u32"))?,
1062-
)
1063-
.ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?;
1047+
// Collect all sandboxes that reference at least one provider — pagination
1048+
// is handled by `scan_sandboxes`; the async provider lookup happens below.
1049+
let candidates = scan_sandboxes(store, |sandbox| {
1050+
let has_providers = sandbox
1051+
.spec
1052+
.as_ref()
1053+
.is_some_and(|s| !s.providers.is_empty());
1054+
has_providers.then_some(sandbox)
1055+
})
1056+
.await?;
10641057

1065-
for record in records {
1066-
let sandbox = Sandbox::decode(record.payload.as_slice())
1067-
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
1068-
let Some(spec) = sandbox.spec.as_ref() else {
1058+
let mut blocking = Vec::new();
1059+
for sandbox in candidates {
1060+
let spec = sandbox.spec.as_ref().expect("filtered by scan_sandboxes");
1061+
for provider_name in &spec.providers {
1062+
let Some(provider) = store
1063+
.get_message_by_name::<Provider>(provider_name)
1064+
.await
1065+
.map_err(|e| Status::internal(format!("fetch provider failed: {e}")))?
1066+
else {
10691067
continue;
10701068
};
1071-
for provider_name in &spec.providers {
1072-
let Some(provider) = store
1073-
.get_message_by_name::<Provider>(provider_name)
1074-
.await
1075-
.map_err(|e| Status::internal(format!("fetch provider failed: {e}")))?
1076-
else {
1077-
continue;
1078-
};
1079-
if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) {
1080-
blocking.push(sandbox.object_name().to_string());
1081-
break;
1082-
}
1069+
if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) {
1070+
blocking.push(sandbox.object_name().to_string());
1071+
break;
10831072
}
10841073
}
10851074
}

crates/openshell-server/src/grpc/sandbox.rs

Lines changed: 54 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use std::pin::Pin;
3030
use std::sync::Arc;
3131
use std::sync::atomic::{AtomicBool, Ordering};
3232
use tokio::net::{TcpListener, TcpStream};
33-
use tokio::sync::mpsc;
33+
use tokio::sync::{mpsc, oneshot};
3434
use tokio_stream::wrappers::ReceiverStream;
3535
use tonic::{Request, Response, Status};
3636
use tracing::{debug, info, warn};
@@ -737,29 +737,10 @@ pub(super) async fn handle_exec_sandbox(
737737
let (tx, rx) = mpsc::channel::<Result<ExecSandboxEvent, Status>>(256);
738738
tokio::spawn(async move {
739739
// Wait for the supervisor's reverse CONNECT to deliver the relay stream.
740-
let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx)
741-
.await
742-
{
743-
Ok(Ok(Ok(stream))) => stream,
744-
Ok(Ok(Err(status))) => {
745-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandbox: relay target open failed");
746-
let _ = tx.send(Err(status)).await;
747-
return;
748-
}
749-
Ok(Err(_)) => {
750-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped");
751-
let _ = tx
752-
.send(Err(Status::unavailable("relay channel dropped")))
753-
.await;
754-
return;
755-
}
756-
Err(_) => {
757-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay open timed out");
758-
let _ = tx
759-
.send(Err(Status::deadline_exceeded("relay open timed out")))
760-
.await;
761-
return;
762-
}
740+
let Some(relay_stream) =
741+
await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ExecSandbox").await
742+
else {
743+
return;
763744
};
764745

765746
if let Err(err) = stream_exec_over_relay(
@@ -782,6 +763,41 @@ pub(super) async fn handle_exec_sandbox(
782763
Ok(Response::new(ReceiverStream::new(rx)))
783764
}
784765

766+
/// Wait for the supervisor's reverse CONNECT to deliver a relay stream.
767+
///
768+
/// Returns `Some(stream)` on success. On any failure the error is sent on `tx`
769+
/// and `None` is returned; the caller should then `return` immediately.
770+
async fn await_relay_stream<T: Send + 'static>(
771+
relay_rx: oneshot::Receiver<Result<tokio::io::DuplexStream, Status>>,
772+
tx: &mpsc::Sender<Result<T, Status>>,
773+
sandbox_id: &str,
774+
channel_id: &str,
775+
context: &str,
776+
) -> Option<tokio::io::DuplexStream> {
777+
match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx).await {
778+
Ok(Ok(Ok(stream))) => Some(stream),
779+
Ok(Ok(Err(status))) => {
780+
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "{context}: relay target open failed");
781+
let _ = tx.send(Err(status)).await;
782+
None
783+
}
784+
Ok(Err(_)) => {
785+
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay channel dropped");
786+
let _ = tx
787+
.send(Err(Status::unavailable("relay channel dropped")))
788+
.await;
789+
None
790+
}
791+
Err(_) => {
792+
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay open timed out");
793+
let _ = tx
794+
.send(Err(Status::deadline_exceeded("relay open timed out")))
795+
.await;
796+
None
797+
}
798+
}
799+
}
800+
785801
pub(super) async fn handle_forward_tcp(
786802
state: &Arc<ServerState>,
787803
request: Request<tonic::Streaming<TcpForwardFrame>>,
@@ -831,29 +847,10 @@ pub(super) async fn handle_forward_tcp(
831847
let (tx, rx) = mpsc::channel::<Result<TcpForwardFrame, Status>>(256);
832848
tokio::spawn(async move {
833849
let _connection_guard = connection_guard;
834-
let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx)
835-
.await
836-
{
837-
Ok(Ok(Ok(stream))) => stream,
838-
Ok(Ok(Err(status))) => {
839-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ForwardTcp: relay target open failed");
840-
let _ = tx.send(Err(status)).await;
841-
return;
842-
}
843-
Ok(Err(_)) => {
844-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay channel dropped");
845-
let _ = tx
846-
.send(Err(Status::unavailable("relay channel dropped")))
847-
.await;
848-
return;
849-
}
850-
Err(_) => {
851-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay open timed out");
852-
let _ = tx
853-
.send(Err(Status::deadline_exceeded("relay open timed out")))
854-
.await;
855-
return;
856-
}
850+
let Some(relay_stream) =
851+
await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ForwardTcp").await
852+
else {
853+
return;
857854
};
858855

859856
bridge_forward_tcp_stream(inbound, relay_stream, tx, &sandbox_id, &channel_id).await;
@@ -1179,29 +1176,16 @@ pub(super) async fn handle_exec_sandbox_interactive(
11791176

11801177
let (tx, rx) = mpsc::channel::<Result<ExecSandboxEvent, Status>>(256);
11811178
tokio::spawn(async move {
1182-
let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx)
1183-
.await
1184-
{
1185-
Ok(Ok(Ok(stream))) => stream,
1186-
Ok(Ok(Err(status))) => {
1187-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandboxInteractive: relay target open failed");
1188-
let _ = tx.send(Err(status)).await;
1189-
return;
1190-
}
1191-
Ok(Err(_)) => {
1192-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandboxInteractive: relay channel dropped");
1193-
let _ = tx
1194-
.send(Err(Status::unavailable("relay channel dropped")))
1195-
.await;
1196-
return;
1197-
}
1198-
Err(_) => {
1199-
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandboxInteractive: relay open timed out");
1200-
let _ = tx
1201-
.send(Err(Status::deadline_exceeded("relay open timed out")))
1202-
.await;
1203-
return;
1204-
}
1179+
let Some(relay_stream) = await_relay_stream(
1180+
relay_rx,
1181+
&tx,
1182+
&sandbox_id,
1183+
&channel_id,
1184+
"ExecSandboxInteractive",
1185+
)
1186+
.await
1187+
else {
1188+
return;
12051189
};
12061190

12071191
if let Err(err) = stream_interactive_exec_over_relay(

0 commit comments

Comments
 (0)