Skip to content

Commit 8c5b43b

Browse files
committed
make port forwarding wait instead of exit
1 parent 54f0b98 commit 8c5b43b

5 files changed

Lines changed: 64 additions & 51 deletions

File tree

coman/src/cli/exec.rs

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ use base64::prelude::*;
44
use color_eyre::Result;
55
use iroh::{
66
Endpoint, SecretKey,
7+
endpoint::ConnectionError,
78
protocol::{ProtocolHandler, Router},
89
};
9-
use iroh_ssh::IrohSsh;
1010
use pid1::Pid1Settings;
1111
use rust_supervisor::{ChildType, Supervisor, SupervisorConfig};
12-
use tokio::{net::TcpStream, task::JoinSet};
12+
use tokio::{io::AsyncWriteExt, net::TcpStream};
1313

1414
const SECRET_KEY_ENV: &str = "COMAN_IROH_SECRET";
1515
const PORT_FORWARD_ENV: &str = "COMAN_FORWARDED_PORTS";
16+
const SSH_PORT: u16 = 15263;
1617

1718
fn get_secret_key() -> Option<Vec<u8>> {
1819
if let Ok(secret) = std::env::var(SECRET_KEY_ENV) {
@@ -23,19 +24,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2324
}
2425
}
2526

26-
#[tokio::main]
27-
async fn run_ssh() -> Result<()> {
28-
let mut builder = IrohSsh::builder().accept_incoming(true).accept_port(15263);
29-
if let Some(secret_key) = get_secret_key() {
30-
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
31-
builder = builder.secret_key(secret_key);
32-
}
33-
let server = builder.build().await.expect("couldn't create iroh server");
34-
println!("{}@{}", whoami::username(), server.node_id());
35-
tokio::signal::ctrl_c().await?;
36-
Ok(())
37-
}
38-
3927
#[derive(Debug)]
4028
struct PortForwardHandler {
4129
port: u16,
@@ -56,7 +44,15 @@ impl ProtocolHandler for PortForwardHandler {
5644

5745
let (mut local_read, mut local_write) = output_stream.split();
5846

59-
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
47+
let a_to_b = async move {
48+
let res = tokio::io::copy(&mut local_read, &mut iroh_send).await;
49+
if res.is_ok() {
50+
iroh_send.flush().await.expect("couldn't flush stream");
51+
iroh_send.finish().expect("couldn't finish stream");
52+
iroh_send.stopped().await.expect("stream not properly stopped");
53+
}
54+
res
55+
};
6056
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
6157

6258
tokio::select! {
@@ -67,6 +63,19 @@ impl ProtocolHandler for PortForwardHandler {
6763
println!("Iroh->{port} stream ended: {result:?}");
6864
},
6965
};
66+
// wait for client to close connection so we don't close prematurely
67+
let res = tokio::time::timeout(Duration::from_secs(3), async move {
68+
let closed = connection.closed().await;
69+
if !matches!(closed, ConnectionError::ApplicationClosed(_)) {
70+
println!("endpoint disconnected witn an error: {closed:#}");
71+
} else {
72+
println!("connection closed");
73+
}
74+
})
75+
.await;
76+
if res.is_err() {
77+
println!("endpoint did not disconnect within 3 seconds");
78+
}
7079
}
7180
Err(e) => {
7281
println!("Failed to connect to local server {port}: {e}");
@@ -88,30 +97,35 @@ async fn port_forward() -> Result<()> {
8897
};
8998
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
9099
let secret_key = SecretKey::from_bytes(secret_key);
91-
if let Ok(forwarded_ports) = std::env::var(PORT_FORWARD_ENV) {
92-
println!("setting up port forwarding...");
93-
let mut join_set = JoinSet::new();
94-
for port in forwarded_ports.split(',') {
95-
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
96-
let endpoint = Endpoint::builder()
97-
.secret_key(secret_key.clone())
98-
.alpns(vec![alpn.clone()])
99-
.bind()
100-
.await?;
101-
102-
let port = port.to_owned();
103-
join_set.spawn(async move {
104-
let handler = PortForwardHandler {
105-
port: port.parse::<u16>().expect("couldn't parse port"),
106-
};
107-
Router::builder(endpoint.clone()).accept(&alpn, handler).spawn();
108-
});
109-
}
110-
while let Some(res) = join_set.join_next().await {
111-
println!("Task joined: {res:?}");
112-
}
100+
let mut forwarded_ports = vec!["ssh".to_owned()];
101+
if let Ok(env_ports) = std::env::var(PORT_FORWARD_ENV) {
102+
forwarded_ports.extend(env_ports.split(',').map(|p| p.to_owned()).collect::<Vec<String>>());
103+
}
104+
let endpoint = Endpoint::builder().secret_key(secret_key.clone()).bind().await?;
105+
let id = endpoint.id();
106+
println!("endpoint: {id}");
107+
108+
println!("setting up port forwarding...");
109+
let mut builder = Router::builder(endpoint.clone());
110+
for port in forwarded_ports {
111+
let (port, alpn) = if port == "ssh" {
112+
(SSH_PORT, "/iroh/ssh".to_string())
113+
} else {
114+
(
115+
port.parse::<u16>().expect("couldn't parse port"),
116+
format!("/coman/{port}"),
117+
)
118+
};
119+
120+
let handler = PortForwardHandler { port };
121+
builder = builder.accept(alpn.clone().into_bytes(), handler);
122+
println!("set up port forwarding for port {port} ({alpn})");
113123
}
124+
let _router = builder.spawn();
125+
println!("port forwarding started");
114126

127+
let _ = tokio::signal::ctrl_c().await;
128+
println!("port forwarding stopped");
115129
Ok(())
116130
}
117131

@@ -125,11 +139,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125139
.expect("Launch failed");
126140

127141
let mut supervisor = Supervisor::new(SupervisorConfig::default());
128-
supervisor.add_process("iroh-ssh", ChildType::Permanent, || {
129-
thread::spawn(|| {
130-
let _ = run_ssh();
131-
})
132-
});
133142
supervisor.add_process("port-forward", ChildType::Permanent, || {
134143
thread::spawn(|| {
135144
let _ = port_forward();

coman/src/cli/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
pub mod app;
22
pub mod exec;
3-
pub mod port_forward;
43
pub mod proxy;

coman/src/cli/port_forward.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

coman/src/cscs/cli.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ pub(crate) async fn cli_cscs_port_forward(
132132
platform: Option<ComputePlatform>,
133133
) -> Result<()> {
134134
let job_id = maybe_job_id_from_name(job, system.clone(), platform.clone()).await?;
135+
println!("running port forward for job {job_id}");
135136
cscs_port_forward(job_id, source_port, destination_port, system).await
136137
}
137138

coman/src/cscs/handlers.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use base64::prelude::*;
1414
use color_eyre::{Result, eyre::eyre};
1515
use eyre::Context;
1616
use futures::StreamExt;
17-
use iroh::{Endpoint, EndpointId, SecretKey, protocol::Router};
17+
use iroh::{Endpoint, EndpointId, SecretKey};
1818
use itertools::Itertools;
1919
use regex::Regex;
2020
use reqwest::Url;
@@ -202,25 +202,30 @@ async fn process_port_forward(endpoint_id: EndpointId, destination_port: u16, mu
202202
let alpn: Vec<u8> = format!("/coman/{destination_port}").into_bytes();
203203
let secret_key = SecretKey::generate(&mut rand::rng());
204204
let endpoint = Endpoint::builder().secret_key(secret_key).bind().await?;
205-
Router::builder(endpoint.clone()).spawn(); // start local iroh listener
205+
// let _router = Router::builder(endpoint.clone()).spawn(); // start local iroh listener
206206

207207
match endpoint.connect(endpoint_id, &alpn).await {
208208
Ok(connection) => {
209209
let (mut iroh_send, mut iroh_recv) = connection.open_bi().await?;
210210
let (mut local_read, mut local_write) = socket.split();
211211
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
212-
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
212+
let b_to_a = async move {
213+
let res = tokio::io::copy(&mut iroh_recv, &mut local_write).await;
214+
if res.is_ok() {
215+
local_write.flush().await.expect("couldn't flush socket");
216+
}
217+
res
218+
};
213219
println!("connection open");
214220

215221
tokio::select! {
216222
result = a_to_b => {
217-
let _ = result;
223+
let _= dbg!(result);
218224
},
219225
result = b_to_a => {
220-
let _ = result;
226+
let _= dbg!(result);
221227
},
222228
};
223-
println!("connection closed");
224229

225230
Ok(())
226231
}

0 commit comments

Comments
 (0)