Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.

Commit 7087621

Browse files
authored
Merge pull request #203 from bbrowning/write-recipe-yamls
Introduce data mixing recipe yaml files
2 parents b292b7a + fcfacfc commit 7087621

7 files changed

Lines changed: 227 additions & 34 deletions

File tree

src/instructlab/sdg/datamixing.py

Lines changed: 97 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Third Party
99
from datasets import Dataset, concatenate_datasets, load_dataset
10+
import yaml
1011

1112
# First Party
1213
from instructlab.sdg.logger_config import setup_logger
@@ -27,18 +28,12 @@ def _adjust_train_sample_size(ds: Dataset, num_samples: int):
2728
return pandas.dataset_from_pandas_dataframe(df)
2829

2930

30-
def _load_ds(path, sampling_size, num_proc):
31+
def _sample_ds(dataset, sampling_size, num_proc):
3132
"""
32-
Load a dataset from the given file path and select sampling_size
33-
number/ratio of samples from it, ensuring the loaded dataset has only
34-
ALLOWED_COLS columns in it with any additional columns moved to the
35-
metadata section.
33+
Select sampling_size number/ratio of samples from a dataset, ensuring
34+
the returned dataset has only ALLOWED_COLS columns in it with any
35+
additional columns moved to the metadata section.
3636
"""
37-
logger.info(f"Loading dataset from {path} ...")
38-
dataset = load_dataset("json", data_files=path, split="train")
39-
logger.info(f"Dataset columns: {dataset.column_names}")
40-
logger.info(f"Dataset loaded with {len(dataset)} samples")
41-
4237
if sampling_size != 1.0:
4338
if isinstance(sampling_size, int):
4439
num_samples = sampling_size
@@ -94,29 +89,63 @@ class Recipe:
9489
"""
9590

9691
def __init__(
97-
self, initial_datasets: Optional[list] = None, sys_prompt: Optional[str] = ""
92+
self, recipe_path: Optional[str] = None, sys_prompt: Optional[str] = ""
9893
):
99-
self.recipe = {
100-
"datasets": initial_datasets or [],
101-
"sys_prompt": sys_prompt,
102-
}
103-
self.sys_prompt = self.recipe.get("sys_prompt", "")
94+
self.recipe_path = recipe_path or ""
95+
self.sys_prompt = sys_prompt
96+
97+
# Defaults if no recipe path given or these values don't
98+
# exist in the given recipe file
99+
self.datasets = []
100+
if recipe_path is not None:
101+
recipe = self._load_recipe()
102+
if "datasets" in recipe:
103+
self.datasets = recipe["datasets"]
104+
104105
self.dataset_added = False
105106

107+
def _load_recipe(self):
108+
with open(self.recipe_path, encoding="utf-8") as fp:
109+
return yaml.safe_load(fp)
110+
111+
def _load_ds(self, path):
112+
"""
113+
Load a dataset from the given location. If a jsonl file is
114+
given, we load the dataset from disk. Otherwise, we load the
115+
path given from HuggingFace. Relative paths are resolved
116+
respective to the directory the recipe yaml itself resides in.
117+
"""
118+
if not os.path.isabs(path):
119+
path = os.path.join(os.path.dirname(self.recipe_path), path)
120+
logger.info(f"Loading dataset from {path} ...")
121+
dataset = load_dataset("json", data_files=path, split="train")
122+
logger.info(f"Dataset columns: {dataset.column_names}")
123+
logger.info(f"Dataset loaded with {len(dataset)} samples")
124+
return dataset
125+
126+
def _load_and_sample_datasets(self, num_proc):
127+
"""
128+
Load and sample all the datasets in this recipe, taking
129+
into account the desired sampling size from each individual
130+
dataset to control the overall mix of samples in the final
131+
dataset.
132+
"""
133+
return [
134+
_sample_ds(
135+
self._load_ds(dataset["path"]), dataset["sampling_size"], num_proc
136+
)
137+
for dataset in self.datasets
138+
]
139+
106140
def _create_mixed_dataset(self, num_proc):
107141
"""
108-
Create the mixed dataset from its list of included datasets, taking
109-
into account the desired sampling size from each individual dataset
110-
to control the overall mix of samples in the final dataset.
142+
Create the final mixed dataset by loading, sampling, and
143+
concatenating all datasets in this recipe
111144
"""
112145
if not self.dataset_added:
113146
logger.error("No dataset added to the recipe")
114147

115-
mixed_ds = [
116-
_load_ds(dataset["path"], dataset["sampling_size"], num_proc)
117-
for dataset in self.recipe["datasets"]
118-
]
119-
148+
mixed_ds = self._load_and_sample_datasets(num_proc)
120149
mixed_ds = concatenate_datasets(mixed_ds)
121150
mixed_ds = mixed_ds.map(
122151
_add_system_message,
@@ -143,7 +172,20 @@ def add_dataset(self, path, sampling_size):
143172
of the samples, and so on.
144173
"""
145174
self.dataset_added = True
146-
self.recipe["datasets"].append({"path": path, "sampling_size": sampling_size})
175+
self.datasets.append({"path": path, "sampling_size": sampling_size})
176+
177+
def save_recipe(self, output_path):
178+
recipe = {
179+
"datasets": self.datasets,
180+
"metadata": {"sys_prompt": self.sys_prompt},
181+
}
182+
with open(output_path, "w", encoding="utf-8") as fp:
183+
yaml.dump(recipe, fp)
184+
# Update this instance's recipe_path to reflect the path we
185+
# just saved it to so that any subsequent loading of datasets
186+
# (like via save_mixed_dataset) pulls from relative to the
187+
# saved recipe_path.
188+
self.recipe_path = output_path
147189

