Skip to content

Commit 1240a42

Browse files
committed
make port forwarding wait instead of exit
1 parent 54f0b98 commit 1240a42

5 files changed

Lines changed: 62 additions & 51 deletions

File tree

coman/src/cli/exec.rs

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ use base64::prelude::*;
44
use color_eyre::Result;
55
use iroh::{
66
Endpoint, SecretKey,
7+
endpoint::ConnectionError,
78
protocol::{ProtocolHandler, Router},
89
};
9-
use iroh_ssh::IrohSsh;
1010
use pid1::Pid1Settings;
1111
use rust_supervisor::{ChildType, Supervisor, SupervisorConfig};
12-
use tokio::{net::TcpStream, task::JoinSet};
12+
use tokio::{io::AsyncWriteExt, net::TcpStream};
1313

1414
const SECRET_KEY_ENV: &str = "COMAN_IROH_SECRET";
1515
const PORT_FORWARD_ENV: &str = "COMAN_FORWARDED_PORTS";
16+
const SSH_PORT: u16 = 15263;
1617

1718
fn get_secret_key() -> Option<Vec<u8>> {
1819
if let Ok(secret) = std::env::var(SECRET_KEY_ENV) {
@@ -23,19 +24,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2324
}
2425
}
2526

26-
#[tokio::main]
27-
async fn run_ssh() -> Result<()> {
28-
let mut builder = IrohSsh::builder().accept_incoming(true).accept_port(15263);
29-
if let Some(secret_key) = get_secret_key() {
30-
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
31-
builder = builder.secret_key(secret_key);
32-
}
33-
let server = builder.build().await.expect("couldn't create iroh server");
34-
println!("{}@{}", whoami::username(), server.node_id());
35-
tokio::signal::ctrl_c().await?;
36-
Ok(())
37-
}
38-
3927
#[derive(Debug)]
4028
struct PortForwardHandler {
4129
port: u16,
@@ -56,7 +44,13 @@ impl ProtocolHandler for PortForwardHandler {
5644

5745
let (mut local_read, mut local_write) = output_stream.split();
5846

59-
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
47+
let a_to_b = async move {
48+
let res = tokio::io::copy(&mut local_read, &mut iroh_send).await;
49+
if res.is_ok() {
50+
iroh_send.flush().await.expect("couldn't flush stream");
51+
}
52+
res
53+
};
6054
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
6155

6256
tokio::select! {
@@ -67,6 +61,19 @@ impl ProtocolHandler for PortForwardHandler {
6761
println!("Iroh->{port} stream ended: {result:?}");
6862
},
6963
};
64+
// wait for client to close connection so we don't close prematurely
65+
let res = tokio::time::timeout(Duration::from_secs(3), async move {
66+
let closed = connection.closed().await;
67+
if !matches!(closed, ConnectionError::ApplicationClosed(_)) {
68+
println!("endpoint disconnected witn an error: {closed:#}");
69+
} else {
70+
println!("connection closed");
71+
}
72+
})
73+
.await;
74+
if res.is_err() {
75+
println!("endpoint did not disconnect within 3 seconds");
76+
}
7077
}
7178
Err(e) => {
7279
println!("Failed to connect to local server {port}: {e}");
@@ -88,30 +95,35 @@ async fn port_forward() -> Result<()> {
8895
};
8996
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
9097
let secret_key = SecretKey::from_bytes(secret_key);
91-
if let Ok(forwarded_ports) = std::env::var(PORT_FORWARD_ENV) {
92-
println!("setting up port forwarding...");
93-
let mut join_set = JoinSet::new();
94-
for port in forwarded_ports.split(',') {
95-
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
96-
let endpoint = Endpoint::builder()
97-
.secret_key(secret_key.clone())
98-
.alpns(vec![alpn.clone()])
99-
.bind()
100-
.await?;
101-
102-
let port = port.to_owned();
103-
join_set.spawn(async move {
104-
let handler = PortForwardHandler {
105-
port: port.parse::<u16>().expect("couldn't parse port"),
106-
};
107-
Router::builder(endpoint.clone()).accept(&alpn, handler).spawn();
108-
});
109-
}
110-
while let Some(res) = join_set.join_next().await {
111-
println!("Task joined: {res:?}");
112-
}
98+
let mut forwarded_ports = vec!["ssh".to_owned()];
99+
if let Ok(env_ports) = std::env::var(PORT_FORWARD_ENV) {
100+
forwarded_ports.extend(env_ports.split(',').map(|p| p.to_owned()).collect::<Vec<String>>());
101+
}
102+
let endpoint = Endpoint::builder().secret_key(secret_key.clone()).bind().await?;
103+
let id = endpoint.id();
104+
println!("endpoint: {id}");
105+
106+
println!("setting up port forwarding...");
107+
let mut builder = Router::builder(endpoint.clone());
108+
for port in forwarded_ports {
109+
let (port, alpn) = if port == "ssh" {
110+
(SSH_PORT, "/iroh/ssh".to_string())
111+
} else {
112+
(
113+
port.parse::<u16>().expect("couldn't parse port"),
114+
format!("/coman/{port}"),
115+
)
116+
};
117+
118+
let handler = PortForwardHandler { port };
119+
builder = builder.accept(alpn.clone().into_bytes(), handler);
120+
println!("set up port forwarding for port {port} ({alpn})");
113121
}
122+
let _router = builder.spawn();
123+
println!("port forwarding started");
114124

125+
let _ = tokio::signal::ctrl_c().await;
126+
println!("port forwarding stopped");
115127
Ok(())
116128
}
117129

@@ -125,11 +137,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125137
.expect("Launch failed");
126138

127139
let mut supervisor = Supervisor::new(SupervisorConfig::default());
128-
supervisor.add_process("iroh-ssh", ChildType::Permanent, || {
129-
thread::spawn(|| {
130-
let _ = run_ssh();
131-
})
132-
});
133140
supervisor.add_process("port-forward", ChildType::Permanent, || {
134141
thread::spawn(|| {
135142
let _ = port_forward();

coman/src/cli/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
pub mod app;
22
pub mod exec;
3-
pub mod port_forward;
43
pub mod proxy;

coman/src/cli/port_forward.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

coman/src/cscs/cli.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ pub(crate) async fn cli_cscs_port_forward(
132132
platform: Option<ComputePlatform>,
133133
) -> Result<()> {
134134
let job_id = maybe_job_id_from_name(job, system.clone(), platform.clone()).await?;
135+
println!("running port forward for job {job_id}");
135136
cscs_port_forward(job_id, source_port, destination_port, system).await
136137
}
137138

coman/src/cscs/handlers.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use base64::prelude::*;
1414
use color_eyre::{Result, eyre::eyre};
1515
use eyre::Context;
1616
use futures::StreamExt;
17-
use iroh::{Endpoint, EndpointId, SecretKey, protocol::Router};
17+
use iroh::{Endpoint, EndpointId, SecretKey};
1818
use itertools::Itertools;
1919
use regex::Regex;
2020
use reqwest::Url;
@@ -202,25 +202,30 @@ async fn process_port_forward(endpoint_id: EndpointId, destination_port: u16, mu
202202
let alpn: Vec<u8> = format!("/coman/{destination_port}").into_bytes();
203203
let secret_key = SecretKey::generate(&mut rand::rng());
204204
let endpoint = Endpoint::builder().secret_key(secret_key).bind().await?;
205-
Router::builder(endpoint.clone()).spawn(); // start local iroh listener
205+
// let _router = Router::builder(endpoint.clone()).spawn(); // start local iroh listener
206206

207207
match endpoint.connect(endpoint_id, &alpn).await {
208208
Ok(connection) => {
209209
let (mut iroh_send, mut iroh_recv) = connection.open_bi().await?;
210210
let (mut local_read, mut local_write) = socket.split();
211211
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
212-
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
212+
let b_to_a = async move {
213+
let res = tokio::io::copy(&mut iroh_recv, &mut local_write).await;
214+
if res.is_ok() {
215+
local_write.flush().await.expect("couldn't flush socket");
216+
}
217+
res
218+
};
213219
println!("connection open");
214220

215221
tokio::select! {
216222
result = a_to_b => {
217-
let _ = result;
223+
let _= dbg!(result);
218224
},
219225
result = b_to_a => {
220-
let _ = result;
226+
let _= dbg!(result);
221227
},
222228
};
223-
println!("connection closed");
224229

225230
Ok(())
226231
}

0 commit comments

Comments
 (0)