Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ def save_mixed_dataset(self, output_path, num_proc):
as a jsonl file.
"""
mixed_ds = self._create_mixed_dataset(num_proc)

# filter out any records where the any message content is None
mixed_ds = mixed_ds.filter(
lambda x: all(
message.get("content")
for message in x["messages"]
if message.get("role") != "system"
)
)

mixed_ds.to_json(output_path, orient="records", lines=True)
logger.info(f"Mixed Dataset saved to {output_path}")

Expand Down
11 changes: 11 additions & 0 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def _gen_train_data(
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
assistant = _unescape(_get_response_hack(synth_example))
# filter out any assistant message that is empty
if not assistant:
continue

train_entry = {
"system": system_prompt,
"user": _unescape(user),
Expand Down Expand Up @@ -594,6 +598,13 @@ def postprocess_taxonomy(
if leaf_node_type == "knowledge":
is_knowledge = True

if is_knowledge:
# Filter out rows with no document, they cause errors in the datamixing code
for i in range(len(samples) - 1, -1, -1):
if not samples[i].get("document"):
logger.warning("Removing sample without document: %s", samples[i])
samples.pop(i)

samples_ds = Dataset.from_list(samples)
logger.debug("Postprocessing from samples: %s", samples_ds)
all_generated_data.append(samples_ds)
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/test_datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,67 @@ def test_mix_instructlab_07x_precomputed_skills_with_unmask(tmp_path):
assert (
sample.get("unmask", None) is not None
), "Mixed sample does not have unmask"


def test_save_mixed_dataset_with_none_content(tmp_path):
"""
Test that we filter out mixed dataset records where any message content is None.
"""

# Create a knowledge dataset
knowledge_dataset = load_auxiliary_dataset()
number_of_records = len(knowledge_dataset)
# append a record with content=None and content="", both should be filtered out
knowledge_dataset = knowledge_dataset.add_item(
{
"id": "test_001",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of Ireland?"},
{"role": "assistant", "content": None},
],
}
)
knowledge_dataset = knowledge_dataset.add_item(
{
"id": "test_002",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of Ireland?"},
{"role": "assistant", "content": "Dublin"},
],
}
)

knowledge_dataset = knowledge_dataset.add_item(
{
"id": "test_003",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of Ireland?"},
{"role": "assistant", "content": ""},
],
}
)

Comment thread
aakankshaduggal marked this conversation as resolved.
knowledge_path = os.path.join(tmp_path, "knowledge.jsonl")
jldump(knowledge_dataset, knowledge_path)

output_path = os.path.join(tmp_path, "output.jsonl")
recipe = Recipe()
recipe.add_dataset(knowledge_path, 1.0)
recipe.save_mixed_dataset(output_path, TEST_NUM_PROCS)

# Ensure the mixed dataset is saved correctly
mixed_samples = load_dataset("json", data_files=output_path, split="train")

# the row with content=None should have been removed
assert (
len(mixed_samples) == number_of_records + 1
), f"Expected {number_of_records + 1} records in mixed dataset"

# None of the mixed samples should have content=None
for sample in mixed_samples:
assert all(
[message.get("content") is not None for message in sample["messages"]]
), "Mixed sample has content=None"
48 changes: 48 additions & 0 deletions tests/unit/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from instructlab.sdg import LLMBlock, PipelineContext
from instructlab.sdg.generate_data import (
_context_init,
_gen_train_data,
_locate_docling_models,
_sdg_init,
generate_data,
)
from instructlab.sdg.utils.json import jlload

# Local
from ..taxonomy import load_test_skills
Expand Down Expand Up @@ -586,3 +588,49 @@ def test_locate_docling_models_config_not_found(testdata_path):
os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir"))
docling_model_path = _locate_docling_models()
assert docling_model_path is None


class TestGenTrainData(unittest.TestCase):
"""Test the _gen_train_data function with small synthetic examples."""

def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.system_prompt = "Test system prompt"

def tearDown(self):
shutil.rmtree(self.test_dir)

def test_gen_train_data_with_empty_response(self):
"""Test _gen_train_data with synthetic examples with blank responses."""
# Create mock synthetic examples with blank responses
machine_instruction_data = [
[
{"question": "Q1", "response": "", "context": "C1"},
{"question": "Q2", "response": "A2", "context": "C2"},
]
]

output_file_train = os.path.join(self.test_dir, "train_test.jsonl")
output_file_messages = os.path.join(self.test_dir, "messages_test.jsonl")

# Call the function
_gen_train_data(
machine_instruction_data,
output_file_train,
output_file_messages,
self.system_prompt,
)

# Verify train file was created and only has a single sample
self.assertTrue(os.path.exists(output_file_train))
train_data = jlload(output_file_train)
self.assertEqual(len(train_data), 1)

# Check first sample
first_sample = train_data[0]
self.assertEqual(first_sample["system"], self.system_prompt)
self.assertEqual(first_sample["user"], "Q2\nC2")
self.assertEqual(first_sample["assistant"], "A2")

# Verify messages file was created and has correct content
self.assertTrue(os.path.exists(output_file_messages))