@@ -4,15 +4,16 @@ use base64::prelude::*;
44use color_eyre:: Result ;
55use iroh:: {
66 Endpoint , SecretKey ,
7+ endpoint:: ConnectionError ,
78 protocol:: { ProtocolHandler , Router } ,
89} ;
9- use iroh_ssh:: IrohSsh ;
1010use pid1:: Pid1Settings ;
1111use rust_supervisor:: { ChildType , Supervisor , SupervisorConfig } ;
12- use tokio:: { net :: TcpStream , task :: JoinSet } ;
12+ use tokio:: { io :: AsyncWriteExt , net :: TcpStream } ;
1313
1414const SECRET_KEY_ENV : & str = "COMAN_IROH_SECRET" ;
1515const PORT_FORWARD_ENV : & str = "COMAN_FORWARDED_PORTS" ;
16+ const SSH_PORT : u16 = 15263 ;
1617
1718fn get_secret_key ( ) -> Option < Vec < u8 > > {
1819 if let Ok ( secret) = std:: env:: var ( SECRET_KEY_ENV ) {
@@ -23,19 +24,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2324 }
2425}
2526
26- #[ tokio:: main]
27- async fn run_ssh ( ) -> Result < ( ) > {
28- let mut builder = IrohSsh :: builder ( ) . accept_incoming ( true ) . accept_port ( 15263 ) ;
29- if let Some ( secret_key) = get_secret_key ( ) {
30- let secret_key: & [ u8 ; 32 ] = secret_key[ 0 ..32 ] . try_into ( ) . unwrap ( ) ;
31- builder = builder. secret_key ( secret_key) ;
32- }
33- let server = builder. build ( ) . await . expect ( "couldn't create iroh server" ) ;
34- println ! ( "{}@{}" , whoami:: username( ) , server. node_id( ) ) ;
35- tokio:: signal:: ctrl_c ( ) . await ?;
36- Ok ( ( ) )
37- }
38-
3927#[ derive( Debug ) ]
4028struct PortForwardHandler {
4129 port : u16 ,
@@ -56,7 +44,13 @@ impl ProtocolHandler for PortForwardHandler {
5644
5745 let ( mut local_read, mut local_write) = output_stream. split ( ) ;
5846
59- let a_to_b = async move { tokio:: io:: copy ( & mut local_read, & mut iroh_send) . await } ;
47+ let a_to_b = async move {
48+ let res = tokio:: io:: copy ( & mut local_read, & mut iroh_send) . await ;
49+ if res. is_ok ( ) {
50+ iroh_send. flush ( ) . await . expect ( "couldn't flush stream" ) ;
51+ }
52+ res
53+ } ;
6054 let b_to_a = async move { tokio:: io:: copy ( & mut iroh_recv, & mut local_write) . await } ;
6155
6256 tokio:: select! {
@@ -67,6 +61,19 @@ impl ProtocolHandler for PortForwardHandler {
6761 println!( "Iroh->{port} stream ended: {result:?}" ) ;
6862 } ,
6963 } ;
64+ // wait for client to close connection so we don't close prematurely
65+ let res = tokio:: time:: timeout ( Duration :: from_secs ( 3 ) , async move {
66+ let closed = connection. closed ( ) . await ;
67+ if !matches ! ( closed, ConnectionError :: ApplicationClosed ( _) ) {
68+ println ! ( "endpoint disconnected witn an error: {closed:#}" ) ;
69+ } else {
70+ println ! ( "connection closed" ) ;
71+ }
72+ } )
73+ . await ;
74+ if res. is_err ( ) {
75+ println ! ( "endpoint did not disconnect within 3 seconds" ) ;
76+ }
7077 }
7178 Err ( e) => {
7279 println ! ( "Failed to connect to local server {port}: {e}" ) ;
@@ -88,30 +95,35 @@ async fn port_forward() -> Result<()> {
8895 } ;
8996 let secret_key: & [ u8 ; 32 ] = secret_key[ 0 ..32 ] . try_into ( ) . unwrap ( ) ;
9097 let secret_key = SecretKey :: from_bytes ( secret_key) ;
91- if let Ok ( forwarded_ports) = std:: env:: var ( PORT_FORWARD_ENV ) {
92- println ! ( "setting up port forwarding..." ) ;
93- let mut join_set = JoinSet :: new ( ) ;
94- for port in forwarded_ports. split ( ',' ) {
95- let alpn: Vec < u8 > = format ! ( "/coman/{port}" ) . into_bytes ( ) ;
96- let endpoint = Endpoint :: builder ( )
97- . secret_key ( secret_key. clone ( ) )
98- . alpns ( vec ! [ alpn. clone( ) ] )
99- . bind ( )
100- . await ?;
101-
102- let port = port. to_owned ( ) ;
103- join_set. spawn ( async move {
104- let handler = PortForwardHandler {
105- port : port. parse :: < u16 > ( ) . expect ( "couldn't parse port" ) ,
106- } ;
107- Router :: builder ( endpoint. clone ( ) ) . accept ( & alpn, handler) . spawn ( ) ;
108- } ) ;
109- }
110- while let Some ( res) = join_set. join_next ( ) . await {
111- println ! ( "Task joined: {res:?}" ) ;
112- }
98+ let mut forwarded_ports = vec ! [ "ssh" . to_owned( ) ] ;
99+ if let Ok ( env_ports) = std:: env:: var ( PORT_FORWARD_ENV ) {
100+ forwarded_ports. extend ( env_ports. split ( ',' ) . map ( |p| p. to_owned ( ) ) . collect :: < Vec < String > > ( ) ) ;
101+ }
102+ let endpoint = Endpoint :: builder ( ) . secret_key ( secret_key. clone ( ) ) . bind ( ) . await ?;
103+ let id = endpoint. id ( ) ;
104+ println ! ( "endpoint: {id}" ) ;
105+
106+ println ! ( "setting up port forwarding..." ) ;
107+ let mut builder = Router :: builder ( endpoint. clone ( ) ) ;
108+ for port in forwarded_ports {
109+ let ( port, alpn) = if port == "ssh" {
110+ ( SSH_PORT , "/iroh/ssh" . to_string ( ) )
111+ } else {
112+ (
113+ port. parse :: < u16 > ( ) . expect ( "couldn't parse port" ) ,
114+ format ! ( "/coman/{port}" ) ,
115+ )
116+ } ;
117+
118+ let handler = PortForwardHandler { port } ;
119+ builder = builder. accept ( alpn. clone ( ) . into_bytes ( ) , handler) ;
120+ println ! ( "set up port forwarding for port {port} ({alpn})" ) ;
113121 }
122+ let _router = builder. spawn ( ) ;
123+ println ! ( "port forwarding started" ) ;
114124
125+ let _ = tokio:: signal:: ctrl_c ( ) . await ;
126+ println ! ( "port forwarding stopped" ) ;
115127 Ok ( ( ) )
116128}
117129
@@ -125,11 +137,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125137 . expect ( "Launch failed" ) ;
126138
127139 let mut supervisor = Supervisor :: new ( SupervisorConfig :: default ( ) ) ;
128- supervisor. add_process ( "iroh-ssh" , ChildType :: Permanent , || {
129- thread:: spawn ( || {
130- let _ = run_ssh ( ) ;
131- } )
132- } ) ;
133140 supervisor. add_process ( "port-forward" , ChildType :: Permanent , || {
134141 thread:: spawn ( || {
135142 let _ = port_forward ( ) ;
0 commit comments