Skip to content

Commit 1978efd

Browse files
committed
Fix failing test - test_no_stored_outputs AND format fix
1 parent b49496a commit 1978efd

3 files changed

Lines changed: 29 additions & 11 deletions

File tree

examples/knowledge-tuning/04_Knowledge_Mixing/Knowledge_Mixing.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@
160160
"\n",
161161
" # Filter out problematic questions\n",
162162
" ds = ds.filter(\n",
163-
" lambda x: \"...\" not in x[\"question\"]\n",
164-
" and \"<question>\" not in x[\"question\"]\n",
165-
" and \"<Insert question here>\" not in x[\"question\"]\n",
163+
" lambda x: (\n",
164+
" \"...\" not in x[\"question\"]\n",
165+
" and \"<question>\" not in x[\"question\"]\n",
166+
" and \"<Insert question here>\" not in x[\"question\"]\n",
167+
" )\n",
166168
" )\n",
167169
"\n",
168170
" # Clean response text\n",

examples/knowledge-tuning/04_Knowledge_Mixing/utils/knowledge_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def sample_doc_qa(
100100
def _clean_response_text(df: pl.DataFrame) -> pl.DataFrame:
101101
"""Clean response text by removing markers and whitespace."""
102102
return df.with_columns(
103-
pl.col("response")
103+
pl
104+
.col("response")
104105
.str.replace_all(r"\[END\]", "")
105106
.str.replace_all(r"\[ANSWER\]", "")
106107
.str.strip_chars()
@@ -111,7 +112,8 @@ def _clean_response_text(df: pl.DataFrame) -> pl.DataFrame:
111112
def _create_metadata(df: pl.DataFrame) -> pl.Expr:
112113
"""Create metadata JSON structure."""
113114
return (
114-
pl.struct([
115+
pl
116+
.struct([
115117
pl.col("document").alias("sdg_document"),
116118
pl.lit("document_knowledge_qa").alias("dataset"),
117119
pl.col("raw_document"),
@@ -232,7 +234,8 @@ def generate_knowledge_qa_dataset(
232234
"reasoning",
233235
]
234236
messages_expr = (
235-
pl.struct(message_columns)
237+
pl
238+
.struct(message_columns)
236239
.map_elements(_create_messages_with_reasoning_no_document)
237240
.alias("messages")
238241
)
@@ -245,21 +248,24 @@ def generate_knowledge_qa_dataset(
245248
"reasoning",
246249
]
247250
messages_expr = (
248-
pl.struct(message_columns)
251+
pl
252+
.struct(message_columns)
249253
.map_elements(_create_messages_with_reasoning)
250254
.alias("messages")
251255
)
252256
elif keep_document_in_context:
253257
message_columns = ["question", "response", "document", "document_outline"]
254258
messages_expr = (
255-
pl.struct(message_columns)
259+
pl
260+
.struct(message_columns)
256261
.map_elements(_create_messages_without_reasoning)
257262
.alias("messages")
258263
)
259264
else:
260265
message_columns = ["question", "response", "document", "document_outline"]
261266
messages_expr = (
262-
pl.struct(message_columns)
267+
pl
268+
.struct(message_columns)
263269
.map_elements(_create_messages_without_reasoning_no_document)
264270
.alias("messages")
265271
)
@@ -307,7 +313,8 @@ def count_tokens(text: str) -> int:
307313
return len(tokenizer.encode(text))
308314

309315
return df.with_columns(
310-
pl.col(column_name)
316+
pl
317+
.col(column_name)
311318
.map_elements(apply_chat_template, return_dtype=pl.String)
312319
.map_elements(count_tokens, return_dtype=pl.Int32)
313320
.alias("token_length")

tests/validation/test_notebook_content.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,22 @@ def test_no_execution_counts(notebook_path, relative_path):
4343

4444

4545
def test_no_stored_outputs(notebook_path, relative_path):
46-
"""Test that notebooks have no stored outputs (should be cleared)."""
46+
"""Test that notebooks have no stored outputs (should be cleared).
47+
48+
Cells with 'keep_output' tag in metadata are ignored.
49+
"""
4750
with open(notebook_path, encoding="utf-8") as f:
4851
nb = json.load(f)
4952

5053
cells_with_outputs = []
5154
for i, cell in enumerate(nb.get("cells", [])):
5255
if cell.get("cell_type") == "code":
56+
# Check if cell has keep_output tag
57+
metadata = cell.get("metadata", {})
58+
tags = metadata.get("tags", [])
59+
if "keep_output" in tags:
60+
continue
61+
5362
outputs = cell.get("outputs", [])
5463
if len(outputs) > 0:
5564
cells_with_outputs.append((i, len(outputs)))

0 commit comments

Comments
 (0)