148190
def save_mixed_dataset(self, output_path, num_proc):
149191
"""
@@ -398,18 +440,34 @@ class DataMixer:
398440
# once.
399441
NUM_SYNTH_SKILLS = 30
400442

401-
def __init__(self, output_dir, date_suffix, sys_prompt, num_procs):
443+
def __init__(self, data_dirs, output_dir, date_suffix, sys_prompt, num_procs):
444+
self.data_dirs = data_dirs
402445
self.output_dir = output_dir
403446
self.sys_prompt = sys_prompt
404447
self.date_suffix = date_suffix
405448
self.num_procs = num_procs
406449

407-
self.knowledge_recipe = Recipe(sys_prompt=self.sys_prompt)
408-
self.skills_recipe = Recipe(sys_prompt=self.sys_prompt)
450+
self.knowledge_recipe = self._load_default_recipe("knowledge.yaml")
451+
self.skills_recipe = self._load_default_recipe("skills.yaml")
409452

453+
self.output_file_knowledge_recipe = f"knowledge_recipe_{date_suffix}.yaml"
454+
self.output_file_skills_recipe = f"skills_recipe_{date_suffix}.yaml"
410455
self.output_file_mixed_knowledge = f"knowledge_train_msgs_{date_suffix}.jsonl"
411456
self.output_file_mixed_skills = f"skills_train_msgs_{date_suffix}.jsonl"
412457

458+
def _load_default_recipe(self, yaml_basename):
459+
"""
460+
Load a default system recipe from e.g. /usr/share/instructlab/sdg/default_data_recipes
461+
if it exists, otherwise return an empty recipe.
462+
"""
463+
for d in self.data_dirs:
464+
default_recipe_path = os.path.join(d, "default_data_recipes", yaml_basename)
465+
if os.path.exists(default_recipe_path):
466+
return Recipe(
467+
recipe_path=default_recipe_path, sys_prompt=self.sys_prompt
468+
)
469+
return Recipe(sys_prompt=self.sys_prompt)
470+
413471
def _gen_leaf_node_data(
414472
self, leaf_node_data, recipe, output_file_leaf_node, sampling_size=1.0
415473
):
@@ -420,7 +478,7 @@ def _gen_leaf_node_data(
420478
"""
421479
output_file = os.path.join(self.output_dir, output_file_leaf_node)
422480
leaf_node_data.to_json(output_file, orient="records", lines=True)
423-
recipe.add_dataset(output_file, sampling_size)
481+
recipe.add_dataset(output_file_leaf_node, sampling_size)
424482

