Skip to content

Commit 1678fe7

Browse files
committed
fix(cli): validate background forward cleanup pid
Signed-off-by: Shiju <shiju@nvidia.com>
1 parent 301909f commit 1678fe7

2 files changed

Lines changed: 184 additions & 32 deletions

File tree

crates/openshell-cli/src/ssh.rs

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use nix::sys::signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction
1010
#[cfg(unix)]
1111
use nix::unistd::Pid;
1212
use openshell_core::ObjectId;
13+
#[cfg(unix)]
14+
use openshell_core::forward::pid_matches_forward;
1315
use openshell_core::forward::{
1416
ForwardSpec, build_proxy_command, find_ssh_forward_pid, format_gateway_url,
1517
resolve_ssh_gateway, shell_escape, validate_ssh_session_response, write_forward_pid,
@@ -31,11 +33,15 @@ use tokio::net::TcpStream;
3133
use tokio::process::{Child, Command as TokioCommand};
3234
use tokio_stream::wrappers::ReceiverStream;
3335

34-
/// Time budget for a forward to become healthy after `ssh` starts: covers both
35-
/// backgrounded-PID discovery and listener readiness, in foreground and
36-
/// background.
37-
const FORWARD_STARTUP_GRACE_PERIOD: Duration = Duration::from_secs(2);
38-
/// Delay between listener/PID probes within the grace period.
36+
/// Time budget for finding the OpenSSH background process after `ssh -f`
37+
/// returns. PID discovery is separate from listener readiness so missing
38+
/// process tracking still fails quickly.
39+
const FORWARD_PID_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(2);
40+
/// Time budget for the local listener to become reachable after `ssh` starts.
41+
/// This is a user-visible readiness deadline for both foreground and background
42+
/// forwards, not a soft cleanup grace period.
43+
const FORWARD_LISTENER_READINESS_TIMEOUT: Duration = Duration::from_secs(10);
44+
/// Delay between listener/PID probes within the configured timeout.
3945
const FORWARD_LISTENER_PROBE_INTERVAL: Duration = Duration::from_millis(50);
4046
/// Per-attempt connect timeout, so one hung probe cannot consume the whole
4147
/// grace period.
@@ -378,11 +384,11 @@ pub async fn sandbox_forward(
378384
)
379385
})?;
380386

381-
if let Err(err) = wait_for_forward_listener(spec, FORWARD_STARTUP_GRACE_PERIOD)
387+
if let Err(err) = wait_for_forward_listener(spec, FORWARD_LISTENER_READINESS_TIMEOUT)
382388
.await
383389
.wrap_err("ssh process started but local forward listener was not reachable")
384390
{
385-
terminate_forward_pid(pid);
391+
terminate_forward_pid(pid, port, &session.sandbox_id);
386392
return Err(err);
387393
}
388394

