Skip to content

Added option for loading pre-saved bootstrapped training data for fine-tuning #8262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
adapter: Optional[Union[Adapter, Dict[LM, Adapter]]] = None,
exclude_demos: bool = False,
num_threads: Optional[int] = None,
bootstrapped_data_path: Optional[str] = None,
):
# TODO(feature): Inputs train_kwargs (a dict with string keys) and
# adapter (Adapter) can depend on the LM they are used with. We are
Expand All @@ -58,6 +59,7 @@ def __init__(
self.adapter: Dict[LM, Adapter] = self.convert_to_lm_dict(adapter)
self.exclude_demos = exclude_demos
self.num_threads = num_threads
self.bootstrapped_data_path = bootstrapped_data_path

def compile(
self, student: Program, trainset: List[Example], teacher: Optional[Union[Program, List[Program]]] = None
Expand All @@ -74,18 +76,26 @@ def compile(
teachers = teacher if isinstance(teacher, list) else [teacher]
teachers = [prepare_teacher(student, t) for t in teachers]
num_threads = self.num_threads or dspy.settings.num_threads
for t in teachers:
trace_data += bootstrap_trace_data(program=t, dataset=trainset, metric=self.metric, num_threads=num_threads)
if self.bootstrapped_data_path:
logger.info(f"Loading bootstrapped training data from: {self.bootstrapped_data_path}")
train_data = self.load_data(self.bootstrapped_data_path)
else:
for t in teachers:
trace_data += bootstrap_trace_data(program=t, dataset=trainset, metric=self.metric, num_threads=num_threads)

logger.info("Preparing the train data...")
key_to_data = {}
for pred_ind, pred in enumerate(student.predictors()):
data_pred_ind = None if self.multitask else pred_ind
training_key = (pred.lm, data_pred_ind)
if training_key not in key_to_data:
train_data, data_format = self._prepare_finetune_data(
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
)
if self.bootstrapped_data_path:
adapter = self.adapter[pred.lm] or settings.adapter or ChatAdapter()
data_format = infer_data_format(adapter)
else:
train_data, data_format = self._prepare_finetune_data(
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
)
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
finetune_kwargs = {
"lm": pred.lm,
Expand Down Expand Up @@ -128,6 +138,15 @@ def compile(
student._compiled = True
return student

@staticmethod
def load_data(path: str) -> List[Dict[str, Any]]:
import json
data = []
with open(path, "r") as f:
for line in f:
data.append(json.loads(line))
return data

@staticmethod
def finetune_lms(finetune_dict) -> Dict[Any, LM]:
num_jobs = len(finetune_dict)
Expand Down