Skip to content

Commit 54f0b98

Browse files
committed
fix portforwarding
1 parent b252cad commit 54f0b98

2 files changed

Lines changed: 66 additions & 37 deletions

File tree

coman/src/cli/exec.rs

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use std::{thread, time::Duration};
22

33
use base64::prelude::*;
44
use color_eyre::Result;
5-
use iroh::{Endpoint, SecretKey};
5+
use iroh::{
6+
Endpoint, SecretKey,
7+
protocol::{ProtocolHandler, Router},
8+
};
69
use iroh_ssh::IrohSsh;
710
use pid1::Pid1Settings;
811
use rust_supervisor::{ChildType, Supervisor, SupervisorConfig};
@@ -29,11 +32,55 @@ async fn run_ssh() -> Result<()> {
2932
}
3033
let server = builder.build().await.expect("couldn't create iroh server");
3134
println!("{}@{}", whoami::username(), server.node_id());
32-
loop {
33-
tokio::time::sleep(Duration::from_secs(60)).await;
34-
}
35+
tokio::signal::ctrl_c().await?;
36+
Ok(())
37+
}
38+
39+
#[derive(Debug)]
40+
struct PortForwardHandler {
41+
port: u16,
3542
}
3643

44+
impl ProtocolHandler for PortForwardHandler {
45+
async fn accept(&self, connection: iroh::endpoint::Connection) -> Result<(), iroh::protocol::AcceptError> {
46+
let endpoint_id = connection.remote_id();
47+
let port = self.port;
48+
49+
match connection.accept_bi().await {
50+
Ok((mut iroh_send, mut iroh_recv)) => {
51+
println!("Accepted bidirectional stream from {endpoint_id}");
52+
53+
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
54+
Ok(mut output_stream) => {
55+
println!("Connected to local server on port {}", port);
56+
57+
let (mut local_read, mut local_write) = output_stream.split();
58+
59+
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
60+
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
61+
62+
tokio::select! {
63+
result = a_to_b => {
64+
println!("{port}->Iroh stream ended: {result:?}");
65+
},
66+
result = b_to_a => {
67+
println!("Iroh->{port} stream ended: {result:?}");
68+
},
69+
};
70+
}
71+
Err(e) => {
72+
println!("Failed to connect to local server {port}: {e}");
73+
}
74+
}
75+
}
76+
Err(e) => {
77+
println!("Failed to accept bidirectional stream {port}: {e}");
78+
}
79+
}
80+
81+
Ok(())
82+
}
83+
}
3784
#[tokio::main]
3885
async fn port_forward() -> Result<()> {
3986
let Some(secret_key) = get_secret_key() else {
@@ -42,45 +89,22 @@ async fn port_forward() -> Result<()> {
4289
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
4390
let secret_key = SecretKey::from_bytes(secret_key);
4491
if let Ok(forwarded_ports) = std::env::var(PORT_FORWARD_ENV) {
92+
println!("setting up port forwarding...");
4593
let mut join_set = JoinSet::new();
4694
for port in forwarded_ports.split(',') {
4795
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
4896
let endpoint = Endpoint::builder()
4997
.secret_key(secret_key.clone())
50-
.alpns(vec![alpn])
98+
.alpns(vec![alpn.clone()])
5199
.bind()
52100
.await?;
101+
53102
let port = port.to_owned();
54103
join_set.spawn(async move {
55-
while let Some(incoming) = endpoint.accept().await {
56-
let connection = incoming.await.unwrap();
57-
match connection.accept_bi().await {
58-
Ok((mut iroh_send, mut iroh_recv)) => {
59-
match TcpStream::connect(format!("127.0.0.1:{port}")).await {
60-
Ok(mut stream) => {
61-
let (mut local_read, mut local_write) = stream.split();
62-
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
63-
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
64-
65-
tokio::select! {
66-
result = a_to_b => {
67-
println!("{port}->Iroh stream ended: {result:?}");
68-
},
69-
result = b_to_a => {
70-
println!("Iroh->{port} stream ended: {result:?}");
71-
},
72-
};
73-
}
74-
Err(e) => {
75-
println!("Failed to connect to {port}: {e:?}");
76-
}
77-
}
78-
}
79-
Err(e) => {
80-
println!("Failed to accept stream to {port}: {e:?}");
81-
}
82-
}
83-
}
104+
let handler = PortForwardHandler {
105+
port: port.parse::<u16>().expect("couldn't parse port"),
106+
};
107+
Router::builder(endpoint.clone()).accept(&alpn, handler).spawn();
84108
});
85109
}
86110
while let Some(res) = join_set.join_next().await {

coman/src/cscs/handlers.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use base64::prelude::*;
1414
use color_eyre::{Result, eyre::eyre};
1515
use eyre::Context;
1616
use futures::StreamExt;
17-
use iroh::{Endpoint, EndpointId, SecretKey};
17+
use iroh::{Endpoint, EndpointId, SecretKey, protocol::Router};
1818
use itertools::Itertools;
1919
use regex::Regex;
2020
use reqwest::Url;
@@ -189,6 +189,7 @@ pub async fn cscs_port_forward(
189189
return Err(eyre!("invalid endpoint id length"));
190190
})?;
191191
let listener = TcpListener::bind(format!("127.0.0.1:{source_port}")).await?;
192+
println!("forwarding connection for port {source_port}");
192193

193194
loop {
194195
let (socket, _) = listener.accept().await?;
@@ -197,16 +198,19 @@ pub async fn cscs_port_forward(
197198
}
198199

199200
async fn process_port_forward(endpoint_id: EndpointId, destination_port: u16, mut socket: TcpStream) -> Result<()> {
201+
println!("accepted connection for destination port {destination_port}");
200202
let alpn: Vec<u8> = format!("/coman/{destination_port}").into_bytes();
201-
202-
let endpoint = Endpoint::bind().await?;
203+
let secret_key = SecretKey::generate(&mut rand::rng());
204+
let endpoint = Endpoint::builder().secret_key(secret_key).bind().await?;
205+
Router::builder(endpoint.clone()).spawn(); // start local iroh listener
203206

204207
match endpoint.connect(endpoint_id, &alpn).await {
205208
Ok(connection) => {
206209
let (mut iroh_send, mut iroh_recv) = connection.open_bi().await?;
207210
let (mut local_read, mut local_write) = socket.split();
208211
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
209212
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
213+
println!("connection open");
210214

211215
tokio::select! {
212216
result = a_to_b => {
@@ -216,6 +220,7 @@ async fn process_port_forward(endpoint_id: EndpointId, destination_port: u16, mu
216220
let _ = result;
217221
},
218222
};
223+
println!("connection closed");
219224

220225
Ok(())
221226
}

0 commit comments

Comments
 (0)