@@ -4,15 +4,16 @@ use base64::prelude::*;
44use color_eyre:: Result ;
55use iroh:: {
66 Endpoint , SecretKey ,
7+ endpoint:: ConnectionError ,
78 protocol:: { ProtocolHandler , Router } ,
89} ;
9- use iroh_ssh:: IrohSsh ;
1010use pid1:: Pid1Settings ;
1111use rust_supervisor:: { ChildType , Supervisor , SupervisorConfig } ;
12- use tokio:: { net :: TcpStream , task :: JoinSet } ;
12+ use tokio:: { io :: AsyncWriteExt , net :: TcpStream } ;
1313
1414const SECRET_KEY_ENV : & str = "COMAN_IROH_SECRET" ;
1515const PORT_FORWARD_ENV : & str = "COMAN_FORWARDED_PORTS" ;
16+ const SSH_PORT : u16 = 15263 ;
1617
1718fn get_secret_key ( ) -> Option < Vec < u8 > > {
1819 if let Ok ( secret) = std:: env:: var ( SECRET_KEY_ENV ) {
@@ -23,19 +24,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2324 }
2425}
2526
26- #[ tokio:: main]
27- async fn run_ssh ( ) -> Result < ( ) > {
28- let mut builder = IrohSsh :: builder ( ) . accept_incoming ( true ) . accept_port ( 15263 ) ;
29- if let Some ( secret_key) = get_secret_key ( ) {
30- let secret_key: & [ u8 ; 32 ] = secret_key[ 0 ..32 ] . try_into ( ) . unwrap ( ) ;
31- builder = builder. secret_key ( secret_key) ;
32- }
33- let server = builder. build ( ) . await . expect ( "couldn't create iroh server" ) ;
34- println ! ( "{}@{}" , whoami:: username( ) , server. node_id( ) ) ;
35- tokio:: signal:: ctrl_c ( ) . await ?;
36- Ok ( ( ) )
37- }
38-
3927#[ derive( Debug ) ]
4028struct PortForwardHandler {
4129 port : u16 ,
@@ -56,7 +44,15 @@ impl ProtocolHandler for PortForwardHandler {
5644
5745 let ( mut local_read, mut local_write) = output_stream. split ( ) ;
5846
59- let a_to_b = async move { tokio:: io:: copy ( & mut local_read, & mut iroh_send) . await } ;
47+ let a_to_b = async move {
48+ let res = tokio:: io:: copy ( & mut local_read, & mut iroh_send) . await ;
49+ if res. is_ok ( ) {
50+ iroh_send. flush ( ) . await . expect ( "couldn't flush stream" ) ;
51+ iroh_send. finish ( ) . expect ( "couldn't finish stream" ) ;
52+ iroh_send. stopped ( ) . await . expect ( "stream not properly stopped" ) ;
53+ }
54+ res
55+ } ;
6056 let b_to_a = async move { tokio:: io:: copy ( & mut iroh_recv, & mut local_write) . await } ;
6157
6258 tokio:: select! {
@@ -67,6 +63,19 @@ impl ProtocolHandler for PortForwardHandler {
6763 println!( "Iroh->{port} stream ended: {result:?}" ) ;
6864 } ,
6965 } ;
66+ // wait for client to close connection so we don't close prematurely
67+ let res = tokio:: time:: timeout ( Duration :: from_secs ( 3 ) , async move {
68+ let closed = connection. closed ( ) . await ;
69+ if !matches ! ( closed, ConnectionError :: ApplicationClosed ( _) ) {
70+ println ! ( "endpoint disconnected witn an error: {closed:#}" ) ;
71+ } else {
72+ println ! ( "connection closed" ) ;
73+ }
74+ } )
75+ . await ;
76+ if res. is_err ( ) {
77+ println ! ( "endpoint did not disconnect within 3 seconds" ) ;
78+ }
7079 }
7180 Err ( e) => {
7281 println ! ( "Failed to connect to local server {port}: {e}" ) ;
@@ -88,30 +97,35 @@ async fn port_forward() -> Result<()> {
8897 } ;
8998 let secret_key: & [ u8 ; 32 ] = secret_key[ 0 ..32 ] . try_into ( ) . unwrap ( ) ;
9099 let secret_key = SecretKey :: from_bytes ( secret_key) ;
91- if let Ok ( forwarded_ports) = std:: env:: var ( PORT_FORWARD_ENV ) {
92- println ! ( "setting up port forwarding..." ) ;
93- let mut join_set = JoinSet :: new ( ) ;
94- for port in forwarded_ports. split ( ',' ) {
95- let alpn: Vec < u8 > = format ! ( "/coman/{port}" ) . into_bytes ( ) ;
96- let endpoint = Endpoint :: builder ( )
97- . secret_key ( secret_key. clone ( ) )
98- . alpns ( vec ! [ alpn. clone( ) ] )
99- . bind ( )
100- . await ?;
101-
102- let port = port. to_owned ( ) ;
103- join_set. spawn ( async move {
104- let handler = PortForwardHandler {
105- port : port. parse :: < u16 > ( ) . expect ( "couldn't parse port" ) ,
106- } ;
107- Router :: builder ( endpoint. clone ( ) ) . accept ( & alpn, handler) . spawn ( ) ;
108- } ) ;
109- }
110- while let Some ( res) = join_set. join_next ( ) . await {
111- println ! ( "Task joined: {res:?}" ) ;
112- }
100+ let mut forwarded_ports = vec ! [ "ssh" . to_owned( ) ] ;
101+ if let Ok ( env_ports) = std:: env:: var ( PORT_FORWARD_ENV ) {
102+ forwarded_ports. extend ( env_ports. split ( ',' ) . map ( |p| p. to_owned ( ) ) . collect :: < Vec < String > > ( ) ) ;
103+ }
104+ let endpoint = Endpoint :: builder ( ) . secret_key ( secret_key. clone ( ) ) . bind ( ) . await ?;
105+ let id = endpoint. id ( ) ;
106+ println ! ( "endpoint: {id}" ) ;
107+
108+ println ! ( "setting up port forwarding..." ) ;
109+ let mut builder = Router :: builder ( endpoint. clone ( ) ) ;
110+ for port in forwarded_ports {
111+ let ( port, alpn) = if port == "ssh" {
112+ ( SSH_PORT , "/iroh/ssh" . to_string ( ) )
113+ } else {
114+ (
115+ port. parse :: < u16 > ( ) . expect ( "couldn't parse port" ) ,
116+ format ! ( "/coman/{port}" ) ,
117+ )
118+ } ;
119+
120+ let handler = PortForwardHandler { port } ;
121+ builder = builder. accept ( alpn. clone ( ) . into_bytes ( ) , handler) ;
122+ println ! ( "set up port forwarding for port {port} ({alpn})" ) ;
113123 }
124+ let _router = builder. spawn ( ) ;
125+ println ! ( "port forwarding started" ) ;
114126
127+ let _ = tokio:: signal:: ctrl_c ( ) . await ;
128+ println ! ( "port forwarding stopped" ) ;
115129 Ok ( ( ) )
116130}
117131
@@ -125,11 +139,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125139 . expect ( "Launch failed" ) ;
126140
127141 let mut supervisor = Supervisor :: new ( SupervisorConfig :: default ( ) ) ;
128- supervisor. add_process ( "iroh-ssh" , ChildType :: Permanent , || {
129- thread:: spawn ( || {
130- let _ = run_ssh ( ) ;
131- } )
132- } ) ;
133142 supervisor. add_process ( "port-forward" , ChildType :: Permanent , || {
134143 thread:: spawn ( || {
135144 let _ = port_forward ( ) ;
0 commit comments