@@ -4,6 +4,7 @@ use std::os::unix::fs::MetadataExt;
44use std:: os:: windows:: fs:: MetadataExt ;
55use std:: {
66 collections:: HashMap ,
7+ io:: { BufWriter , Read , Write } ,
78 path:: { Path , PathBuf } ,
89} ;
910
@@ -16,7 +17,7 @@ use reqwest::Url;
1617
1718use super :: api_client:: client:: { EdfSpec , ScriptSpec } ;
1819use crate :: {
19- config:: { ComputePlatform , Config } ,
20+ config:: { ComputePlatform , Config , get_data_dir } ,
2021 cscs:: {
2122 api_client:: {
2223 client:: { CscsApi , JobStartOptions } ,
@@ -169,12 +170,11 @@ async fn setup_ssh(
169170 base_path : & Path ,
170171 current_system : & str ,
171172 options : & JobStartOptions ,
172- ) -> Result < Option < ( PathBuf , String ) > > {
173+ ) -> Result < Option < ( PathBuf , SecretKey ) > > {
173174 if options. no_ssh {
174175 return Ok ( None ) ;
175176 }
176177 let secret = SecretKey :: generate ( & mut rand:: rng ( ) ) ;
177- let encoded_secret = BASE64_STANDARD . encode ( secret. to_bytes ( ) ) ;
178178
179179 let ssh_key = if let Some ( path) = options. ssh_key . clone ( ) {
180180 path. canonicalize ( ) . map ( Some ) . wrap_err ( "couldn't get ssh key path" ) ?
@@ -201,7 +201,7 @@ async fn setup_ssh(
201201 api_client
202202 . upload ( current_system, remote_path. clone ( ) , public_key. into_bytes ( ) )
203203 . await ?;
204- Ok ( Some ( ( remote_path, encoded_secret ) ) )
204+ Ok ( Some ( ( remote_path, secret ) ) )
205205 }
206206 None => Err ( eyre ! ( "couldn't find ssh public key, use `--ssh_key` to specify it" ) ) ,
207207 }
@@ -274,8 +274,8 @@ async fn handle_edf(
274274 current_system : & str ,
275275 envvars : & HashMap < String , String > ,
276276 coman_squash : & Option < PathBuf > ,
277- ssh_path : & Option < PathBuf > ,
278- iroh_secret : & Option < String > ,
277+ ssh_public_key_path : & Option < PathBuf > ,
278+ iroh_secret : & Option < SecretKey > ,
279279 workdir : & str ,
280280 options : & JobStartOptions ,
281281) -> Result < PathBuf > {
@@ -322,11 +322,12 @@ async fn handle_edf(
322322 context. insert ( "container_workdir" , & workdir) ;
323323 context. insert ( "env" , & envvars) ;
324324 context. insert ( "mount" , & mount) ;
325- context. insert ( "ssh_public_key" , & ssh_path ) ;
325+ context. insert ( "ssh_public_key" , & ssh_public_key_path ) ;
326326 context. insert ( "coman_squash" , & coman_squash) ;
327327 if let Some ( iroh_secret) = iroh_secret {
328328 // set iroh secret key
329- context. insert ( "iroh_secret" , & iroh_secret) ;
329+ let encoded_secret = BASE64_STANDARD . encode ( iroh_secret. to_bytes ( ) ) ;
330+ context. insert ( "iroh_secret" , & encoded_secret) ;
330331 }
331332
332333 let environment_file = tera. render ( "environment.toml" , & context) ?;
@@ -439,7 +440,9 @@ pub async fn cscs_job_start(
439440 let mut envvars = config. values . cscs . env . clone ( ) ;
440441 envvars. extend ( options. env . clone ( ) ) ;
441442
442- let ssh_values = setup_ssh ( & api_client, & base_path, current_system, & options) . await ?;
443+ let ( ssh_public_key_path, secret_key) = setup_ssh ( & api_client, & base_path, current_system, & options)
444+ . await ?
445+ . unzip ( ) ;
443446 let coman_squash = inject_coman_squash ( & api_client, & base_path, current_system, & options) . await ?;
444447
445448 let environment_path = handle_edf (
@@ -448,8 +451,8 @@ pub async fn cscs_job_start(
448451 current_system,
449452 & envvars,
450453 & coman_squash,
451- & ssh_values . clone ( ) . map ( |v| v . 0 ) ,
452- & ssh_values . map ( |v| v . 1 ) ,
454+ & ssh_public_key_path ,
455+ & secret_key ,
453456 & container_workdir,
454457 & options,
455458 )
@@ -468,9 +471,55 @@ pub async fn cscs_job_start(
468471 . await ?;
469472
470473 // start job
471- api_client
474+ let job_id = api_client
472475 . start_job ( current_system, account, & job_name, script_path, envvars, options)
473- . await ?;
476+ . await ?
477+ . ok_or ( eyre ! ( "didn't get job id for created job" ) ) ?;
478+
479+ if let Some ( secret_key) = secret_key {
480+ // store connection information in data dir and set up ssh connection
481+ let data_dir = get_data_dir ( ) ;
482+ std:: fs:: write (
483+ data_dir. join ( format ! ( "{}.endpoint" , job_id) ) ,
484+ format ! ( "{}" , secret_key. public( ) ) ,
485+ ) ?;
486+ let coman_ssh_config_path = data_dir. join ( "ssh_config" ) ;
487+ let coman_ssh_config = std:: fs:: OpenOptions :: new ( )
488+ . create ( true )
489+ . append ( true )
490+ . open ( coman_ssh_config_path. clone ( ) ) ?;
491+ let mut writer = BufWriter :: new ( coman_ssh_config) ;
492+ write ! (
493+ writer,
494+ "\n Host {}-{}\n Hostname {}\n User {}\n ProxyCommand coman proxy {}\n " ,
495+ job_name,
496+ job_id,
497+ secret_key. public( ) ,
498+ user_info. name,
499+ job_id
500+ ) ?;
501+ let ssh_dir = dirs:: home_dir ( ) . ok_or ( eyre ! ( "couldn't find home dir" ) ) ?. join ( ".ssh" ) ;
502+ let ssh_config_path = ssh_dir. join ( "config" ) ;
503+ let mut ssh_config = std:: fs:: OpenOptions :: new ( )
504+ . read ( true )
505+ . append ( true )
506+ . open ( ssh_config_path) ?;
507+ let mut content = String :: new ( ) ;
508+ ssh_config. read_to_string ( & mut content) ?;
509+ if !content. contains ( & format ! ( "Include {}" , coman_ssh_config_path. clone( ) . display( ) ) ) {
510+ let mut writer = BufWriter :: new ( ssh_config) ;
511+ write ! (
512+ writer,
513+ "\n \n #coman include\n Match all\n Include {}" ,
514+ coman_ssh_config_path. display( )
515+ ) ?;
516+ }
517+ println ! (
518+ "Use ssh {}@{}-{} to connect to the job" ,
519+ user_info. name, job_name, job_id
520+ ) ;
521+ }
522+
474523 Ok ( ( ) )
475524 }
476525 Err ( e) => Err ( e) ,
0 commit comments