Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions otel-worker-cli/src/commands/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio_tungstenite::tungstenite::Message;
Expand Down Expand Up @@ -109,17 +110,29 @@ pub async fn handle_command(args: Args) -> Result<()> {

#[derive(Clone)]
struct McpState {
/// A single ApiClient that is used by all the [`McpSession`]'s. This allows
/// for connection re-use. ApiClient itself is a [`Arc`] wrapped, hence we
/// just use it without having to wrap it ourselves in a [`Arc`].
api_client: ApiClient,

/// This stores a reference to all the sessions by its session id.
sessions: Arc<RwLock<HashMap<String, McpSession>>>,

/// [`shutdown`] indicates that the server has received a terminate signal
/// and is in the process of shutting down. This means that no new
/// connections are accepted.
shutdown: Arc<AtomicBool>,
}

impl McpState {
fn new(api_client: ApiClient) -> Self {
let sessions = Arc::new(RwLock::new(HashMap::new()));
let shutdown = Arc::new(AtomicBool::new(false));

Self {
api_client,
sessions,
shutdown,
}
}

Expand Down Expand Up @@ -160,6 +173,28 @@ impl McpState {
let message = ServerMessage::Notification(notification);
self.broadcast(message).await
}

/// Initiate the shutdown process.
async fn shutdown(&self) {
if self.is_shutting_down() {
return;
}

// Mark the instance as shutdown
self.shutdown.store(true, Ordering::Relaxed);

// Go through all the sessions and drop them. This makes sure that the
// receiver will also be closed since all the transmitters will be gone.
let mut sessions = self.sessions.write().await;
for (session_id, _mcp_session) in sessions.drain() {
debug!(?session_id, "Closing session");
}
}

/// Checks if the current server instance is shutting down.
fn is_shutting_down(&self) -> bool {
self.shutdown.load(Ordering::Relaxed)
}
}

#[derive(Debug, Clone)]
Expand Down
42 changes: 24 additions & 18 deletions otel-worker-cli/src/commands/mcp/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ use super::McpState;
use anyhow::{Context, Result};
use axum::extract::{MatchedPath, Query, Request, State};
use axum::middleware::{self, Next};
use axum::response::sse::Event;
use axum::response::sse::{Event, KeepAlive};
use axum::response::{IntoResponse, Response, Sse};
use axum::routing::{get, post};
use axum::Json;
use futures::{Stream, StreamExt};
use futures::StreamExt;
use http::StatusCode;
use rust_mcp_schema::schema_utils::ClientMessage;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::process::exit;
use std::time::{Duration, Instant};
use std::time::Instant;
use tokio::net::TcpListener;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, info, info_span, warn, Instrument};
Expand All @@ -27,10 +27,10 @@ pub(crate) async fn serve(listen_address: &str, state: McpState) -> Result<()> {
"Starting MCP server",
);

let mcp_service = build_mcp_service(state);
let mcp_service = build_mcp_service(state.clone());

axum::serve(listener, mcp_service)
.with_graceful_shutdown(shutdown_requested())
.with_graceful_shutdown(shutdown_requested(state))
.await?;

Ok(())
Expand Down Expand Up @@ -86,36 +86,40 @@ async fn json_rpc_handler(
}

#[tracing::instrument(skip(state))]
async fn sse_handler(
State(mut state): State<McpState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
async fn sse_handler(State(mut state): State<McpState>) -> Response {
debug!("MCP client connected to the SSE handler");

// Do not accept anymore clients if the server is shutting down.
if state.is_shutting_down() {
return (StatusCode::SERVICE_UNAVAILABLE, "server is shutting down").into_response();
}

let (session_id, messages) = state.register_session().await;

// This message needs to be send as soon as the client accesses the page.
let initial_event = futures::stream::once(async move {
let querystring = serde_urlencoded::to_string(JsonRpcQuery::new(Some(session_id)))
.expect("querystring encoding is expected to work");
Ok(Event::default()
// We need to explicitly specify the error type, since we are not
// constructing this anywhere

Event::default()
.event("endpoint")
.data(format!("/messages?{querystring}")))
.data(format!("/messages?{querystring}"))
});

// This stream will contain all the ServerMessages which are converted to
// Sse Events.
let events = ReceiverStream::new(messages).map(|message| {
Ok(Event::default()
Event::default()
.event("message")
.json_data(message)
.expect("unable to serialize data"))
.expect("unable to serialize data")
});

Sse::new(initial_event.chain(events)).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(5))
.text("keep-alive-text"),
)
Sse::new(initial_event.chain(events).map(Ok::<Event, Infallible>))
.keep_alive(KeepAlive::new())
.into_response()
}

async fn log_and_metrics(req: Request, next: Next) -> impl IntoResponse {
Expand Down Expand Up @@ -151,13 +155,15 @@ async fn log_and_metrics(req: Request, next: Next) -> impl IntoResponse {
/// Another SIGINT listener task is spawned just before resolving this task,
/// which will forcefully exit the application. This is to prevent not being
/// able to shutdown, if the graceful shutdown doesn't work.
async fn shutdown_requested() {
async fn shutdown_requested(state: McpState) {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen for ctrl-c");

info!("Received SIGINT, shutting down api server");

state.shutdown().await;

// Monitor for another SIGINT, and force shutdown if received.
tokio::spawn(async {
tokio::signal::ctrl_c()
Expand Down
Loading