Skip to content

Commit 9b82006

Browse files
committed
Improve stdio tunnel on windows
- Handle CTRL+C to exit properly - Restore terminal mode at exit - Use logger to stderr
1 parent 0595e23 commit 9b82006

File tree

3 files changed

+41
-28
lines changed

3 files changed

+41
-28
lines changed

src/main.rs

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -724,27 +724,28 @@ async fn main() {
724724
let args = Wstunnel::parse();
725725

726726
// Setup logging
727-
match &args.commands {
728-
// Disable logging if there is a stdio tunnel
729-
Commands::Client(args)
730-
if args
731-
.local_to_remote
732-
.iter()
733-
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
734-
.count()
735-
> 0 => {}
736-
_ => {
737-
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
738-
if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
739-
env_filter =
740-
env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
741-
}
742-
tracing_subscriber::fmt()
743-
.with_ansi(args.no_color.is_none())
744-
.with_env_filter(env_filter)
745-
.init();
746-
}
727+
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
728+
if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
729+
env_filter = env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
747730
}
731+
let logger = tracing_subscriber::fmt()
732+
.with_ansi(args.no_color.is_none())
733+
.with_env_filter(env_filter);
734+
735+
// stdio tunnel capture stdio, so need to log into stderr
736+
if let Commands::Client(args) = &args.commands {
737+
if args
738+
.local_to_remote
739+
.iter()
740+
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
741+
.count()
742+
> 0
743+
{
744+
logger.with_writer(io::stderr).init();
745+
}
746+
} else {
747+
logger.init();
748+
};
748749

749750
match args.commands {
750751
Commands::Client(args) => {
@@ -1018,7 +1019,7 @@ async fn main() {
10181019
});
10191020
}
10201021
#[cfg(not(unix))]
1021-
LocalProtocol::Unix { path } => {
1022+
LocalProtocol::Unix { .. } => {
10221023
panic!("Unix socket is not available for non Unix platform")
10231024
}
10241025
LocalProtocol::Stdio

src/stdio.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
pub mod server {
33

44
use tokio_fd::AsyncFd;
5+
use tracing::info;
56
pub async fn run_server() -> Result<(AsyncFd, AsyncFd), anyhow::Error> {
6-
eprintln!("Starting STDIO server");
7+
info!("Starting STDIO server");
78

89
let stdin = AsyncFd::try_from(nix::libc::STDIN_FILENO)?;
910
let stdout = AsyncFd::try_from(nix::libc::STDOUT_FILENO)?;
@@ -15,31 +16,39 @@ pub mod server {
1516
#[cfg(not(unix))]
1617
pub mod server {
1718
use bytes::BytesMut;
19+
use log::error;
20+
use scopeguard::guard;
1821
use std::io::{Read, Write};
1922
use std::{io, thread};
2023
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
2124
use tokio::task::LocalSet;
2225
use tokio_stream::wrappers::UnboundedReceiverStream;
2326
use tokio_util::io::StreamReader;
27+
use tracing::info;
2428

2529
pub async fn run_server() -> Result<(impl AsyncRead, impl AsyncWrite), anyhow::Error> {
26-
eprintln!("Starting STDIO server");
30+
info!("Starting STDIO server. Press ctrl+c twice to exit");
2731

2832
crossterm::terminal::enable_raw_mode()?;
2933

3034
let stdin = io::stdin();
3135
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
3236
thread::spawn(move || {
37+
let _restore_terminal = guard((), move |_| {
38+
let _ = crossterm::terminal::disable_raw_mode();
39+
});
3340
let stdin = stdin;
3441
let mut stdin = stdin.lock();
3542
let mut buf = [0u8; 65536];
43+
3644
loop {
37-
let n = stdin.read(&mut buf).unwrap();
38-
if n == 0 {
45+
let n = stdin.read(&mut buf).unwrap_or(0);
46+
if n == 0 || (n == 1 && buf[0] == 3) {
47+
// ctrl+c send char 3
3948
break;
4049
}
4150
if let Err(err) = send.send(Result::<_, io::Error>::Ok(BytesMut::from(&buf[..n]))) {
42-
eprintln!("Failed send inout: {:?}", err);
51+
error!("Failed send inout: {:?}", err);
4352
break;
4453
}
4554
}
@@ -50,6 +59,9 @@ pub mod server {
5059
let rt = tokio::runtime::Handle::current();
5160
thread::spawn(move || {
5261
let task = async move {
62+
let _restore_terminal = guard((), move |_| {
63+
let _ = crossterm::terminal::disable_raw_mode();
64+
});
5365
let mut stdout = io::stdout().lock();
5466
let mut buf = [0u8; 65536];
5567
loop {
@@ -62,7 +74,7 @@ pub mod server {
6274
}
6375

6476
if let Err(err) = stdout.write_all(&buf[..n]) {
65-
eprintln!("Failed to write to stdout: {:?}", err);
77+
error!("Failed to write to stdout: {:?}", err);
6678
break;
6779
};
6880
let _ = stdout.flush();

src/tunnel/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ async fn run_tunnel(
167167
Ok((remote, Box::pin(local_rx), Box::pin(local_tx)))
168168
}
169169
#[cfg(not(unix))]
170-
LocalProtocol::ReverseUnix { ref path } => {
170+
LocalProtocol::ReverseUnix { .. } => {
171171
error!("Received an unsupported target protocol {:?}", remote);
172172
Err(anyhow::anyhow!("Invalid upgrade request"))
173173
}

0 commit comments

Comments
 (0)