Skip to content

Commit 92d83e2

Browse files
authored
Fix read-only share enforcement and auth throttling (#61)
## Summary - replace WebRTC read-only share filtering with a deny-by-default C2S opcode allowlist shared by legacy and mux forwarding paths - add regression coverage for read-only opcode classification, including the previously missed quit/surface text/resize/focus mutations - add shared auth throttling for config/gateway WebSocket and WebTransport authentication, including unauthenticated connection caps and per-peer lockouts after repeated failures ## Validation - `direnv exec . cargo fmt --check` - `direnv exec . cargo test -p blit-webrtc-forwarder` - `direnv exec . cargo test -p blit-webserver` - `direnv exec . cargo test -p blit-gateway` *(with a temporary valid `js/ui/dist/index.html.br` generated for the existing `include_bytes!` requirement, then removed)* - `direnv exec . cargo check -p blit-cli` *(same temporary dist asset setup)*
1 parent f72578e commit 92d83e2

6 files changed

Lines changed: 495 additions & 142 deletions

File tree

crates/cli/src/interactive.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ struct BrowserState {
146146
/// Broadcast notification triggered on SIGINT so active WebSocket
147147
/// handlers can send `S2C_QUIT` before the process exits.
148148
shutdown: Arc<tokio::sync::Notify>,
149+
/// Shared auth throttle for config WebSocket handshakes.
150+
auth_throttle: blit_webserver::config::AuthThrottle,
149151
}
150152

151153
pub async fn run_browser(port: Option<u16>, hub: &str) {
@@ -195,6 +197,7 @@ pub async fn run_browser(port: Option<u16>, hub: &str) {
195197
hub: hub.to_string(),
196198
ssh_pool,
197199
shutdown: shutdown.clone(),
200+
auth_throttle: blit_webserver::config::AuthThrottle::new(),
198201
});
199202

200203
// Reconcile destinations whenever blit.remotes changes (from the
@@ -247,6 +250,10 @@ pub async fn run_browser(port: Option<u16>, hub: &str) {
247250
Some(&state.remotes),
248251
None,
249252
&[],
253+
blit_webserver::config::AuthContext {
254+
throttle: &state.auth_throttle,
255+
peer: "local",
256+
},
250257
)
251258
.await;
252259
})

crates/gateway/src/lib.rs

Lines changed: 85 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use axum::extract::ws::{Message, WebSocket};
2-
use axum::extract::{FromRequest, State, WebSocketUpgrade};
2+
use axum::extract::{ConnectInfo, FromRequest, State, WebSocketUpgrade};
33
use axum::response::{IntoResponse, Response};
44
use axum::routing::get;
5+
use axum::serve::ListenerExt;
56
use futures_util::{SinkExt, StreamExt};
67
use std::collections::HashMap;
8+
use std::net::SocketAddr;
79
use std::sync::{Arc, LazyLock};
810
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
911
#[cfg(unix)]
@@ -63,36 +65,6 @@ async fn connect_ipc(path: &str) -> Result<IpcStream, String> {
6365
}
6466
}
6567

66-
/// Wraps TcpListener to set TCP_NODELAY on every accepted connection,
67-
/// disabling Nagle's algorithm for low-latency frame delivery.
68-
struct NoDelayListener(tokio::net::TcpListener);
69-
70-
impl axum::serve::Listener for NoDelayListener {
71-
type Io = tokio::net::TcpStream;
72-
type Addr = std::net::SocketAddr;
73-
74-
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
75-
{
76-
loop {
77-
match self.0.accept().await {
78-
Ok((stream, addr)) => {
79-
let _ = stream.set_nodelay(true);
80-
return (stream, addr);
81-
}
82-
Err(e) => {
83-
eprintln!("accept error: {e}");
84-
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
85-
}
86-
}
87-
}
88-
}
89-
}
90-
91-
fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
92-
self.0.local_addr()
93-
}
94-
}
95-
9668
const INDEX_HTML_BR: &[u8] = include_bytes!("../../../js/ui/dist/index.html.br");
9769

9870
static INDEX_ETAG: LazyLock<String> = LazyLock::new(|| blit_webserver::html_etag(INDEX_HTML_BR));
@@ -126,6 +98,8 @@ struct Config {
12698
/// Broadcast notification triggered on SIGINT/SIGTERM so active
12799
/// WebSocket/WebTransport handlers can send `S2C_QUIT` before exit.
128100
shutdown: Arc<tokio::sync::Notify>,
101+
/// Shared auth throttle for config and gateway transports.
102+
auth_throttle: blit_webserver::config::AuthThrottle,
129103
}
130104

131105
impl Config {
@@ -506,6 +480,7 @@ pub async fn run() {
506480
hub_url,
507481
webrtc_enabled,
508482
shutdown: shutdown.clone(),
483+
auth_throttle: blit_webserver::config::AuthThrottle::new(),
509484
});
510485

