Skip to content

Commit 02c13c9

Browse files
committed
write outputs to ydoc client-side for jupyter-collaboration
1 parent ca52a46 commit 02c13c9

3 files changed

Lines changed: 99 additions & 25 deletions

File tree

src/execution/remote/mod.rs

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ impl ExecutionBackend for RemoteExecutor {
179179
let cell_idx = cell_index.context("cell_index required for remote execution")?;
180180
let ydoc = self.ydoc.as_mut().context("Y.js client not connected")?;
181181
let http = reqwest::Client::new();
182+
let client_writes = !ydoc.server_writes_outputs();
182183

183184
// 1. Fire execute request
184185
let msg_id = ws
@@ -187,6 +188,7 @@ impl ExecutionBackend for RemoteExecutor {
187188

188189
// 2. Watch for changes on the ydoc for this cell
189190
let mut outputs: Vec<nbformat::v4::Output> = Vec::new();
191+
let mut kernel_outputs: Vec<nbformat::v4::Output> = Vec::new();
190192
let mut fetched_urls: HashSet<String> = HashSet::new();
191193
let mut seen_indices: HashSet<usize> = HashSet::new();
192194
let mut idle_received = false;
@@ -227,28 +229,20 @@ impl ExecutionBackend for RemoteExecutor {
227229
}
228230

229231
if idle_received {
230-
let has_error = outputs
231-
.iter()
232-
.any(|o| matches!(o, nbformat::v4::Output::Error(_)));
233-
let error_info = outputs.iter().find_map(|o| {
234-
if let nbformat::v4::Output::Error(err) = o {
235-
Some(ExecutionError {
236-
ename: err.ename.clone(),
237-
evalue: err.evalue.clone(),
238-
traceback: err.traceback.clone(),
239-
})
240-
} else {
241-
None
242-
}
243-
});
244-
return if has_error {
245-
Ok(ExecutionResult::error(outputs, ec, error_info.unwrap()))
246-
} else {
247-
Ok(ExecutionResult::success(outputs, ec))
248-
};
232+
return Self::build_result(outputs, ec);
249233
}
250234
}
251235

236+
// When the server doesn't write outputs and kernel is done,
237+
// write collected outputs to Y.js ourselves, sync, then let
238+
// the read loop above pick them up on the next iteration.
239+
if client_writes && idle_received && !ec_ready && !kernel_outputs.is_empty() {
240+
ydoc.update_cell_outputs(cell_idx, kernel_outputs.clone())?;
241+
ydoc.update_cell_execution_count(cell_idx, expected_ec)?;
242+
ydoc.sync().await?;
243+
continue;
244+
}
245+
252246
// 4. Wait for new messages
253247
if idle_received {
254248
match tokio::time::timeout_at(deadline, ydoc.recv_update()).await {
@@ -273,7 +267,13 @@ impl ExecutionBackend for RemoteExecutor {
273267
idle_received = true;
274268
}
275269
}
276-
_ => {}
270+
_ => {
271+
if client_writes {
272+
if let Some(output) = Self::kernel_msg_to_output(&msg.content) {
273+
kernel_outputs.push(output);
274+
}
275+
}
276+
}
277277
}
278278
}
279279
}
@@ -285,6 +285,11 @@ impl ExecutionBackend for RemoteExecutor {
285285
}
286286
}
287287

