Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.

Commit 20ac841

Browse files
authored
Merge pull request #231 from danmcp/checkpointleafnodes
Separate checkpoints by leaf nodes
2 parents 2dcbec7 + ba9da90 commit 20ac841

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

src/instructlab/sdg/generate_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def generate_data(
398398
logger.debug("Samples: %s", samples)
399399
ds = Dataset.from_list(samples)
400400
logger.debug("Dataset: %s", ds)
401-
new_generated_data = pipe.generate(ds)
401+
new_generated_data = pipe.generate(ds, leaf_node_path)
402402
if len(new_generated_data) == 0:
403403
raise EmptyDatasetError(
404404
"Pipeline stopped: Empty dataset after running pipe"

src/instructlab/sdg/pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,21 @@ def from_file(cls, ctx, pipeline_yaml):
131131
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
132132
return cls(ctx, pipeline_yaml, *_parse_pipeline_config_file(pipeline_yaml))
133133

134-
def generate(self, dataset) -> Dataset:
134+
def generate(self, dataset, checkpoint_name=None) -> Dataset:
135135
"""
136136
Generate the dataset by running the pipeline steps.
137137
dataset: the input dataset
138+
checkpoint_name: unique subdir name for the checkpoint within checkpoint_dir
138139
"""
139140

140141
# The checkpointer allows us to resume from where we left off
141142
# Saving the output of pipe instances along the way
142-
checkpointer = Checkpointer(self.ctx.checkpoint_dir, self.ctx.save_freq)
143+
checkpoint_dir = None
144+
if self.ctx.checkpoint_dir is not None and checkpoint_name is not None:
145+
# Separate checkpoints with sub directories
146+
checkpoint_dir = os.path.join(self.ctx.checkpoint_dir, checkpoint_name)
147+
148+
checkpointer = Checkpointer(checkpoint_dir, self.ctx.save_freq)
143149
dataset, pre_generated_data = checkpointer.load(dataset)
144150

145151
# If not batching, simply delegate to _generate_single

0 commit comments

Comments
 (0)