Skip to content

Commit 3966113

Browse files
authored
Fix errors in train.py & modelling.py; add ability to track multiple metrics at once (#24)
* Use updated evaluate.load for metrics, add ability to compute multiple metrics at once Signed-off-by: ShreyBiswas <[email protected]> * Fix ignore_keys_for_eval throwing error due to being a set, not a list Signed-off-by: ShreyBiswas <[email protected]> * Add **kwargs to the FastFit.forward() call Forward may get called with different arguments depending on the model; extra arguments like token_type_ids could cause it to crash. This is a simpler way to ignore those. Signed-off-by: ShreyBiswas <[email protected]> * add trust_remote_code so more models can be used Signed-off-by: ShreyBiswas <[email protected]> * adding evaluate to requirements Signed-off-by: Shrey Biswas <[email protected]> --------- Signed-off-by: ShreyBiswas <[email protected]> Signed-off-by: Shrey Biswas <[email protected]>
1 parent eb03b15 commit 3966113

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

fastfit/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def mask_tokens(self, inputs, special_tokens_mask=None):
835835

836836

837837
class FastFit(FastFitTrainable):
838-
def forward(self, input_ids, attention_mask, labels=None):
838+
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
839839
return SequenceClassifierOutput(
840840
logits=self.inference_forward(input_ids, attention_mask),
841841
)

fastfit/train.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import torch
1515
import datasets
1616
import numpy as np
17-
from datasets import load_dataset, load_metric
17+
from datasets import load_dataset
18+
from evaluate import load
1819

1920
import transformers
2021
from transformers import (
@@ -622,6 +623,7 @@ def set_model(self):
622623
else:
623624
config = AutoConfig.from_pretrained(
624625
pretrained_model_name_or_path=self.model_args.model_name_or_path,
626+
trust_remote_code=True,
625627
)
626628
config = FastFitConfig.from_encoder_config(
627629
config,
@@ -872,7 +874,21 @@ def preprocess_function(examples):
872874
)
873875

874876
def set_trainer(self):
875-
metric = load_metric(self.data_args.metric_name, experiment_id=uuid.uuid4())
877+
878+
if type(self.data_args.metric_name) == str: # single metric name
879+
metrics = [load(self.data_args.metric_name, experiment_id=uuid.uuid4())]
880+
elif type(self.data_args.metric_name) == list: # compute multiple metrics
881+
metrics = []
882+
for metric in self.data_args.metric_name:
883+
try:
884+
metrics.append(
885+
load(metric, experiment_id=uuid.uuid4())
886+
)
887+
except:
888+
logger.error(f"Metric {metric} not found. Skipping...")
889+
continue
890+
891+
876892

877893
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
878894
# predictions and label_ids field) and has to return a dictionary string to float.
@@ -886,7 +902,17 @@ def compute_metrics(p: EvalPrediction):
886902
else np.argmax(predictions, axis=1)
887903
)
888904
references = p.label_ids
889-
return metric.compute(predictions=predictions, references=references)
905+
906+
results = {}
907+
908+
for metric in metrics:
909+
if metric.name != 'accuracy':
910+
results.update(metric.compute(predictions=predictions, references=references,average='macro'))
911+
else:
912+
results.update(metric.compute(predictions=predictions, references=references))
913+
914+
return results
915+
890916

891917
# Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
892918
# we already did the padding.
@@ -966,7 +992,7 @@ def train(self):
966992
if self.training_args.do_train:
967993
train_result = self.trainer.train(
968994
resume_from_checkpoint=self.checkpoint,
969-
ignore_keys_for_eval={"doc_input_ids", "doc_attention_mask", "labels"},
995+
ignore_keys_for_eval=list({"doc_input_ids", "doc_attention_mask", "labels"}),
970996
)
971997
metrics = train_result.metrics
972998
max_train_samples = (

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
torch
22
transformers[torch]
33
scikit-learn
4-
datasets
4+
datasets
5+
evaluate

0 commit comments

Comments
 (0)