Skip to content

Commit 3fa0af1

Browse files
authored
improve port forwarding (#54)
* make port forwarding wait instead of exit * update readme * actually check sha256 for managing coman sqsh
1 parent 54f0b98 commit 3fa0af1

11 files changed

Lines changed: 204 additions & 106 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ command = ["sleep", "1"] # command to execute within the container, i.e. the job
279279

280280
workdir = "/scratch" # working directory within container
281281

282+
port_forward = [12345, 8080] # ports to open in container for port-forwarding
283+
282284
# the sbatch script you want to execute
283285
# this gets templated with values specified in the {{}} and {% %} expressions (see https://keats.github.io/tera/docs/#templates for
284286
# more information on the template language). Note, this can also just be hardcoded without any template parameters.
@@ -358,6 +360,24 @@ Creating the ssh connection involves several steps, all handled by coman:
358360
- Including the SSH config in `.ssh/config` so it's accessible in other tools
359361
- Garbage collecting old SSH connections for jobs that are not running anymore
360362

363+
### Port Forwarding
364+
365+
Port forwarding consists of two steps:
366+
- configuring ports in the container that can be forwarded to
367+
- forwarding a local port to one of the configured ports
368+
369+
To configure forwardable ports in the container, either use the `-P <port>` flag or the `cscs.port_forward` config value.
370+
371+
To forward a local port, use the `coman cscs port-forward` command.
372+
373+
Example:
374+
```shell
375+
coman cscs job submit -i python -P 32100 -n myjob -- python3 -m http.server 32100 # run python built-in http server and add port for forwarding.
376+
coman cscs port-forward -s 32100 -d 32100 myjob # forward local 32100 to remote 32100 for job `myjob`
377+
378+
# open http://localhost:32100 in your browser, you should see a file listing
379+
```
380+
361381
## Development
362382

363383
### Prerequisites

coman/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ dirs = "6.0.0"
8383
iroh = "0.95.1"
8484
rand = "0.9.2"
8585
regex = "1.12.2"
86+
sha2 = "0.10.9"
8687

8788
[build-dependencies]
8889
anyhow = "1.0.90"

coman/src/cli/app.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ pub enum CscsSystemCommands {
471471
},
472472
}
473473

