Skip to content
112 changes: 92 additions & 20 deletions src/execution/remote/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ impl ExecutionBackend for RemoteExecutor {
let cell_idx = cell_index.context("cell_index required for remote execution")?;
let ydoc = self.ydoc.as_mut().context("Y.js client not connected")?;
let http = reqwest::Client::new();
let client_writes = !ydoc.server_writes_outputs();

// 1. Fire execute request
let msg_id = ws
Expand All @@ -187,6 +188,7 @@ impl ExecutionBackend for RemoteExecutor {

// 2. Watch for changes on the ydoc for this cell
let mut outputs: Vec<nbformat::v4::Output> = Vec::new();
let mut kernel_outputs: Vec<nbformat::v4::Output> = Vec::new();
let mut fetched_urls: HashSet<String> = HashSet::new();
let mut seen_indices: HashSet<usize> = HashSet::new();
let mut idle_received = false;
Expand Down Expand Up @@ -227,28 +229,20 @@ impl ExecutionBackend for RemoteExecutor {
}

if idle_received {
let has_error = outputs
.iter()
.any(|o| matches!(o, nbformat::v4::Output::Error(_)));
let error_info = outputs.iter().find_map(|o| {
if let nbformat::v4::Output::Error(err) = o {
Some(ExecutionError {
ename: err.ename.clone(),
evalue: err.evalue.clone(),
traceback: err.traceback.clone(),
})
} else {
None
}
});
return if has_error {
Ok(ExecutionResult::error(outputs, ec, error_info.unwrap()))
} else {
Ok(ExecutionResult::success(outputs, ec))
};
return Self::build_result(outputs, ec);
}
}

// When the server doesn't write outputs and kernel is done,
// write collected outputs to Y.js ourselves, sync, then let
// the read loop above pick them up on the next iteration.
if client_writes && idle_received && !ec_ready && !kernel_outputs.is_empty() {
ydoc.update_cell_outputs(cell_idx, kernel_outputs.clone())?;
ydoc.update_cell_execution_count(cell_idx, expected_ec)?;
ydoc.sync().await?;
continue;
}

// 4. Wait for new messages
if idle_received {
match tokio::time::timeout_at(deadline, ydoc.recv_update()).await {
Expand All @@ -273,7 +267,13 @@ impl ExecutionBackend for RemoteExecutor {
idle_received = true;
}
}
_ => {}
_ => {
if client_writes {
if let Some(output) = Self::kernel_msg_to_output(&msg.content) {
kernel_outputs.push(output);
}
}
}
}
}
}
Expand All @@ -285,6 +285,11 @@ impl ExecutionBackend for RemoteExecutor {
}
}

// Fallback: if we collected kernel outputs but never wrote them
if client_writes && !kernel_outputs.is_empty() {
return Self::build_result(kernel_outputs, expected_ec);
}

let ec = ydoc
.read_cell_outputs(cell_idx)
.ok()
Expand All @@ -310,3 +315,70 @@ impl ExecutionBackend for RemoteExecutor {
Ok(())
}
}

impl RemoteExecutor {
fn build_result(
outputs: Vec<nbformat::v4::Output>,
ec: Option<i64>,
) -> Result<ExecutionResult> {
let has_error = outputs
.iter()
.any(|o| matches!(o, nbformat::v4::Output::Error(_)));
let error_info = outputs.iter().find_map(|o| {
if let nbformat::v4::Output::Error(err) = o {
Some(ExecutionError {
ename: err.ename.clone(),
evalue: err.evalue.clone(),
traceback: err.traceback.clone(),
})
} else {
None
}
});
if has_error {
Ok(ExecutionResult::error(outputs, ec, error_info.unwrap()))
} else {
Ok(ExecutionResult::success(outputs, ec))
}
}

fn kernel_msg_to_output(content: &JupyterMessageContent) -> Option<nbformat::v4::Output> {
match content {
JupyterMessageContent::StreamContent(stream) => {
let name = match stream.name {
jupyter_protocol::Stdio::Stdout => "stdout".to_string(),
jupyter_protocol::Stdio::Stderr => "stderr".to_string(),
};
Some(nbformat::v4::Output::Stream {
name,
text: nbformat::v4::MultilineString(stream.text.clone()),
})
}
JupyterMessageContent::ExecuteResult(result) => {
let json = serde_json::json!({
"output_type": "execute_result",
"execution_count": result.execution_count.value(),
"data": result.data,
"metadata": result.metadata
});
serde_json::from_value(json).ok()
}
JupyterMessageContent::DisplayData(display) => {
let json = serde_json::json!({
"output_type": "display_data",
"data": display.data,
"metadata": display.metadata
});
serde_json::from_value(json).ok()
}
JupyterMessageContent::ErrorOutput(error) => {
Some(nbformat::v4::Output::Error(nbformat::v4::ErrorOutput {
ename: error.ename.clone(),
evalue: error.evalue.clone(),
traceback: error.traceback.clone(),
}))
}
_ => None,
}
}
}
3 changes: 0 additions & 3 deletions src/execution/remote/output_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::collections::HashMap;
use yrs::{Any, Array, ArrayPrelim, ArrayRef, Map, MapPrelim, TransactionMut};