@@ -412,7 +418,7 @@ pub async fn sandbox_forward(
412418
/// session) means forwarding never came up, so it errors instead of waiting
413419
/// out the grace period.
414420
async fn wait_for_foreground_forward_start(child: &mut Child, spec: &ForwardSpec) -> Result<()> {
415-
let listener = wait_for_forward_listener(spec, FORWARD_STARTUP_GRACE_PERIOD);
421+
let listener = wait_for_forward_listener(spec, FORWARD_LISTENER_READINESS_TIMEOUT);
416422
tokio::pin!(listener);
417423
tokio::select! {
418424
result = &mut listener => result,
@@ -439,7 +445,7 @@ async fn wait_for_foreground_forward_start(child: &mut Child, spec: &ForwardSpec
439445
/// so the PID is unknown when `command.status()` returns and must be discovered
440446
/// afterward. Returns `None` if it never appears within the grace period.
441447
async fn wait_for_ssh_forward_pid(sandbox_id: &str, port: u16) -> Option<u32> {
442-
let deadline = tokio::time::Instant::now() + FORWARD_STARTUP_GRACE_PERIOD;
448+
let deadline = tokio::time::Instant::now() + FORWARD_PID_DISCOVERY_TIMEOUT;
443449
loop {
444450
if let Some(pid) = find_ssh_forward_pid(sandbox_id, port) {
445451
return Some(pid);
@@ -513,20 +519,26 @@ fn forward_probe_host(spec: &ForwardSpec) -> &str {
513519
/// are ignored: the process may already be exiting, and the caller surfaces the
514520
/// original listener error regardless.
515521
#[cfg(unix)]
516-
fn terminate_forward_pid(pid: u32) {
522+
fn terminate_forward_pid(pid: u32, port: u16, sandbox_id: &str) {
517523
let Ok(raw_pid) = i32::try_from(pid) else {
518524
return;
519525
};
520526
if raw_pid <= 0 {
521527
return;
522528
}
529+
if !pid_matches_forward(pid, port, Some(sandbox_id)) {
530+
// The PID came from a process-table scan, not a file we own. Re-check
531+
// immediately before signaling so a stale or spoofed match is left
532+
// untouched instead of terminating an unrelated process.
533+
return;
534+
}
523535

524536
let _ = nix::sys::signal::kill(Pid::from_raw(raw_pid), Signal::SIGTERM);
525537
}
526538

527539
/// Non-Unix builds cannot manage OpenSSH process IDs with Unix signals.
528540
#[cfg(not(unix))]
529-
fn terminate_forward_pid(_pid: u32) {}
541+
fn terminate_forward_pid(_pid: u32, _port: u16, _sandbox_id: &str) {}
530542

531543
fn foreground_forward_started_message(name: &str, spec: &ForwardSpec) -> String {
532544
format!(
@@ -1793,6 +1805,48 @@ mod tests {
17931805
assert!(text.contains("local forward listener did not open"));
17941806
}
17951807

1808+
#[cfg(unix)]
1809+
#[test]
1810+
fn terminate_forward_pid_skips_process_that_no_longer_matches_forward() {
1811+
let dir = tempfile::tempdir().unwrap();
1812+
let terminated_path = dir.path().join("terminated");
1813+
let mut child = Command::new("python3")
1814+
.arg("-c")
1815+
.arg(
1816+
r#"
1817+
import pathlib
1818+
import signal
1819+
import sys
1820+
import time
1821+
1822+
terminated_path = pathlib.Path(sys.argv[1])
1823+
1824+
def stop(_signum, _frame):
1825+
terminated_path.write_text("terminated")
1826+
raise SystemExit(0)
1827+
1828+
signal.signal(signal.SIGTERM, stop)
1829+
1830+
while True:
1831+
time.sleep(1)
1832+
"#,
1833+
)
1834+
.arg(&terminated_path)
1835+
.spawn()
1836+
.unwrap();
1837+
std::thread::sleep(Duration::from_millis(100));
1838+
1839+
terminate_forward_pid(child.id(), 43210, "id-spoofed-forward");
1840+
std::thread::sleep(Duration::from_millis(200));
1841+
1842+
assert!(
1843+
!terminated_path.exists(),
1844+
"mismatched process should not receive SIGTERM"
1845+
);
1846+
let _ = child.kill();
1847+
let _ = child.wait();
1848+
}
1849+
17961850
#[test]
17971851
fn split_sandbox_path_separates_parent_and_basename() {
17981852
assert_eq!(

crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs

Lines changed: 119 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -727,10 +727,17 @@ fn install_fake_pgrep_no_match(dir: &TempDir) -> std::path::PathBuf {
727727
install_executable_script(dir, "pgrep", "#!/bin/sh\nexit 1\n")
728728
}
729729

730-
async fn wait_for_file(path: &std::path::Path, timeout: Duration) -> bool {
730+
async fn wait_for_process_exit(pid: u32, timeout: Duration) -> bool {
731731
let deadline = Instant::now() + timeout;
732732
loop {
733-
if path.exists() {
733+
let alive = std::process::Command::new("ps")
734+
.arg("-p")
735+
.arg(pid.to_string())
736+
.stdout(std::process::Stdio::null())
737+
.stderr(std::process::Stdio::null())
738+
.status()
739+
.is_ok_and(|status| status.success());
740+
if !alive {
734741
return true;
735742
}
736743
if Instant::now() >= deadline {
@@ -844,39 +851,87 @@ exit 1
844851
ssh_path
845852
}
846853

847-
fn install_fake_unreachable_forwarding_ssh(dir: &TempDir) -> std::path::PathBuf {
854+
struct FakeUnreachableForward {
855+
log_path: std::path::PathBuf,
856+
pid_path: std::path::PathBuf,
857+
}
858+
859+
fn install_fake_unreachable_forwarding_ssh(dir: &TempDir) -> FakeUnreachableForward {
860+
let log_path = dir.path().join("fake-forward.log");
848861
let pid_path = dir.path().join("fake-forward.pid");
849-
let terminated_path = dir.path().join("fake-forward.terminated");
862+
let ready_path = dir.path().join("fake-forward.ready");
850863
install_executable_script(
851864
dir,
852865
"ssh",
853866
r#"#!/bin/sh
854867
set -eu
855868
856-
nohup python3 -c '
869+
forward=""
870+
sandbox_id=""
871+
previous=""
872+
873+
for arg in "$@"; do
874+
if [ "$previous" = "-L" ]; then
875+
forward="$arg"
876+
previous=""
877+
continue
878+
fi
879+
880+
if [ "$previous" = "-o" ]; then
881+
case "$arg" in
882+
ProxyCommand=*)
883+
sandbox_id="$(printf '%s\n' "$arg" | sed -n 's/.*--sandbox-id \([^ ]*\).*/\1/p')"
884+
;;
885+
esac
886+
previous=""
887+
continue
888+
fi
889+
890+
case "$arg" in
891+
-L|-o)
892+
previous="$arg"
893+
;;
894+
esac
895+
done
896+
897+
if [ -z "$forward" ] || [ -z "$sandbox_id" ]; then
898+
exit 1
899+
fi
900+
901+
trap '' HUP
902+
python3 -c '
857903
import pathlib
858904
import signal
859905
import sys
860906
import time
861907
862-
terminated_path = pathlib.Path(sys.argv[1])
908+
ready_path = pathlib.Path(sys.argv[1])
863909
864-
def stop(_signum, _frame):
865-
terminated_path.write_text("terminated")
866-
raise SystemExit(0)
867-
868-
signal.signal(signal.SIGTERM, stop)
869-
signal.signal(signal.SIGINT, stop)
870910
signal.signal(signal.SIGHUP, signal.SIG_IGN)
871911
912+
ready_path.write_text("ready")
913+
872914
while True:
873915
time.sleep(1)
874-
' '@TERMINATED_PATH@' >/dev/null 2>&1 &
875-
echo $! > '@PID_PATH@'
916+
' '@READY_PATH@' ssh ssh-proxy --sandbox-id "$sandbox_id" -L "$forward" >'@LOG_PATH@' 2>&1 &
917+
pid="$!"
918+
i=0
919+
while [ "$i" -lt 100 ]; do
920+
if [ -e '@READY_PATH@' ]; then
921+
break
922+
fi
923+
i=$((i + 1))
924+
sleep 0.05
925+
done
926+
if [ ! -e '@READY_PATH@' ]; then
927+
exit 1
928+
fi
929+
echo "$pid" > '@PID_PATH@'
876930
877931
exit 0
878932
"#
879-
.replace("@TERMINATED_PATH@", &terminated_path.display().to_string())
933+
.replace("@LOG_PATH@", &log_path.display().to_string())
934+
.replace("@READY_PATH@", &ready_path.display().to_string())
880935
.replace("@PID_PATH@", &pid_path.display().to_string()),
881936
);
882937

@@ -896,7 +951,7 @@ exit 1
896951
),
897952
);
898953

899-
terminated_path
954+
FakeUnreachableForward { log_path, pid_path }
900955
}
901956

902957
fn test_env(fake_ssh_dir: &TempDir, xdg_dir: &TempDir) -> EnvVarGuard {
@@ -1553,14 +1608,37 @@ async fn sandbox_forward_background_fails_when_pid_is_not_discoverable() {
15531608
);
15541609
}
15551610

1611+
#[tokio::test]
1612+
async fn sandbox_forward_foreground_fails_when_ssh_exits_before_listener_opens() {
1613+
let server = run_server().await;
1614+
let fake_ssh_dir = tempfile::tempdir().unwrap();
1615+
let xdg_dir = tempfile::tempdir().unwrap();
1616+
let _env = test_env(&fake_ssh_dir, &xdg_dir);
1617+
let tls = test_tls(&server);
1618+
install_fake_ssh(&fake_ssh_dir);
1619+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1620+
let forward_port = listener.local_addr().unwrap().port();
1621+
drop(listener);
1622+
1623+
let spec = openshell_core::forward::ForwardSpec::new(forward_port);
1624+
let err = run::sandbox_forward(&server.endpoint, "foreground-forward", &spec, false, &tls)
1625+
.await
1626+
.expect_err("foreground forward should fail when ssh exits before listener readiness");
1627+
let msg = format!("{err}");
1628+
assert!(
1629+
msg.contains("ssh exited before local forward listener opened"),
1630+
"error should explain that ssh exited before listener readiness, got: {msg}",
1631+
);
1632+
}
1633+
15561634
#[tokio::test]
15571635
async fn sandbox_forward_background_terminates_discovered_pid_when_listener_never_opens() {
15581636
let server = run_server().await;
15591637
let fake_ssh_dir = tempfile::tempdir().unwrap();
15601638
let xdg_dir = tempfile::tempdir().unwrap();
15611639
let _env = test_env(&fake_ssh_dir, &xdg_dir);
15621640
let tls = test_tls(&server);
1563-
let terminated_path = install_fake_unreachable_forwarding_ssh(&fake_ssh_dir);
1641+
let fake_forward = install_fake_unreachable_forwarding_ssh(&fake_ssh_dir);
15641642
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
15651643
let forward_port = listener.local_addr().unwrap().port();
15661644
drop(listener);
@@ -1578,10 +1656,30 @@ async fn sandbox_forward_background_terminates_discovered_pid_when_listener_neve
15781656
openshell_core::forward::read_forward_pid("unreachable-forward", forward_port).is_none(),
15791657
"unreachable background forwards must not write a PID file",
15801658
);
1581-
assert!(
1582-
wait_for_file(&terminated_path, Duration::from_secs(2)).await,
1583-
"discovered background SSH process should be terminated after listener failure",
1584-
);
1659+
let pid = fs::read_to_string(&fake_forward.pid_path)
1660+
.expect("fake forward should record a PID")
1661+
.trim()
1662+
.parse::<u32>()
1663+
.expect("fake forward PID should be numeric");
1664+
if !wait_for_process_exit(pid, Duration::from_secs(2)).await {
1665+
let log = fs::read_to_string(&fake_forward.log_path).unwrap_or_default();
1666+
let command = std::process::Command::new("ps")
1667+
.arg("-ww")
1668+
.arg("-o")
1669+
.arg("command=")
1670+
.arg("-p")
1671+
.arg(pid.to_string())
1672+
.output()
1673+
.ok()
1674+
.map(|output| String::from_utf8_lossy(&output.stdout).to_string())
1675+
.unwrap_or_default();
1676+
panic!(
1677+
"discovered background SSH process should exit after listener failure cleanup; pid={}, command={}, log={}",
1678+
pid,
1679+
command.trim(),
1680+
log.trim(),
1681+
);
1682+
}
15851683
}
15861684

15871685
#[tokio::test]

0 commit comments

Comments
 (0)