Skip to content

Commit d55f622

Browse files
authored
Add a transformers example (#322)
1 parent d54bf4f commit d55f622

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

.github/workflows/transformers.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: transformers
2+
3+
on:
4+
schedule:
5+
- cron: '0 15 * * *'
6+
pull_request:
7+
paths:
8+
- 'transformers/**'
9+
- '.github/workflows/transformers.yml'
10+
workflow_dispatch:
11+
12+
jobs:
13+
examples:
14+
if: (github.event_name == 'schedule' && github.repository == 'optuna/optuna-examples') || (github.event_name != 'schedule')
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: ['3.9', '3.10', '3.11', '3.12']
19+
20+
steps:
21+
- uses: actions/checkout@v4
22+
- name: setup-python${{ matrix.python-version }}
23+
uses: actions/setup-python@v5
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
- name: Install (Python)
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install --progress-bar off -U setuptools
30+
pip install git+https://github.com/optuna/optuna.git
31+
python -c 'import optuna'
32+
33+
pip install -r transformers/requirements.txt
34+
- name: Run examples
35+
run: |
36+
python transformers/transformers_simple.py
37+
env:
38+
OMP_NUM_THREADS: 1

transformers/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
accelerate
2+
datasets
3+
evaluate
4+
scikit-learn
5+
transformers
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
Optuna example for fine-tuning a BERT-based text classification model on the IMDb dataset
3+
with hyperparameter optimization using Optuna. In this example, we fine-tune a lightweight
4+
pre-trained BERT model on a small subset of the IMDb dataset to classify movie reviews as
5+
positive or negative. We optimize the validation accuracy by tuning the learning rate
6+
and batch size. To learn more about transformers' hyperparameter search,
7+
you can check the following documentation:
8+
https://huggingface.co/docs/transformers/en/hpo_train.
9+
"""
10+
11+
from datasets import load_dataset
12+
import evaluate
13+
14+
from transformers import AutoModelForSequenceClassification
15+
from transformers import AutoTokenizer
16+
from transformers import set_seed
17+
from transformers import Trainer
18+
from transformers import TrainingArguments
19+
20+
21+
set_seed(42)
22+
23+
24+
train_dataset = load_dataset("imdb", split="train").shuffle(seed=42).select(range(1000))
25+
valid_dataset = load_dataset("imdb", split="test").shuffle(seed=42).select(range(500))
26+
27+
model_name = "prajjwal1/bert-tiny"
28+
tokenizer = AutoTokenizer.from_pretrained(model_name)
29+
30+
31+
def tokenize(batch):
32+
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)
33+
34+
35+
tokenized_train = train_dataset.map(tokenize, batched=True).select_columns(
36+
["input_ids", "attention_mask", "label"]
37+
)
38+
tokenized_valid = valid_dataset.map(tokenize, batched=True).select_columns(
39+
["input_ids", "attention_mask", "label"]
40+
)
41+
42+
43+
metric = evaluate.load("accuracy")
44+
45+
46+
def model_init():
47+
return AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
48+
49+
50+
def compute_metrics(eval_pred):
51+
predictions = eval_pred.predictions.argmax(axis=-1)
52+
labels = eval_pred.label_ids
53+
return metric.compute(predictions=predictions, references=labels)
54+
55+
56+
def compute_objective(metrics):
57+
return metrics["eval_accuracy"]
58+
59+
60+
training_args = TrainingArguments(
61+
eval_strategy="epoch",
62+
save_strategy="best",
63+
load_best_model_at_end=True,
64+
logging_strategy="epoch",
65+
report_to="none",
66+
)
67+
68+
69+
trainer = Trainer(
70+
model_init=model_init,
71+
args=training_args,
72+
train_dataset=tokenized_train,
73+
eval_dataset=tokenized_valid,
74+
processing_class=tokenizer,
75+
compute_metrics=compute_metrics,
76+
)
77+
78+
79+
def optuna_hp_space(trial):
80+
return {
81+
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
82+
"per_device_train_batch_size": trial.suggest_categorical(
83+
"per_device_train_batch_size", [16, 32, 64, 128]
84+
),
85+
}
86+
87+
88+
best_run = trainer.hyperparameter_search(
89+
direction="maximize",
90+
backend="optuna",
91+
hp_space=optuna_hp_space,
92+
n_trials=5,
93+
compute_objective=compute_objective,
94+
)
95+
96+
print(best_run)

0 commit comments

Comments
 (0)