425483
def collect(self, leaf_node_path, new_generated_data, is_knowledge):
426484
if is_knowledge:
@@ -459,20 +517,27 @@ def collect(self, leaf_node_path, new_generated_data, is_knowledge):
459517
sampling_size=self.NUM_SYNTH_SKILLS,
460518
)
461519

462-
def _gen_mixed_data(self, recipe, output_file_mixed):
520+
def _gen_mixed_data(self, recipe, output_file_recipe, output_file_data):
463521
"""
464522
Mix the generated leaf node data into a single dataset and write it to
465523
disk. The heavy lifting is delegated to the Recipe class.
466524
"""
467525
if recipe.dataset_added:
526+
full_recipe_path = os.path.join(self.output_dir, output_file_recipe)
527+
recipe.save_recipe(full_recipe_path)
468528
recipe.save_mixed_dataset(
469-
os.path.join(self.output_dir, output_file_mixed),
529+
os.path.join(self.output_dir, output_file_data),
470530
self.num_procs,
471531
)
472532

473533
def generate(self):
474-
self._gen_mixed_data(self.knowledge_recipe, self.output_file_mixed_knowledge)
534+
self._gen_mixed_data(
535+
self.knowledge_recipe,
536+
self.output_file_knowledge_recipe,
537+
self.output_file_mixed_knowledge,
538+
)
475539
self._gen_mixed_data(
476540
self.skills_recipe,
541+
self.output_file_skills_recipe,
477542
self.output_file_mixed_skills,
478543
)

src/instructlab/sdg/generate_data.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,20 @@ def load_pipeline(yaml_basename):
241241
)
242242

243243

