Skip to content

Commit 1f8f9e5

Browse files
committed
make port forwarding wait instead of exit
1 parent 54f0b98 commit 1f8f9e5

5 files changed

Lines changed: 41 additions & 49 deletions

File tree

coman/src/cli/exec.rs

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ use iroh::{
66
Endpoint, SecretKey,
77
protocol::{ProtocolHandler, Router},
88
};
9-
use iroh_ssh::IrohSsh;
109
use pid1::Pid1Settings;
1110
use rust_supervisor::{ChildType, Supervisor, SupervisorConfig};
12-
use tokio::{net::TcpStream, task::JoinSet};
11+
use tokio::{io::AsyncWriteExt, net::TcpStream};
1312

1413
const SECRET_KEY_ENV: &str = "COMAN_IROH_SECRET";
1514
const PORT_FORWARD_ENV: &str = "COMAN_FORWARDED_PORTS";
15+
const SSH_PORT: u16 = 15263;
1616

1717
fn get_secret_key() -> Option<Vec<u8>> {
1818
if let Ok(secret) = std::env::var(SECRET_KEY_ENV) {
@@ -23,19 +23,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2323
}
2424
}
2525

26-
#[tokio::main]
27-
async fn run_ssh() -> Result<()> {
28-
let mut builder = IrohSsh::builder().accept_incoming(true).accept_port(15263);
29-
if let Some(secret_key) = get_secret_key() {
30-
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
31-
builder = builder.secret_key(secret_key);
32-
}
33-
let server = builder.build().await.expect("couldn't create iroh server");
34-
println!("{}@{}", whoami::username(), server.node_id());
35-
tokio::signal::ctrl_c().await?;
36-
Ok(())
37-
}
38-
3926
#[derive(Debug)]
4027
struct PortForwardHandler {
4128
port: u16,
@@ -56,7 +43,13 @@ impl ProtocolHandler for PortForwardHandler {
5643

5744
let (mut local_read, mut local_write) = output_stream.split();
5845

59-
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
46+
let a_to_b = async move {
47+
let res = tokio::io::copy(&mut local_read, &mut iroh_send).await;
48+
if res.is_ok() {
49+
iroh_send.flush().await.expect("couldn't flush stream");
50+
}
51+
res
52+
};
6053
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
6154

6255
tokio::select! {
@@ -88,30 +81,35 @@ async fn port_forward() -> Result<()> {
8881
};
8982
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
9083
let secret_key = SecretKey::from_bytes(secret_key);
91-
if let Ok(forwarded_ports) = std::env::var(PORT_FORWARD_ENV) {
92-
println!("setting up port forwarding...");
93-
let mut join_set = JoinSet::new();
94-
for port in forwarded_ports.split(',') {
95-
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
96-
let endpoint = Endpoint::builder()
97-
.secret_key(secret_key.clone())
98-
.alpns(vec![alpn.clone()])
99-
.bind()
100-
.await?;
101-
102-
let port = port.to_owned();
103-
join_set.spawn(async move {
104-
let handler = PortForwardHandler {
105-
port: port.parse::<u16>().expect("couldn't parse port"),
106-
};
107-
Router::builder(endpoint.clone()).accept(&alpn, handler).spawn();
108-
});
109-
}
110-
while let Some(res) = join_set.join_next().await {
111-
println!("Task joined: {res:?}");
112-
}
84+
let mut forwarded_ports = vec!["ssh".to_owned()];
85+
if let Ok(env_ports) = std::env::var(PORT_FORWARD_ENV) {
86+
forwarded_ports.extend(env_ports.split(',').map(|p| p.to_owned()).collect::<Vec<String>>());
87+
}
88+
let endpoint = Endpoint::builder().secret_key(secret_key.clone()).bind().await?;
89+
let id = endpoint.id();
90+
println!("endpoint: {id}");
91+
92+
println!("setting up port forwarding...");
93+
let mut builder = Router::builder(endpoint.clone());
94+
for port in forwarded_ports {
95+
let (port, alpn) = if port == "ssh" {
96+
(SSH_PORT, "/iroh/ssh".to_string())
97+
} else {
98+
(
99+
port.parse::<u16>().expect("couldn't parse port"),
100+
format!("/coman/{port}"),
101+
)
102+
};
103+
104+
let handler = PortForwardHandler { port };
105+
builder = builder.accept(alpn.clone().into_bytes(), handler);
106+
println!("set up port forwarding for port {port} ({alpn})");
113107
}
108+
let _router = builder.spawn();
109+
println!("port forwarding started");
114110

111+
let _ = tokio::signal::ctrl_c().await;
112+
println!("port forwarding stopped");
115113
Ok(())
116114
}
117115

@@ -125,11 +123,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125123
.expect("Launch failed");
126124

127125
let mut supervisor = Supervisor::new(SupervisorConfig::default());
128-
supervisor.add_process("iroh-ssh", ChildType::Permanent, || {
129-
thread::spawn(|| {
130-
let _ = run_ssh();
131-
})
132-
});
133126
supervisor.add_process("port-forward", ChildType::Permanent, || {
134127
thread::spawn(|| {
135128
let _ = port_forward();

coman/src/cli/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
pub mod app;
22
pub mod exec;
3-
pub mod port_forward;
43
pub mod proxy;

coman/src/cli/port_forward.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

coman/src/cscs/cli.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ pub(crate) async fn cli_cscs_port_forward(
132132
platform: Option<ComputePlatform>,
133133
) -> Result<()> {
134134
let job_id = maybe_job_id_from_name(job, system.clone(), platform.clone()).await?;
135+
println!("running port forward for job {job_id}");
135136
cscs_port_forward(job_id, source_port, destination_port, system).await
136137
}
137138

coman/src/cscs/handlers.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use base64::prelude::*;
1414
use color_eyre::{Result, eyre::eyre};
1515
use eyre::Context;
1616
use futures::StreamExt;
17-
use iroh::{Endpoint, EndpointId, SecretKey, protocol::Router};
17+
use iroh::{Endpoint, EndpointId, SecretKey};
1818
use itertools::Itertools;
1919
use regex::Regex;
2020
use reqwest::Url;
@@ -202,7 +202,7 @@ async fn process_port_forward(endpoint_id: EndpointId, destination_port: u16, mu
202202
let alpn: Vec<u8> = format!("/coman/{destination_port}").into_bytes();
203203
let secret_key = SecretKey::generate(&mut rand::rng());
204204
let endpoint = Endpoint::builder().secret_key(secret_key).bind().await?;
205-
Router::builder(endpoint.clone()).spawn(); // start local iroh listener
205+
// let _router = Router::builder(endpoint.clone()).spawn(); // start local iroh listener
206206

207207
match endpoint.connect(endpoint_id, &alpn).await {
208208
Ok(connection) => {
@@ -214,10 +214,10 @@ async fn process_port_forward(endpoint_id: EndpointId, destination_port: u16, mu
214214

215215
tokio::select! {
216216
result = a_to_b => {
217-
let _ = result;
217+
let _= result;
218218
},
219219
result = b_to_a => {
220-
let _ = result;
220+
let _= result;
221221
},
222222
};
223223
println!("connection closed");

0 commit comments

Comments
 (0)