Skip to content

Commit 15dc34c

Browse files
committed
Updates to PR from feedback.
1 parent 9ec145f commit 15dc34c

File tree

3 files changed

+98
-59
lines changed

3 files changed

+98
-59
lines changed

contrib/hamilton/contrib/user/skrawcz/fine_tuning/README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Purpose of this module
22
This module shows you how to fine-tune an LLM model. This code is inspired by this [fine-tuning code](https://github.com/dagster-io/dagster_llm_finetune/tree/main).
33

4-
Specifically the code here, from an approach standpoint, shows Supervised Fine-Tuning (SFT) for dialogue. This approach instructs the model to be more
4+
Specifically the code here, shows Supervised Fine-Tuning (SFT) for dialogue. This approach instructs the model to be more
55
useful to directly respond to a question, rather than optimizing over an entire dialogue. SFT is the most common type of fine-tuning,
6-
as the other two options: Pre-training for Completion, and RLHF, required more to work. Pre-training requires more computational power,
6+
as the other two options, Pre-training for Completion, and RLHF, required more to work. Pre-training requires more computational power,
77
while RLHF requires higher-quality dialogue data.
88

99
This code should work on a regular CPU (in a docker container), which will allow you to test out the code locally without
1010
any additional setup. This specific approach this code uses is [LoRA](https://arxiv.org/abs/2106.09685) (low-rank adaptation of large language models), which
1111
means that only a subset of the LLM's parameters are tweaked and prevents over-fitting.
1212

13+
Note: if you have issues running this on MacOS, reach out, we might be able to help.
14+
1315
## What is fine-tuning?
1416
Fine-tuning is when a pre-trained model, in this context a foundational model, is customized using additional data to
1517
adjust its responses for a specific task. This is a good way to adjust an off-the-shelf, i.e. pretrained model, to provide
@@ -28,7 +30,7 @@ It shows a basic process of:
2830

2931
a. Loading data and tokenizing it and setting up some tokenization parameters.
3032

31-
b. Splitting data into training, validation, and inference sets.
33+
b. Splitting data into training, validation, and hold out sets.
3234

3335
c. Fine-tuning the model using LoRA.
3436

@@ -63,14 +65,12 @@ You would then pass in as _inputs_ to execution `"data_path"=PATH_TO_THIS_FILE`
6365
that the transformers library supports for `AutoModelForSeq2SeqLM` models.
6466
- Run the code.
6567

66-
Because there's no configuration that changes the shape of the DAG, you can run the code like this:
67-
6868
```python
6969
# instantiate the driver with this module however you want
7070
result = dr.execute(
7171
[ # some suggested outputs
7272
"save_best_models",
73-
"inference_set_predictions",
73+
"hold_out_set_predictions",
7474
"training_and_validation_set_metrics",
7575
"finetuned_model_on_validation_set",
7676
],
@@ -122,7 +122,7 @@ docker run YOUR_IMAGE_NAME
122122
- `{"start": "presaved"}` Use this if you want to load an already fine-tuned model and then just eval it.
123123

124124
# Limitations
125-
The code here cannot guarantee groundbreaking performance for your specific use case,
125+
The code here will likely not solve all your LLM troubles,
126126
but it can show you how to fine-tune an LLM using parameter-efficient techniques such as LoRA.
127127

128128
This code is currently set up to work with dataset and transformer libraries. It could be modified to work with other libraries.

contrib/hamilton/contrib/user/skrawcz/fine_tuning/__init__.py

+91-52
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from hamilton.function_modifiers.configuration import config
3838

3939

40-
@extract_fields({"train_set": Dataset, "validation_set": Dataset, "inference_set": Dataset})
40+
@extract_fields({"train_set": Dataset, "validation_set": Dataset, "hold_out_set": Dataset})
4141
def raw_dataset(
4242
data_path: str,
4343
random_state: int = 42,
@@ -78,7 +78,7 @@ def raw_dataset(
7878
return {
7979
"train_set": dataset_train,
8080
"validation_set": dataset_validation,
81-
"inference_set": dataset_inference,
81+
"hold_out_set": dataset_inference,
8282
}
8383

8484

@@ -87,7 +87,7 @@ def tokenizer(
8787
) -> PreTrainedTokenizerBase:
8888
"""The tokenizer we're going to use to tokenize text.
8989
90-
:param model_id:
90+
:param model_id: the model id that corresponds to what huggingface knows about.
9191
:return: the tokenizer to use.
9292
"""
9393
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -97,11 +97,11 @@ def tokenizer(
9797
def tokenized_inputs(
9898
train_set: Dataset,
9999
validation_set: Dataset,
100-
inference_set: Dataset,
100+
hold_out_set: Dataset,
101101
tokenizer: PreTrainedTokenizerBase,
102102
) -> DatasetType:
103-
"""Tokenizes"""
104-
return concatenate_datasets([train_set, validation_set, inference_set]).map(
103+
"""Tokenizes the inputs from all the datasets and creates a single data set with them."""
104+
return concatenate_datasets([train_set, validation_set, hold_out_set]).map(
105105
lambda x: tokenizer(x["input_text"], truncation=True),
106106
batched=True,
107107
remove_columns=["input_text", "output_text"],
@@ -121,10 +121,11 @@ def max_source_lengths(input_lengths: list[int]) -> int:
121121
def tokenized_targets(
122122
train_set: Dataset,
123123
validation_set: Dataset,
124-
inference_set: Dataset,
124+
hold_out_set: Dataset,
125125
tokenizer: PreTrainedTokenizerBase,
126126
) -> DatasetType:
127-
return concatenate_datasets([train_set, validation_set, inference_set]).map(
127+
"""Tokenizes the outputs, i.e. target responses, from all the datasets and creates a single data set with them."""
128+
return concatenate_datasets([train_set, validation_set, hold_out_set]).map(
128129
lambda x: tokenizer(x["output_text"], truncation=True),
129130
batched=True,
130131
remove_columns=["input_text", "output_text"],
@@ -215,14 +216,14 @@ def tokenized_validation(
215216
)
216217

217218

218-
def tokenized_inference(
219-
inference_set: Dataset,
219+
def tokenized_hold_out(
220+
hold_out_set: Dataset,
220221
max_source_lengths: int,
221222
max_target_lengths: int,
222223
tokenizer: PreTrainedTokenizerBase,
223224
) -> Dataset:
224225
"""Tokenizes the inference set."""
225-
return inference_set.map(
226+
return hold_out_set.map(
226227
partial(
227228
_preprocess_function,
228229
max_source_lengths=max_source_lengths,
@@ -237,15 +238,16 @@ def tokenized_inference(
237238
def saved_datasets(
238239
tokenized_train: Dataset,
239240
tokenized_validation: Dataset,
240-
tokenized_inference: Dataset,
241+
tokenized_hold_out: Dataset,
241242
) -> dict:
243+
"""Function to save tokenized datasets using the datasets library."""
242244
tokenized_train.save_to_disk("data/train")
243245
tokenized_validation.save_to_disk("data/validation")
244-
tokenized_inference.save_to_disk("data/inference")
246+
tokenized_hold_out.save_to_disk("data/inference")
245247
return {
246248
"tokenized_train": "data/train",
247249
"tokenized_validation": "data/validation",
248-
"tokenized_inference": "data/inference",
250+
"tokenized_hold_out": "data/inference",
249251
}
250252

251253

@@ -338,14 +340,18 @@ def training_args(
338340
gradient_accumulation_steps: int = 1,
339341
num_train_epochs: int = 2,
340342
) -> Seq2SeqTrainingArguments:
341-
"""The arguments to use for fine-tuning.
342-
343-
:param peft_model_id:
344-
:param per_device_eval_batch_size:
345-
:param per_device_train_batch_size:
346-
:param gradient_accumulation_steps:
347-
:param num_train_epochs:
348-
:return: Seq2SeqTrainingArguments
343+
"""Constructs the arguments to use for fine-tuning.
344+
345+
:param peft_model_id: The ID we're running everything under here.
346+
:param per_device_eval_batch_size: The batch size for evaluation. This is the number of samples that will be fed to
347+
the model at once during evaluation.
348+
:param per_device_train_batch_size: The batch size for training. This is the number of samples that will be fed to
349+
the model at once during training.
350+
:param gradient_accumulation_steps: The number of steps to accumulate gradients before performing an optimization
351+
step. This can be useful when training on multiple GPUs to effectively increase the batch size.
352+
:param num_train_epochs: The number of epochs to train the model. An epoch is one pass through the entire training
353+
dataset.
354+
:return: Seq2SeqTrainingArguments object, the arguments for training such as batch size, learning rate, etc.
349355
"""
350356
training_args = Seq2SeqTrainingArguments(
351357
do_train=True,
@@ -374,14 +380,16 @@ def trainer(
374380
tokenized_train: Dataset,
375381
tokenized_validation: Dataset,
376382
) -> Seq2SeqTrainer:
377-
"""The trainer we'll use for fine-tuning.
378-
379-
:param base_model:
380-
:param training_args:
381-
:param data_collator:
382-
:param tokenized_train:
383-
:param tokenized_test:
384-
:return: Seq2SeqTrainer
383+
"""Constructs a Seq2SeqTrainer for fine-tuning.
384+
385+
:param base_model: torch.nn.Module object, the base model to be fine-tuned.
386+
:param training_args: Seq2SeqTrainingArguments object, the arguments for training such as batch size, learning rate,
387+
etc.
388+
:param data_collator: DataCollatorForSeq2Seq object, the data collator that will be used to form batches for
389+
training.
390+
:param tokenized_train: Dataset object, the training set that has been tokenized.
391+
:param tokenized_validation: Dataset object, the validation set that has been tokenized.
392+
:return: Seq2SeqTrainer object, the trainer that will be used for fine-tuning.
385393
"""
386394
trainer = Seq2SeqTrainer(
387395
model=base_model,
@@ -407,7 +415,7 @@ def fitted_and_evaluated_trainer(trainer: Seq2SeqTrainer, device: torch.device)
407415
408416
:param trainer: the trainer to use.
409417
:param device: device to place the device on to.
410-
:return:
418+
:return: dictionary with the trainer, the evaluation metrics, and the fine-tuned model.
411419
"""
412420
trainer.train()
413421
eval_metrics = trainer.evaluate()
@@ -472,12 +480,15 @@ def metric() -> evaluate.EvaluationModule:
472480
def _evaluate_peft_model(sample, model, tokenizer, device, max_target_length=512) -> tuple:
473481
"""Helper function to evaluate the model on a sample.
474482
475-
:param sample:
476-
:param model:
477-
:param tokenizer:
478-
:param device:
479-
:param max_target_length:
480-
:return: the prediction, and the ground truth label.
483+
:param sample: The sample on which the model is to be evaluated. This is typically a single instance from the
484+
dataset.
485+
:param model: The model to be used for evaluation. This is typically the fine-tuned model.
486+
:param tokenizer: The tokenizer that was used during the preprocessing of the data.
487+
This will be used to decode the model's predictions from token ids back to text.
488+
:param device: The device where the computations will be performed. This is typically a CPU or a specific GPU.
489+
:param max_target_length: The maximum length of the target sequence. If the predicted sequence is longer than this,
490+
it will be truncated to this length.
491+
:return: A tuple containing the prediction and the ground truth label.
481492
"""
482493
# generate summary
483494
# outputs = model.generate(
@@ -508,14 +519,29 @@ def finetuned_model_on_validation_set(
508519
tokenizer: PreTrainedTokenizerBase,
509520
device: torch.device,
510521
) -> dict[str, list]:
511-
"""Evaluates our fine-tuned model on the validation set.
522+
"""
523+
Evaluates the fine-tuned model on the validation set.
512524
513-
If you run this on a large model this can take a while.
514-
:param tokenized_validation:
515-
:param finetuned_model:
516-
:param tokenizer:
517-
:param device:
518-
:return:
525+
This function iterates over the validation set and generates predictions for each sample.
526+
The predictions and the ground truth labels are then returned as a dictionary.
527+
528+
Note: If you run this on a large model this can take a while.
529+
530+
:param tokenized_validation: Dataset object, the validation set that has been tokenized.
531+
This is the data that the model has seen during training and will be used for testing the model's performance.
532+
533+
:param finetuned_model: torch.nn.Module object, the fine-tuned model that will be used to generate predictions
534+
on the validation set. This model has been trained on the training set.
535+
536+
:param tokenizer: PreTrainedTokenizerBase object, the tokenizer that was used during the preprocessing
537+
of the data. This will be used to decode the model's predictions from token ids back to text.
538+
539+
:param device: torch.device object, the device where the computations will be performed.
540+
This is typically a CPU or a specific GPU.
541+
542+
:return: Dictionary containing two lists - 'validation_predictions' and 'validation_references'.
543+
'validation_predictions' is a list of model predictions for each sample in the validation set.
544+
'validation_references' is a list of ground truth labels for each sample in the validation set.
519545
"""
520546
predictions, references = [], []
521547
with torch.inference_mode():
@@ -540,6 +566,14 @@ def validation_set_metrics(
540566
validation_references: list,
541567
metric: evaluate.EvaluationModule,
542568
) -> dict:
569+
"""Computes the Rouge metric on the validation set.
570+
571+
:param validation_predictions: List of model predictions for each sample in the validation set.
572+
:param validation_references: List of ground truth labels for each sample in the validation set.
573+
:param metric: EvaluationModule object, the metric used to evaluate the model's performance.
574+
In this case, it's Rouge.
575+
:return: Dictionary containing the Rouge scores for the model's performance on the validation set.
576+
"""
543577
rogue = metric.compute(
544578
predictions=validation_predictions,
545579
references=validation_references,
@@ -561,24 +595,29 @@ def training_and_validation_set_metrics(
561595
return {**fit_trainer_eval_metrics, **validation_set_metrics}
562596

563597

564-
def inference_set_predictions(
565-
tokenized_inference: Dataset,
598+
def hold_out_set_predictions(
599+
tokenized_hold_out: Dataset,
566600
finetuned_model: torch.nn.Module,
567601
tokenizer: PreTrainedTokenizerBase,
568602
device: torch.device,
569603
) -> list[tuple[str, str]]:
570604
"""Runs model on the inference set.
571605
572-
:param tokenized_inference:
573-
:param finetuned_model:
574-
:param tokenizer:
575-
:param device:
606+
:param tokenized_hold_out: Dataset object, the hold-out set that has been tokenized.This is the data that the
607+
model has not seen during training or validation and will be used for testing the model's performance.
608+
:param finetuned_model: torch.nn.Module object, the fine-tuned model that will be used to generate predictions
609+
on the hold-out set.
610+
This model has been trained on the training set and validated on the validation set.
611+
:param tokenizer: PreTrainedTokenizerBase object, the tokenizer that was used during the preprocessing
612+
of the data. This will be used to decode the model's predictions from token ids back to text.
613+
:param device: torch.device object, the device where the computations will be performed.
614+
This is typically a CPU or a specific GPU.
576615
:return: generate responses for the inference set
577616
"""
578617
predictions = []
579618
questions = []
580619
with torch.inference_mode():
581-
for sample in tokenized_inference.with_format("torch", device=device):
620+
for sample in tokenized_hold_out.with_format("torch", device=device):
582621
max_target_length = 512
583622
outputs = finetuned_model.generate(
584623
input_ids=sample["input_ids"].unsqueeze(0).cpu(),
@@ -636,7 +675,7 @@ def inference_set_predictions(
636675
result = dr.execute(
637676
[
638677
"save_best_models",
639-
"inference_set_predictions",
678+
"hold_out_set_predictions",
640679
"training_and_validation_set_metrics",
641680
"finetuned_model_on_validation_set",
642681
],
Loading

0 commit comments

Comments
 (0)