474+
pub const COMAN_VERSION: &str = env!("CARGO_PKG_VERSION");
475+
474476
const VERSION_MESSAGE: &str = concat!(
475477
env!("CARGO_PKG_VERSION"),
476478
"-",

coman/src/cli/exec.rs

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ use base64::prelude::*;
44
use color_eyre::Result;
55
use iroh::{
66
Endpoint, SecretKey,
7+
endpoint::ConnectionError,
78
protocol::{ProtocolHandler, Router},
89
};
9-
use iroh_ssh::IrohSsh;
1010
use pid1::Pid1Settings;
1111
use rust_supervisor::{ChildType, Supervisor, SupervisorConfig};
12-
use tokio::{net::TcpStream, task::JoinSet};
12+
use tokio::{io::AsyncWriteExt, net::TcpStream};
1313

1414
const SECRET_KEY_ENV: &str = "COMAN_IROH_SECRET";
1515
const PORT_FORWARD_ENV: &str = "COMAN_FORWARDED_PORTS";
16+
const SSH_PORT: u16 = 15263;
1617

1718
fn get_secret_key() -> Option<Vec<u8>> {
1819
if let Ok(secret) = std::env::var(SECRET_KEY_ENV) {
@@ -23,19 +24,6 @@ fn get_secret_key() -> Option<Vec<u8>> {
2324
}
2425
}
2526

26-
#[tokio::main]
27-
async fn run_ssh() -> Result<()> {
28-
let mut builder = IrohSsh::builder().accept_incoming(true).accept_port(15263);
29-
if let Some(secret_key) = get_secret_key() {
30-
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
31-
builder = builder.secret_key(secret_key);
32-
}
33-
let server = builder.build().await.expect("couldn't create iroh server");
34-
println!("{}@{}", whoami::username(), server.node_id());
35-
tokio::signal::ctrl_c().await?;
36-
Ok(())
37-
}
38-
3927
#[derive(Debug)]
4028
struct PortForwardHandler {
4129
port: u16,
@@ -56,7 +44,15 @@ impl ProtocolHandler for PortForwardHandler {
5644

5745
let (mut local_read, mut local_write) = output_stream.split();
5846

59-
let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
47+
let a_to_b = async move {
48+
let res = tokio::io::copy(&mut local_read, &mut iroh_send).await;
49+
if res.is_ok() {
50+
iroh_send.flush().await.expect("couldn't flush stream");
51+
iroh_send.finish().expect("couldn't finish stream");
52+
iroh_send.stopped().await.expect("stream not properly stopped");
53+
}
54+
res
55+
};
6056
let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
6157

6258
tokio::select! {
@@ -67,6 +63,19 @@ impl ProtocolHandler for PortForwardHandler {
6763
println!("Iroh->{port} stream ended: {result:?}");
6864
},
6965
};
66+
// wait for client to close connection so we don't close prematurely
67+
let res = tokio::time::timeout(Duration::from_secs(3), async move {
68+
let closed = connection.closed().await;
69+
if !matches!(closed, ConnectionError::ApplicationClosed(_)) {
70+
println!("endpoint disconnected witn an error: {closed:#}");
71+
} else {
72+
println!("connection closed");
73+
}
74+
})
75+
.await;
76+
if res.is_err() {
77+
println!("endpoint did not disconnect within 3 seconds");
78+
}
7079
}
7180
Err(e) => {
7281
println!("Failed to connect to local server {port}: {e}");
@@ -88,30 +97,35 @@ async fn port_forward() -> Result<()> {
8897
};
8998
let secret_key: &[u8; 32] = secret_key[0..32].try_into().unwrap();
9099
let secret_key = SecretKey::from_bytes(secret_key);
91-
if let Ok(forwarded_ports) = std::env::var(PORT_FORWARD_ENV) {
92-
println!("setting up port forwarding...");
93-
let mut join_set = JoinSet::new();
94-
for port in forwarded_ports.split(',') {
95-
let alpn: Vec<u8> = format!("/coman/{port}").into_bytes();
96-
let endpoint = Endpoint::builder()
97-
.secret_key(secret_key.clone())
98-
.alpns(vec![alpn.clone()])
99-
.bind()
100-
.await?;
101-
102-
let port = port.to_owned();
103-
join_set.spawn(async move {
104-
let handler = PortForwardHandler {
105-
port: port.parse::<u16>().expect("couldn't parse port"),
106-
};
107-
Router::builder(endpoint.clone()).accept(&alpn, handler).spawn();
108-
});
109-
}
110-
while let Some(res) = join_set.join_next().await {
111-
println!("Task joined: {res:?}");
112-
}
100+
let mut forwarded_ports = vec!["ssh".to_owned()];
101+
if let Ok(env_ports) = std::env::var(PORT_FORWARD_ENV) {
102+
forwarded_ports.extend(env_ports.split(',').map(|p| p.to_owned()).collect::<Vec<String>>());
103+
}
104+
let endpoint = Endpoint::builder().secret_key(secret_key.clone()).bind().await?;
105+
let id = endpoint.id();
106+
println!("endpoint: {id}");
107+
108+
println!("setting up port forwarding...");
109+
let mut builder = Router::builder(endpoint.clone());
110+
for port in forwarded_ports {
111+
let (port, alpn) = if port == "ssh" {
112+
(SSH_PORT, "/iroh/ssh".to_string())
113+
} else {
114+
(
115+
port.parse::<u16>().expect("couldn't parse port"),
116+
format!("/coman/{port}"),
117+
)
118+
};
119+
120+
let handler = PortForwardHandler { port };
121+
builder = builder.accept(alpn.clone().into_bytes(), handler);
122+
println!("set up port forwarding for port {port} ({alpn})");
113123
}
124+
let _router = builder.spawn();
125+
println!("port forwarding started");
114126

127+
let _ = tokio::signal::ctrl_c().await;
128+
println!("port forwarding stopped");
115129
Ok(())
116130
}
117131

@@ -125,11 +139,6 @@ pub(crate) async fn cli_exec_command(command: Vec<String>) -> Result<()> {
125139
.expect("Launch failed");
126140

127141
let mut supervisor = Supervisor::new(SupervisorConfig::default());
128-
supervisor.add_process("iroh-ssh", ChildType::Permanent, || {
129-
thread::spawn(|| {
130-
let _ = run_ssh();
131-
})
132-
});
133142
supervisor.add_process("port-forward", ChildType::Permanent, || {
134143
thread::spawn(|| {
135144
let _ = port_forward();

coman/src/cli/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
pub mod app;
22
pub mod exec;
3-
pub mod port_forward;
43
pub mod proxy;

coman/src/cli/port_forward.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

coman/src/cscs/api_client/client.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use firecrest_client::{
99
get_compute_system_jobs, post_compute_system_job,
1010
},
1111
filesystem_api::{
12-
delete_filesystem_ops_rm, get_filesystem_ops_download, get_filesystem_ops_ls, get_filesystem_ops_stat,
13-
get_filesystem_ops_tail, post_filesystem_ops_mkdir, post_filesystem_ops_upload,
12+
delete_filesystem_ops_rm, get_filesystem_ops_checksum, get_filesystem_ops_download, get_filesystem_ops_ls,
13+
get_filesystem_ops_stat, get_filesystem_ops_tail, post_filesystem_ops_mkdir, post_filesystem_ops_upload,
1414
post_filesystem_transfer_download, post_filesystem_transfer_upload, put_filesystem_ops_chmod,
1515
},
1616
status_api::{get_status_systems, get_status_userinfo},
@@ -231,6 +231,11 @@ impl CscsApi {
231231
None => Ok(vec![]),
232232
}
233233
}
234+
pub async fn checksum(&self, system_name: &str, path: PathBuf) -> Result<Option<String>> {
235+
get_filesystem_ops_checksum(&self.client, system_name, path)
236+
.await
237+
.wrap_err("couldn't stat file")
238+
}
234239
pub async fn stat_path(&self, system_name: &str, path: PathBuf) -> Result<Option<FileStat>> {
235240
let result = get_filesystem_ops_stat(&self.client, system_name, path)
236241
.await

coman/src/cscs/cli.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ pub(crate) async fn cli_cscs_port_forward(
132132
platform: Option<ComputePlatform>,
133133
) -> Result<()> {
134134
let job_id = maybe_job_id_from_name(job, system.clone(), platform.clone()).await?;
135+
println!("running port forward for job {job_id}");
135136
cscs_port_forward(job_id, source_port, destination_port, system).await
136137
}
137138

0 commit comments

Comments
 (0)