Skip to content

Commit 01c0ae5

Browse files
committed
make port forwarding wait instead of exit
1 parent 54f0b98 commit 01c0ae5

2 files changed

Lines changed: 25 additions & 42 deletions

File tree

coman/src/cli/exec.rs

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ use iroh::{
66
Endpoint, SecretKey,
77
protocol::{ProtocolHandler, Router},
88
};
9-
use iroh_ssh::IrohSsh;
109
use pid1::Pid1Settings;
1110
use rust_supervisor::{ChildType, Supervisor, SupervisorConfig};
12-
use tokio::{net::TcpStream, task::JoinSet};
11+
use tokio::net::TcpStream;
1312

1413
const SECRET_KEY_ENV: &str = "COMAN_IROH_SECRET";
1514
const PORT_FORWARD_ENV: &str = "COMAN_FORWARDED_PORTS";
15+
const SSH_PORT: u16 = 15263;
1616

1717
fn get_secret_key() -> Option<Vec<u8>> {
1818
if let Ok(secret) = std::env::var(SECRET_KEY_ENV) {
@@ -23,19 +23,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2323
}
2424
}
2525

26-
#[tokio::main]
27-
async fn run_ssh() -> Result<()> {
28-
let mut builder = IrohSsh::builder().accept_incoming(true).accept_port(15263);
29-
if let Some(secret_key) = get_secret_key() {
30-
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
31-
builder = builder.secret_key(secret_key);
32-
}
33-
let server = builder.build().await.expect("couldn't create iroh server");
34-
println!("{}@{}", whoami::username(), server.node_id());
35-
tokio::signal::ctrl_c().await?;
36-
Ok(())
37-
}
38-
3926
#[derive(Debug)]
4027
struct PortForwardHandler {
4128
port: u16,
@@ -88,30 +75,30 @@ async fn port_forward() -> Result<()> {
8875
};
8976
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
9077
let secret_key = SecretKey::from_bytes(secret_key);
91-
if let Ok(forwarded_ports) = std::env::var(PORT_FORWARD_ENV) {
92-
println!("setting up port forwarding...");
93-
let mut join_set = JoinSet::new();
94-
for port in forwarded_ports.split(',') {
95-
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
96-
let endpoint = Endpoint::builder()
97-
.secret_key(secret_key.clone())
98-
.alpns(vec![alpn.clone()])
99-
.bind()
100-
.await?;
101-
102-
let port = port.to_owned();
103-
join_set.spawn(async move {
104-
let handler = PortForwardHandler {
105-
port: port.parse::<u16>().expect("couldn't parse port"),
106-
};
107-
Router::builder(endpoint.clone()).accept(&alpn, handler).spawn();
108-
});
109-
}
110-
while let Some(res) = join_set.join_next().await {
111-
println!("Task joined: {res:?}");
112-
}
78+
let mut forwarded_ports = vec!["ssh".to_owned()];
79+
if let Ok(env_ports) = std::env::var(PORT_FORWARD_ENV) {
80+
forwarded_ports.extend(env_ports.split(',').map(|p| p.to_owned()).collect::<Vec<String>>());
81+
}
82+
let endpoint = Endpoint::builder().secret_key(secret_key.clone()).bind().await?;
83+
84+
println!("setting up port forwarding...");
85+
let mut builder = Router::builder(endpoint.clone());
86+
for port in forwarded_ports {
87+
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
88+
let port = if port == "ssh" {
89+
SSH_PORT
90+
} else {
91+
port.parse::<u16>().expect("couldn't parse port")
92+
};
93+
94+
let handler = PortForwardHandler { port };
95+
builder = builder.accept(alpn, handler);
96+
println!("set up port forwarding for port {port}");
11397
}
98+
builder.spawn();
99+
println!("port forwarding started");
114100

101+
let _ = tokio::signal::ctrl_c().await;
115102
Ok(())
116103
}
117104

@@ -125,11 +112,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125112
.expect("Launch failed");
126113

127114
let mut supervisor = Supervisor::new(SupervisorConfig::default());
128-
supervisor.add_process("iroh-ssh", ChildType::Permanent, || {
129-
thread::spawn(|| {
130-
let _ = run_ssh();
131-
})
132-
});
133115
supervisor.add_process("port-forward", ChildType::Permanent, || {
134116
thread::spawn(|| {
135117
let _ = port_forward();

coman/src/cscs/cli.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ pub(crate) async fn cli_cscs_port_forward(
132132
platform: Option<ComputePlatform>,
133133
) -> Result<()> {
134134
let job_id = maybe_job_id_from_name(job, system.clone(), platform.clone()).await?;
135+
println!("running port forward for job {job_id}");
135136
cscs_port_forward(job_id, source_port, destination_port, system).await
136137
}
137138

0 commit comments

Comments
 (0)