Skip to content

websocket proxy: ping/pong downstream clients and proactively disconnect #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions websocket-proxy/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::rate_limit::Ticket;
use axum::extract::ws::Message;
use axum::extract::ws::WebSocket;
use axum::Error;
use std::net::IpAddr;
use std::time::Duration;
use std::time::Instant;

pub struct ClientConnection {
client_addr: IpAddr,
_ticket: Ticket,
pub(crate) websocket: WebSocket,
last_pong: Instant,
}

impl ClientConnection {
Expand All @@ -15,13 +19,30 @@ impl ClientConnection {
client_addr,
_ticket: ticket,
websocket,
last_pong: Instant::now(),
}
}

pub async fn send(&mut self, data: String) -> Result<(), Error> {
self.websocket.send(data.into_bytes().into()).await
}

pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.websocket.recv().await
}

pub async fn ping(&mut self) -> Result<(), Error> {
self.websocket.send(Message::Ping(vec![].into())).await
}

pub fn update_pong(&mut self) {
self.last_pong = Instant::now();
}

pub fn is_healthy(&self, timeout: Duration) -> bool {
self.last_pong.elapsed() < timeout
}

pub fn id(&self) -> String {
self.client_addr.to_string()
}
Expand Down
93 changes: 72 additions & 21 deletions websocket-proxy/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use crate::client::ClientConnection;
use crate::metrics::Metrics;
use std::sync::Arc;
use axum::extract::ws::Message;
use std::{sync::Arc, time::Duration};
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::broadcast::Sender;
use tracing::{info, trace, warn};

const PING_INTERVAL: Duration = Duration::from_secs(10);
const PONG_TIMEOUT: Duration = Duration::from_secs(10);
const HEALTHCHECK_INTERVAL: Duration = Duration::from_secs(2);

#[derive(Clone)]
pub struct Registry {
sender: Sender<String>,
Expand All @@ -24,31 +29,77 @@ impl Registry {
metrics.new_connections.increment(1);

tokio::spawn(async move {
let mut ping_timer = tokio::time::interval(PING_INTERVAL);
let mut healthcheck_timer = tokio::time::interval(HEALTHCHECK_INTERVAL);

loop {
match receiver.recv().await {
Ok(msg) => match client.send(msg.clone()).await {
Ok(_) => {
trace!(message = "message sent to client", client = client.id());
metrics.sent_messages.increment(1);
tokio::select! {
// Forward messages from upstream to client
upstream_msg = receiver.recv() => {
match upstream_msg {
Ok(msg) => match client.send(msg.clone()).await {
Ok(_) => {
trace!(message = "message sent to client", client = client.id());
metrics.sent_messages.increment(1);
}
Err(e) => {
warn!(
message = "failed to send data to client",
client = client.id(),
error = e.to_string()
);
metrics.failed_messages.increment(1);
break;
}
},
Err(RecvError::Closed) => {
info!(message = "upstream connection closed", client = client.id());
break;
}
Err(RecvError::Lagged(_)) => {
info!(message = "client is lagging", client = client.id());
metrics.lag_events.increment(1);
receiver = receiver.resubscribe();
}
}
}

// Handle incoming messages from client (pongs)
client_msg = client.recv() => {
match client_msg {
Some(Ok(msg)) => {
if let Message::Pong(_) = msg {
trace!(message = "received pong from client", client = client.id());
client.update_pong();
}
}
Some(Err(e)) => {
warn!(message = "error receiving message from client", client = client.id(), error = e.to_string());
break;
}
None => {
info!(message = "client connection closed", client = client.id());
break;
}
}
Err(e) => {
warn!(
message = "failed to send data to client",
client = client.id(),
error = e.to_string()
);
metrics.failed_messages.increment(1);
}

// Send pings to client periodically
_ = ping_timer.tick() => {
if let Err(e) = client.ping().await {
warn!(message = "failed to send ping to client", client = client.id(), error = e.to_string());
break;
}
},
Err(RecvError::Closed) => {
info!(message = "upstream connection closed", client = client.id());
break;

trace!(message = "ping sent to client", client = client.id());
}
Err(RecvError::Lagged(_)) => {
info!(message = "client is lagging", client = client.id());
metrics.lag_events.increment(1);
receiver = receiver.resubscribe();

// Check client health
_ = healthcheck_timer.tick() => {
if !client.is_healthy(PONG_TIMEOUT) {
warn!(message = "client health check failed", client = client.id());
break;
}
}
}
}
Expand Down