511486
// --- Reconcile destinations whenever blit.remotes changes ---
@@ -551,7 +526,9 @@ pub async fn run() {
551526
eprintln!("blit gateway: cannot bind to {addr}: {e}");
552527
std::process::exit(1);
553528
});
554-
let listener = NoDelayListener(tcp);
529+
let listener = tcp.tap_io(|stream| {
530+
let _ = stream.set_nodelay(true);
531+
});
555532
eprintln!(
556533
"listening on {addr} (WebSocket{}){}",
557534
if quic_enabled { " + WebTransport" } else { "" },
@@ -564,7 +541,11 @@ pub async fn run() {
564541

565542
blit_sd_notify::notify_ready(false);
566543

567-
let graceful = axum::serve(listener, app).with_graceful_shutdown(async move {
544+
let graceful = axum::serve(
545+
listener,
546+
app.into_make_service_with_connect_info::<SocketAddr>(),
547+
)
548+
.with_graceful_shutdown(async move {
568549
#[cfg(unix)]
569550
{
570551
use tokio::signal::unix::{SignalKind, signal};
@@ -697,6 +678,11 @@ fn mux_error(ch: u16, msg: &str) -> Vec<u8> {
697678
}
698679

699680
async fn root_handler(State(state): State<AppState>, request: axum::extract::Request) -> Response {
681+
let auth_peer = request
682+
.extensions()
683+
.get::<ConnectInfo<SocketAddr>>()
684+
.map(|ConnectInfo(addr)| addr.ip().to_string())
685+
.unwrap_or_else(|| "unknown".to_string());
700686
let path = request.uri().path().to_string();
701687

702688
if let Some(resp) = blit_webserver::try_font_route(&path, state.cors_origin.as_deref()) {
@@ -726,6 +712,10 @@ async fn root_handler(State(state): State<AppState>, request: axum::extract::Req
726712
Some(&state.remotes),
727713
transform,
728714
&extra_init,
715+
blit_webserver::config::AuthContext {
716+
throttle: &state.auth_throttle,
717+
peer: &auth_peer,
718+
},
729719
)
730720
.await;
731721
}),
@@ -735,15 +725,15 @@ async fn root_handler(State(state): State<AppState>, request: axum::extract::Req
735725
match WebSocketUpgrade::from_request(request, &state).await {
736726
Ok(ws) => ws
737727
.max_message_size(MAX_FRAME_SIZE + 2) // +2 for channel ID prefix
738-
.on_upgrade(move |socket| handle_mux_ws(socket, state)),
728+
.on_upgrade(move |socket| handle_mux_ws(socket, state, auth_peer)),
739729
Err(e) => e.into_response(),
740730
}
741731
} else if is_ws {
742732
let dest_name = resolve_destination_name(&path);
743733
match WebSocketUpgrade::from_request(request, &state).await {
744734
Ok(ws) => ws
745735
.max_message_size(MAX_FRAME_SIZE)
746-
.on_upgrade(move |socket| handle_ws(socket, state, dest_name)),
736+
.on_upgrade(move |socket| handle_ws(socket, state, dest_name, auth_peer)),
747737
Err(e) => e.into_response(),
748738
}
749739
} else {
@@ -768,35 +758,21 @@ fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
768758
std::hint::black_box(diff) == 0
769759
}
770760

771-
async fn handle_ws(mut ws: WebSocket, state: AppState, dest_name: Option<String>) {
772-
let authed = match tokio::time::timeout(std::time::Duration::from_secs(30), async {
773-
loop {
774-
match ws.recv().await {
775-
Some(Ok(Message::Text(pass))) => {
776-
if constant_time_eq(pass.trim().as_bytes(), state.passphrase.as_bytes()) {
777-
break true;
778-
} else {
779-
let _ = ws.send(Message::Text("auth".into())).await;
780-
let _ = ws.close().await;
781-
break false;
782-
}
783-
}
784-
Some(Ok(Message::Ping(d))) => {
785-
let _ = ws.send(Message::Pong(d)).await;
786-
}
787-
_ => break false,
788-
}
789-
}
790-
})
761+
async fn handle_ws(
762+
mut ws: WebSocket,
763+
state: AppState,
764+
dest_name: Option<String>,
765+
auth_peer: String,
766+
) {
767+
if !blit_webserver::config::authenticate_text_ws(
768+
&mut ws,
769+
&state.passphrase,
770+
&state.auth_throttle,
771+
&auth_peer,
772+
None,
773+
)
791774
.await
792775
{
793-
Ok(result) => result,
794-
Err(_) => {
795-
let _ = ws.close().await;
796-
false
797-
}
798-
};
799-
if !authed {
800776
return;
801777
}
802778

@@ -906,36 +882,17 @@ impl MuxChannelState {
906882
}
907883
}
908884

909-
async fn handle_mux_ws(mut ws: WebSocket, state: AppState) {
885+
async fn handle_mux_ws(mut ws: WebSocket, state: AppState, auth_peer: String) {
910886
// --- Authentication (identical to handle_ws) ---
911-
let authed = match tokio::time::timeout(std::time::Duration::from_secs(30), async {
912-
loop {
913-
match ws.recv().await {
914-
Some(Ok(Message::Text(pass))) => {
915-
if constant_time_eq(pass.trim().as_bytes(), state.passphrase.as_bytes()) {
916-
break true;
917-
} else {
918-
let _ = ws.send(Message::Text("auth".into())).await;
919-
let _ = ws.close().await;
920-
break false;
921-
}
922-
}
923-
Some(Ok(Message::Ping(d))) => {
924-
let _ = ws.send(Message::Pong(d)).await;
925-
}
926-
_ => break false,
927-
}
928-
}
929-
})
887+
if !blit_webserver::config::authenticate_text_ws(
888+
&mut ws,
889+
&state.passphrase,
890+
&state.auth_throttle,
891+
&auth_peer,
892+
None,
893+
)
930894
.await
931895
{
932-
Ok(result) => result,
933-
Err(_) => {
934-
let _ = ws.close().await;
935-
false
936-
}
937-
};
938-
if !authed {
939896
return;
940897
}
941898

@@ -1410,6 +1367,7 @@ async fn wt_authenticate(
14101367
send: &mut wt::SendStream,
14111368
recv: &mut wt::RecvStream,
14121369
passphrase: &str,
1370+
guard: blit_webserver::config::AuthAttemptGuard,
14131371
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
14141372
let auth_result = tokio::time::timeout(std::time::Duration::from_secs(30), async {
14151373
let mut len_buf = [0u8; 2];
@@ -1418,7 +1376,7 @@ async fn wt_authenticate(
14181376
.map_err(|e| format!("auth read len: {e}"))?;
14191377
let pass_len = u16::from_le_bytes(len_buf) as usize;
14201378
if pass_len > 4096 {
1421-
return Err::<(), String>("passphrase too long".into());
1379+
return Err::<bool, String>("passphrase too long".into());
14221380
}
14231381
let mut pass_buf = vec![0u8; pass_len];
14241382
recv.read_exact(&mut pass_buf)
@@ -1428,16 +1386,26 @@ async fn wt_authenticate(
14281386

14291387
if !constant_time_eq(pass.trim().as_bytes(), passphrase.as_bytes()) {
14301388
send.write_all(&[0]).await.ok();
1431-
return Err("authentication failed".into());
1389+
return Ok(false);
14321390
}
1433-
Ok(())
1391+
Ok(true)
14341392
})
14351393
.await;
14361394

14371395
match auth_result {
1438-
Ok(Ok(())) => {}
1439-
Ok(Err(e)) => return Err(e.into()),
1440-
Err(_) => return Err("authentication timed out".into()),
1396+
Ok(Ok(true)) => guard.record_success(),
1397+
Ok(Ok(false)) => {
1398+
guard.record_failure();
1399+
return Err("authentication failed".into());
1400+
}
1401+
Ok(Err(e)) => {
1402+
guard.record_failure();
1403+
return Err(e.into());
1404+
}
1405+
Err(_) => {
1406+
guard.record_failure();
1407+
return Err("authentication timed out".into());
1408+
}
14411409
}
14421410
send.write_all(&[1])
14431411
.await
@@ -1450,13 +1418,32 @@ async fn handle_webtransport_session(
14501418
state: AppState,
14511419
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
14521420
let path = request.url.path().to_string();
1421+
let auth_peer = request.conn().remote_address().ip().to_string();
14531422
let is_mux = is_mux_path(&path);
14541423
let dest_name = resolve_destination_name(&path);
1424+
let Some(auth_guard) = state.auth_throttle.begin(auth_peer.clone()) else {
1425+
request
1426+
.reject(axum::http::StatusCode::TOO_MANY_REQUESTS)
1427+
.await?;
1428+
return Ok(());
1429+
};
14551430
let session = request.ok().await?;
14561431

1457-
let (mut send, mut recv) = session.accept_bi().await?;
1432+
let (mut send, mut recv) =
1433+
match tokio::time::timeout(std::time::Duration::from_secs(30), session.accept_bi()).await {
1434+
Ok(Ok(streams)) => streams,
1435+
Ok(Err(e)) => {
1436+
auth_guard.record_failure();
1437+
return Err(e.into());
1438+
}
1439+
Err(_) => {
1440+
auth_guard.record_failure();
1441+
session.close(1, b"authentication timed out");
1442+
return Err("authentication timed out".into());
1443+
}
1444+
};
14581445

1459-
wt_authenticate(&mut send, &mut recv, &state.passphrase).await?;
1446+
wt_authenticate(&mut send, &mut recv, &state.passphrase, auth_guard).await?;
14601447

14611448
if is_mux {
14621449
return handle_mux_wt(send, recv, state).await;
@@ -1716,6 +1703,7 @@ mod tests {
17161703
hub_url: blit_webrtc_forwarder::normalize_hub(blit_webrtc_forwarder::DEFAULT_HUB_URL),
17171704
webrtc_enabled: false,
17181705
shutdown: Arc::new(tokio::sync::Notify::new()),
1706+
auth_throttle: blit_webserver::config::AuthThrottle::new(),
17191707
})
17201708
}
17211709

0 commit comments

Comments
 (0)