/// Convert an nbformat Output to a MapPrelim that can be inserted into the outputs array
#[allow(dead_code)]
pub fn output_to_map_prelim(output: &Output) -> MapPrelim {
match output {
Output::Stream { name, text } => MapPrelim::from([
Expand Down Expand Up @@ -92,7 +91,6 @@ fn json_to_any(value: &JsonValue) -> Any {
}

/// Update a cell's outputs in the Y.js document
#[allow(dead_code)]
pub fn update_cell_outputs(
txn: &mut TransactionMut,
cells_array: &ArrayRef,
Expand Down Expand Up @@ -136,7 +134,6 @@ pub fn update_cell_outputs(
}

/// Update a cell's execution_count in the Y.js document
#[allow(dead_code)]
pub fn update_cell_execution_count(
txn: &mut TransactionMut,
cells_array: &ArrayRef,
Expand Down
14 changes: 13 additions & 1 deletion src/execution/remote/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use jupyter_protocol::messaging::{JupyterMessage, JupyterMessageContent};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
use tokio_tungstenite::{connect_async, tungstenite::Message};

/// WebSocket connection to a Jupyter kernel
Expand All @@ -23,7 +25,17 @@ pub struct KernelWebSocket {
impl KernelWebSocket {
/// Connect to a kernel via WebSocket
pub async fn connect(ws_url: &str) -> Result<Self> {
let (ws_stream, _) = connect_async(ws_url)
let mut request = ws_url
.into_client_request()
.context("Invalid kernel WebSocket URL")?;
request.headers_mut().insert(
SEC_WEBSOCKET_PROTOCOL,
"v1.kernel.websocket.jupyter.org"
.parse()
.expect("valid header value"),
);

let (ws_stream, _) = connect_async(request)
.await
.context("Failed to connect to kernel WebSocket")?;

Expand Down
114 changes: 89 additions & 25 deletions src/execution/remote/ydoc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ struct FileIdResponse {
id: String,
}

#[derive(Debug, Deserialize)]
struct CollabSessionResponse {
#[serde(rename = "fileId")]
file_id: String,
#[serde(rename = "sessionId")]
session_id: String,
}

/// Y.js document client for syncing notebook changes with Jupyter Server
pub struct YDocClient {
doc: Doc,
Expand All @@ -58,16 +66,18 @@ pub struct YDocClient {
file_id: String,
/// Track the document state when we last synced, so we only send changes
last_state: StateVector,
/// Whether the server writes outputs to Y.js (JSD does, jupyter-collaboration doesn't)
server_writes_outputs: bool,
}

impl YDocClient {
/// Connect to Y.js room for the given notebook
pub async fn connect(server_url: String, token: String, notebook_path: String) -> Result<Self> {
// Step 1: Get file ID from FileID API
let file_id = Self::get_file_id(&server_url, &token, &notebook_path).await?;
// Step 1: Get file ID (and session ID if using jupyter-collaboration)
let (file_id, session_id) = Self::get_file_id(&server_url, &token, &notebook_path).await?;

// Step 2: Connect to room WebSocket
let ws_url = Self::build_room_ws_url(&server_url, &file_id, &token)?;
let ws_url = Self::build_room_ws_url(&server_url, &file_id, &token, session_id.as_deref())?;

let (ws_stream, _) = connect_async(&ws_url)
.await
Expand All @@ -81,6 +91,7 @@ impl YDocClient {
ws: ws_stream,
file_id,
last_state: StateVector::default(),
server_writes_outputs: session_id.is_none(),
};

// Step 4: Perform Y.js sync handshake with timeout
Expand All @@ -94,13 +105,18 @@ impl YDocClient {
}
}

/// Get unique file ID for notebook path via FileID API
async fn get_file_id(server_url: &str, token: &str, notebook_path: &str) -> Result<String> {
let url = format!("{}/api/fileid/index", server_url);

/// Returns (file_id, session_id). session_id is present only via jupyter-collaboration.
async fn get_file_id(
server_url: &str,
token: &str,
notebook_path: &str,
) -> Result<(String, Option<String>)> {
let http_client = HttpClient::new();

// Try jupyter-server-documents: POST /api/fileid/index (create-if-not-exists)
let index_url = format!("{}/api/fileid/index", server_url);
let response = http_client
.post(&url)
.post(&index_url)
.query(&[("path", notebook_path)])
.header("Authorization", format!("token {}", token))
.send()
Expand All @@ -117,28 +133,68 @@ impl YDocClient {
}
})?;

if !response.status().is_success() {
if response.status().is_success() {
let file_id_response: FileIdResponse = response
.json()
.await
.context("Failed to parse FileID API response")?;
return Ok((file_id_response.id, None));
}

if response.status().as_u16() != 404 {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!(
"FileID API request failed with status {}: {}. \
Make sure jupyter-server-documents is installed: \
pip install jupyter-server-documents",
"FileID API request failed with status {}: {}",
status,
error_text
);
}

let file_id_response: FileIdResponse = response
.json()
// Fallback: try jupyter-collaboration session endpoint (indexes the file if needed)
let mut session_url = Url::parse(server_url).context("Invalid server URL")?;
session_url.set_path(&format!("/api/collaboration/session/{}", notebook_path));
let response = http_client
.put(session_url)
.header("Authorization", format!("token {}", token))
.header("Content-Type", "application/json")
.body(r#"{"format":"json","type":"notebook"}"#)
.send()
.await
.context("Failed to parse FileID API response")?;
.map_err(|e| anyhow::anyhow!("Failed to call collaboration session API: {}", e))?;

if response.status().is_success() {
let session_response: CollabSessionResponse = response
.json()
.await
.context("Failed to parse collaboration session response")?;
return Ok((session_response.file_id, Some(session_response.session_id)));
}

Ok(file_id_response.id)
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
if status.as_u16() == 404 {
anyhow::bail!(
"FileID API request failed with status {}: {}. \
Make sure jupyter-server-documents or jupyter-collaboration is installed: \
pip install jupyter-server-documents (or pip install jupyter-collaboration)",
status,
error_text
);
}
anyhow::bail!(
"Collaboration session API request failed with status {}: {}",
status,
error_text
);
}

/// Build WebSocket URL for Y.js room
fn build_room_ws_url(server_url: &str, file_id: &str, token: &str) -> Result<String> {
fn build_room_ws_url(
server_url: &str,
file_id: &str,
token: &str,
session_id: Option<&str>,
) -> Result<String> {
// Parse base URL to extract host and port
let base_url = Url::parse(server_url).context("Invalid server URL")?;

Expand All @@ -159,12 +215,18 @@ impl YDocClient {
"ws"
};

let ws_url = format!(
"{}://{}:{}/api/collaboration/room/json:notebook:{}?token={}",
ws_scheme, host, port, file_id, token
);
let mut ws_url = Url::parse(&format!(
"{}://{}:{}/api/collaboration/room/json:notebook:{}",
ws_scheme, host, port, file_id
))
.context("Failed to build WebSocket URL")?;

Ok(ws_url)
ws_url.query_pairs_mut().append_pair("token", token);
if let Some(sid) = session_id {
ws_url.query_pairs_mut().append_pair("sessionId", sid);
}

Ok(ws_url.to_string())
}

/// Perform Y.js sync protocol handshake
Expand Down Expand Up @@ -286,8 +348,11 @@ impl YDocClient {
Ok(())
}

pub fn server_writes_outputs(&self) -> bool {
self.server_writes_outputs
}

/// Update cell outputs in the Y.js document
#[allow(dead_code)]
pub fn update_cell_outputs(&mut self, cell_index: usize, outputs: Vec<Output>) -> Result<()> {
let cells_array: ArrayRef = self.doc.get_or_insert_array("cells");
let mut txn = self.doc.transact_mut();
Expand All @@ -299,7 +364,6 @@ impl YDocClient {
}

/// Update cell execution_count in the Y.js document
#[allow(dead_code)]
pub fn update_cell_execution_count(
&mut self,
cell_index: usize,
Expand Down
Loading