244+
def _mixer_init(ctx, output_dir, date_suffix):
245+
pd = platformdirs.PlatformDirs(
246+
appname=os.path.join("instructlab", "sdg"), multipath=True
247+
)
248+
data_dirs = list(pd.iter_data_dirs())
249+
return DataMixer(
250+
data_dirs,
251+
output_dir,
252+
date_suffix,
253+
_SYS_PROMPT,
254+
ctx.dataset_num_procs,
255+
)
256+
257+
244258
# This is part of the public API, and used by instructlab.
245259
# TODO - parameter removal needs to be done in sync with a CLI change.
246260
# pylint: disable=unused-argument
@@ -342,7 +356,7 @@ def generate_data(
342356

343357
mmlu_bench_pipe = mmlubench_pipe_init(ctx)
344358

345-
mixer = DataMixer(output_dir, date_suffix, _SYS_PROMPT, ctx.dataset_num_procs)
359+
mixer = _mixer_init(ctx, output_dir, date_suffix)
346360

347361
if console_output:
348362
logger.info(

tests/test_datamixing.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,31 @@
22
Unit tests for the top-level datamixing module.
33
"""
44

5+
# Standard
6+
from importlib import resources
7+
from unittest.mock import patch
8+
import os
9+
510
# Third Party
611
from datasets import Dataset
712

813
# First Party
9-
from instructlab.sdg.datamixing import _add_extra_contexts_to_samples
14+
from instructlab.sdg.datamixing import DataMixer, Recipe, _add_extra_contexts_to_samples
15+
16+
# We mock out the actual things that use num_procs anyway, but just
17+
# for a consistent value in the tests...
18+
TEST_NUM_PROCS = 4
19+
TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "testdata")
20+
TEST_RECIPE_PATH = os.path.join(TEST_DATA_DIR, "relative_path_recipe.yaml")
21+
TEST_SAMPLES_ABS_PATH = os.path.join(TEST_DATA_DIR, "datasets/samples.jsonl")
22+
23+
24+
def _empty_recipe(self):
25+
return {}
26+
27+
28+
def _noop_sample(dataset, _sampling_size, _num_procs):
29+
return dataset
1030

1131

1232
def _fake_context(msg_id):
@@ -18,6 +38,89 @@ def _fake_context(msg_id):
1838
}
1939

2040

41+
def test_datamixer_can_load_default_recipes():
42+
"""
43+
Test that DataMixer can load default recipe files by pointing
44+
it at a simple set of test recipe files under the testdata/
45+
directory.
46+
"""
47+
date_suffix = "2024-07-25T15_52_10"
48+
prompt = "You are a useful AI assistant."
49+
mixer = DataMixer(
50+
[TEST_DATA_DIR], TEST_DATA_DIR, date_suffix, prompt, TEST_NUM_PROCS
51+
)
52+
assert mixer.knowledge_recipe.datasets[0]["path"] == "test/knowledge.jsonl"
53+
assert mixer.skills_recipe.datasets[0]["path"] == "test/skills.jsonl"
54+
55+
56+
def test_recipe_init_with_empty_params_adds_dataset():
57+
"""
58+
Test that an empty-initialized recipe can add datasets
59+
"""
60+
recipe = Recipe()
61+
recipe.add_dataset("testdata/datasets/samples.jsonl", 1.0)
62+
assert recipe.dataset_added
63+
64+
65+
def test_recipe_init_with_empty_params_loads_abs_dataset():
66+
"""
67+
Test that an empty-initialized recipe can load datasets from
68+
absolute file paths.
69+
"""
70+
recipe = Recipe()
71+
dataset = recipe._load_ds(TEST_SAMPLES_ABS_PATH)
72+
assert dataset is not None
73+
74+
75+
def test_recipe_init_with_empty_params_loads_rel_dataset():
76+
"""
77+
Test that an empty-initialized recipe looks for dataset files relative
78+
to the current working directory (as opposed to blowing up because of
79+
no recipe_path given).
80+
"""
81+
recipe = Recipe()
82+
rel_path = os.path.relpath(TEST_SAMPLES_ABS_PATH)
83+
dataset = recipe._load_ds(rel_path)
84+
assert dataset is not None
85+
86+
87+
@patch.object(Recipe, "_load_recipe", _empty_recipe)
88+
def test_init_with_empty_recipe_files():
89+
"""
90+
Test that we can initialize a Recipe that points to a recipe
91+
file that does not contain one or more of our expected keys, and
92+
that instead of blowing up (like with a KeyError) we just use sane
93+
defaults.
94+
"""
95+
recipe = Recipe(recipe_path=TEST_RECIPE_PATH)
96+
assert len(recipe.datasets) == 0
97+
assert recipe.sys_prompt == ""
98+
99+
100+
@patch("instructlab.sdg.datamixing._sample_ds", _noop_sample)
101+
def test_load_ds_with_relative_jsonl_path():
102+
"""
103+
Test that the _load_ds function can load from datasets from jsonl
104+
files referenced with a path relative to the recipe file
105+
"""
106+
recipe = Recipe(recipe_path=TEST_RECIPE_PATH)
107+
dataset = recipe._load_and_sample_datasets(TEST_NUM_PROCS)
108+
assert dataset is not None
109+
110+
111+
@patch("instructlab.sdg.datamixing._sample_ds", _noop_sample)
112+
def test_load_ds_with_absolute_jsonl_path():
113+
"""
114+
Test that the _load_ds function can load from datasets from jsonl
115+
files referenced with an absolute dataset path
116+
"""
117+
recipe = Recipe(recipe_path=TEST_RECIPE_PATH)
118+
# Patch an absolute path into our recipe before loading it
119+
recipe.datasets[0]["path"] = TEST_SAMPLES_ABS_PATH
120+
dataset = recipe._load_and_sample_datasets(TEST_NUM_PROCS)
121+
assert dataset is not None
122+
123+
21124
def test_add_extra_contexts_to_samples_with_one_sample():
22125
"""
23126
Test _add_extra_contexts_to_samples doesn't error out when
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"id": "abc123", "messages": [], "metadata": {}}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
datasets:
2+
- path: test/knowledge.jsonl
3+
sampling_size: 1.0
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
datasets:
2+
- path: test/skills.jsonl
3+
sampling_size: 1.0
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
datasets:
2+
- path: datasets/samples.jsonl
3+
sampling_size: 1.0
4+
sys_prompt: I am a reliable AI assistant.

0 commit comments

Comments
 (0)