288+
// Fallback: if we collected kernel outputs but never wrote them
289+
if client_writes && !kernel_outputs.is_empty() {
290+
return Self::build_result(kernel_outputs, expected_ec);
291+
}
292+
288293
let ec = ydoc
289294
.read_cell_outputs(cell_idx)
290295
.ok()
@@ -310,3 +315,70 @@ impl ExecutionBackend for RemoteExecutor {
310315
Ok(())
311316
}
312317
}
318+
319+
impl RemoteExecutor {
320+
fn build_result(
321+
outputs: Vec<nbformat::v4::Output>,
322+
ec: Option<i64>,
323+
) -> Result<ExecutionResult> {
324+
let has_error = outputs
325+
.iter()
326+
.any(|o| matches!(o, nbformat::v4::Output::Error(_)));
327+
let error_info = outputs.iter().find_map(|o| {
328+
if let nbformat::v4::Output::Error(err) = o {
329+
Some(ExecutionError {
330+
ename: err.ename.clone(),
331+
evalue: err.evalue.clone(),
332+
traceback: err.traceback.clone(),
333+
})
334+
} else {
335+
None
336+
}
337+
});
338+
if has_error {
339+
Ok(ExecutionResult::error(outputs, ec, error_info.unwrap()))
340+
} else {
341+
Ok(ExecutionResult::success(outputs, ec))
342+
}
343+
}
344+
345+
fn kernel_msg_to_output(content: &JupyterMessageContent) -> Option<nbformat::v4::Output> {
346+
match content {
347+
JupyterMessageContent::StreamContent(stream) => {
348+
let name = match stream.name {
349+
jupyter_protocol::Stdio::Stdout => "stdout".to_string(),
350+
jupyter_protocol::Stdio::Stderr => "stderr".to_string(),
351+
};
352+
Some(nbformat::v4::Output::Stream {
353+
name,
354+
text: nbformat::v4::MultilineString(stream.text.clone()),
355+
})
356+
}
357+
JupyterMessageContent::ExecuteResult(result) => {
358+
let json = serde_json::json!({
359+
"output_type": "execute_result",
360+
"execution_count": result.execution_count.value(),
361+
"data": result.data,
362+
"metadata": result.metadata
363+
});
364+
serde_json::from_value(json).ok()
365+
}
366+
JupyterMessageContent::DisplayData(display) => {
367+
let json = serde_json::json!({
368+
"output_type": "display_data",
369+
"data": display.data,
370+
"metadata": display.metadata
371+
});
372+
serde_json::from_value(json).ok()
373+
}
374+
JupyterMessageContent::ErrorOutput(error) => {
375+
Some(nbformat::v4::Output::Error(nbformat::v4::ErrorOutput {
376+
ename: error.ename.clone(),
377+
evalue: error.evalue.clone(),
378+
traceback: error.traceback.clone(),
379+
}))
380+
}
381+
_ => None,
382+
}
383+
}
384+
}

src/execution/remote/output_conversion.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use std::collections::HashMap;
66
use yrs::{Any, Array, ArrayPrelim, ArrayRef, Map, MapPrelim, TransactionMut};
77

88
/// Convert an nbformat Output to a MapPrelim that can be inserted into the outputs array
9-
#[allow(dead_code)]
109
pub fn output_to_map_prelim(output: &Output) -> MapPrelim {
1110
match output {
1211
Output::Stream { name, text } => MapPrelim::from([
@@ -92,7 +91,6 @@ fn json_to_any(value: &JsonValue) -> Any {
9291
}
9392

9493
/// Update a cell's outputs in the Y.js document
95-
#[allow(dead_code)]
9694
pub fn update_cell_outputs(
9795
txn: &mut TransactionMut,
9896
cells_array: &ArrayRef,
@@ -136,7 +134,6 @@ pub fn update_cell_outputs(
136134
}
137135

138136
/// Update a cell's execution_count in the Y.js document
139-
#[allow(dead_code)]
140137
pub fn update_cell_execution_count(
141138
txn: &mut TransactionMut,
142139
cells_array: &ArrayRef,

src/execution/remote/ydoc.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ pub struct YDocClient {
6666
file_id: String,
6767
/// Track the document state when we last synced, so we only send changes
6868
last_state: StateVector,
69+
/// Whether the server writes outputs to Y.js (JSD does, jupyter-collaboration doesn't)
70+
server_writes_outputs: bool,
6971
}
7072

7173
impl YDocClient {
@@ -89,6 +91,7 @@ impl YDocClient {
8991
ws: ws_stream,
9092
file_id,
9193
last_state: StateVector::default(),
94+
server_writes_outputs: session_id.is_none(),
9295
};
9396

9497
// Step 4: Perform Y.js sync handshake with timeout
@@ -345,8 +348,11 @@ impl YDocClient {
345348
Ok(())
346349
}
347350

351+
pub fn server_writes_outputs(&self) -> bool {
352+
self.server_writes_outputs
353+
}
354+
348355
/// Update cell outputs in the Y.js document
349-
#[allow(dead_code)]
350356
pub fn update_cell_outputs(&mut self, cell_index: usize, outputs: Vec<Output>) -> Result<()> {
351357
let cells_array: ArrayRef = self.doc.get_or_insert_array("cells");
352358
let mut txn = self.doc.transact_mut();
@@ -358,7 +364,6 @@ impl YDocClient {
358364
}
359365

360366
/// Update cell execution_count in the Y.js document
361-
#[allow(dead_code)]
362367
pub fn update_cell_execution_count(
363368
&mut self,
364369
cell_index: usize,

0 commit comments

Comments
 (0)