Skip to content

Commit 8c8913b

Browse files
committed
Update
Signed-off-by: Elron Bandel <elron.bandel@ibm.com>
1 parent 75b05eb commit 8c8913b

File tree

3 files changed

+46
-41
lines changed

3 files changed

+46
-41
lines changed

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ dataset = load_dataset("mteb/banking77")
6161
dataset["validation"] = dataset["test"]
6262

6363
# Down sample the train data for 5-shot training
64-
dataset["train"] = sample_dataset(dataset["train"], label_column="label", num_samples_per_label=5)
64+
dataset["train"] = sample_dataset(dataset["train"], label_column="label_text", num_samples_per_label=5)
6565

6666
trainer = FastFitTrainer(
6767
model_name_or_path="roberta-base",
68-
overwrite_output_dir=True,
69-
label_column_name="label",
68+
label_column_name="label_text",
7069
text_column_name="text",
7170
num_train_epochs=40,
7271
per_device_train_batch_size=32,
@@ -82,8 +81,13 @@ trainer = FastFitTrainer(
8281

8382
model = trainer.train()
8483
results = trainer.evaluate()
85-
test_results = trainer.test()
8684

85+
print("Accuracy: {:.1f}".format(results["eval_accuracy"] * 100))
86+
```
87+
Output: `Accuracy: 82.4`
88+
89+
Then the model can be saved:
90+
```python
8791
model.save_pretrained("fast-fit")
8892
```
8993
Then you can use the model for inference
@@ -185,4 +189,4 @@ print(classifier("I love this package!"))
185189
- `--no_logging_nan_inf_filter`: Filter nan and inf losses for logging. (default: False)
186190
- `--save_strategy {no,steps,epoch}`: The checkpoint save strategy to use. (default: steps)
187191
- `--save_steps SAVE_STEPS`: Save checkpoint every X updates steps. (default: 500)
188-
- `--save_total_limit SAVE_TOTAL_LIMIT
192+

fastfit/train.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import json
66
import math
7+
import tempfile
78

89
from dataclasses import dataclass, field
910
from collections import Counter, defaultdict
@@ -21,7 +22,6 @@
2122
DataCollatorWithPadding,
2223
EvalPrediction,
2324
HfArgumentParser,
24-
PretrainedConfig,
2525
Trainer,
2626
TrainingArguments,
2727
default_data_collator,
@@ -328,7 +328,7 @@ def has_custom_dataset(self):
328328
return (
329329
self._dataset is not None
330330
or self._train_dataset is not None
331-
or self._test_datast is not None
331+
or self._test_dataset is not None
332332
or self._validation_dataset is not None
333333
)
334334

@@ -351,9 +351,20 @@ def set_args(self, args_dict=None):
351351
training_args,
352352
) = parser.parse_args_into_dataclasses()
353353
else:
354-
args_dict["do_train"] = True
355-
args_dict["do_eval"] = True
356-
args_dict["do_predict"] = True
354+
if self._has_train_dataset:
355+
args_dict["do_train"] = True
356+
else:
357+
raise ValueError(
358+
"You must intilize the FastFitTrainer with train_dataset or train_file."
359+
)
360+
if self._has_validation_dataset:
361+
args_dict["do_eval"] = True
362+
if self._has_test_dataset:
363+
args_dict["do_predict"] = True
364+
if "output_dir" not in args_dict:
365+
args_dict["save_strategy"] = "no"
366+
args_dict["output_dir"] = tempfile.gettempdir()
367+
args_dict["overwrite_output_dir"] = True
357368
if self.has_custom_dataset():
358369
args_dict["task_name"] = "custom"
359370
config_args, model_args, data_args, training_args = parser.parse_dict(
@@ -457,7 +468,7 @@ def set_data(
457468

458469
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
459470
# when you use `do_predict` without specifying a GLUE benchmark task.
460-
if self.training_args.do_predict:
471+
if self.training_args.do_predict and self.is_command_line_mode:
461472
if self.data_args.test_file is not None:
462473
train_extension = self.data_args.train_file.split(".")[-1]
463474
test_extension = self.data_args.test_file.split(".")[-1]
@@ -663,28 +674,7 @@ def preprocess_data(self):
663674

664675
# Some models have set the order of the labels to use, so let's make sure we do use it.
665676
self.label_to_id = None
666-
if (
667-
self.model.config.label2id
668-
!= PretrainedConfig(num_labels=self.num_labels).label2id
669-
and self.data_args.task_name is not None
670-
and not self.is_regression
671-
):
672-
# Some have all caps in their config, some don't.
673-
self.label_name_to_id = {
674-
k.lower(): v for k, v in self.model.config.label2id.items()
675-
}
676-
if list(sorted(self.label_name_to_id.keys())) == list(sorted(self.labels)):
677-
self.label_to_id = {
678-
i: int(self.label_name_to_id[self.labels[i]])
679-
for i in range(self.num_labels)
680-
}
681-
else:
682-
logger.warning(
683-
"Your model seems to have been trained with labels, but they don't match the dataset: ",
684-
f"model labels: {list(sorted(self.label_name_to_id.keys()))}, dataset labels: {list(sorted(self.labels))}."
685-
"\nIgnoring the model labels as a result.",
686-
)
687-
elif self.data_args.task_name is None and not self.is_regression:
677+
if not self.is_regression:
688678
self.label_to_id = {v: i for i, v in enumerate(self.labels)}
689679

690680
if self.label_to_id is not None:
@@ -936,12 +926,28 @@ def __init__(
936926
train_dataset=None,
937927
validation_dataset=None,
938928
test_dataset=None,
929+
is_command_line_mode=False,
939930
**kwargs,
940931
):
932+
self._has_train_dataset = train_dataset is not None or "train_file" in kwargs
933+
if dataset is not None:
934+
self._has_train_dataset = "train" in dataset
935+
936+
self._has_validation_dataset = (
937+
validation_dataset is not None or "validation_file" in kwargs
938+
)
939+
if dataset is not None:
940+
self._has_validation_dataset = "validation" in dataset
941+
942+
self._has_test_dataset = test_dataset is not None or "test_file" in kwargs
943+
if dataset is not None:
944+
self._has_test_dataset = "test" in dataset
945+
941946
self._dataset = dataset
942947
self._train_dataset = train_dataset
943948
self._validation_dataset = validation_dataset
944949
self._test_dataset = test_dataset
950+
self.is_command_line_mode = is_command_line_mode
945951
self.set_args(kwargs)
946952
self.set_logger()
947953
self.set_last_checkpoint()
@@ -1049,20 +1055,15 @@ def test(self):
10491055
def push_to_hub(self):
10501056
kwargs = {
10511057
"finetuned_from": self.model_args.model_name_or_path,
1052-
"tasks": "text-classification",
1058+
"tasks": self.data_args.task_name,
10531059
}
1054-
if self.data_args.task_name is not None:
1055-
kwargs["language"] = "en"
1056-
kwargs["dataset_tags"] = "glue"
1057-
kwargs["dataset_args"] = self.data_args.task_name
1058-
kwargs["dataset"] = f"GLUE {self.data_args.task_name.upper()}"
10591060

10601061
if self.training_args.push_to_hub:
10611062
self.trainer.push_to_hub(**kwargs)
10621063

10631064

10641065
def main():
1065-
trainer = FastFitTrainer()
1066+
trainer = FastFitTrainer(is_command_line_mode=True)
10661067
trainer.train()
10671068
trainer.evaluate()
10681069
trainer.test()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name="fast-fit",
9-
version="1.0.1",
9+
version="1.1.0",
1010
description="Fast and effective approach for few shot with many classes",
1111
long_description=open("README.md").read(),
1212
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)