Skip to content

Commit 25719fd

Browse files
committed
chore: add compute metrix
1 parent 5c1c59a commit 25719fd

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies = [
1313
"gradio>=5.29.0",
1414
"huggingface-hub>=0.31.1",
1515
"matplotlib>=3.10.1",
16+
"nltk>=3.9.1",
1617
"notebook>=7.4.2",
1718
"openai>=1.77.0",
1819
"pillow>=11.2.1",

train.ipynb

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@
7070
"\n",
7171
"DATASETS_REPO_ID = \"ryanlinjui/menu-zh-TW\" # set your dataset repo id for training\n",
7272
"PRETRAINED_MODEL_REPO_ID = \"naver-clova-ix/donut-base\" # set your pretrained model repo id for fine-tuning\n",
73-
"TASK_PROMPT_NAME = \"<s_menu>\" # set your task prompt name for training\n",
74-
"MAX_LENGTH = 768 # set your max length for maximum output length\n",
73+
"TASK_PROMPT_NAME = \"<s_menu-text-detection>\" # set your task prompt name for training\n",
74+
"MAX_LENGTH = 1024 # set your max length for maximum output length, max to 1536 for donut-base\n",
7575
"IMAGE_SIZE = [1280, 960] # set your image size for training\n",
7676
"\n",
7777
"raw_datasets = load_dataset(DATASETS_REPO_ID)\n",
@@ -97,11 +97,13 @@
9797
" annotation_column=\"menu\",\n",
9898
" task_start_token=TASK_PROMPT_NAME,\n",
9999
" prompt_end_token=TASK_PROMPT_NAME,\n",
100+
" max_length=MAX_LENGTH,\n",
100101
" train_split=0.8,\n",
101102
" validation_split=0.1,\n",
102103
" test_split=0.1,\n",
103-
" sort_json_key=True,\n",
104-
" seed=42\n",
104+
" sort_json_key=False,\n",
105+
" seed=42,\n",
106+
" shuffle=False\n",
105107
")\n",
106108
"\n",
107109
"# Model: load the pretrained model and set the config.\n",
@@ -124,12 +126,17 @@
124126
"metadata": {},
125127
"outputs": [],
126128
"source": [
129+
"from functools import reduce\n",
130+
"\n",
127131
"import torch\n",
132+
"import numpy as np\n",
133+
"from nltk.metrics import edit_distance\n",
134+
"from transformers.trainer_utils import EvalPrediction\n",
128135
"from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
129136
"\n",
130137
"HUGGINGFACE_MODEL_ID = \"ryanlinjui/donut-base-finetuned-menu\" # set your huggingface model repo id for saving / pushing to the hub\n",
131138
"EPOCHS = 100 # set your training epochs\n",
132-
"TRAIN_BATCH_SIZE = 4 # set your training batch size\n",
139+
"TRAIN_BATCH_SIZE = 3 # set your training batch size\n",
133140
"\n",
134141
"device = (\n",
135142
" \"cuda\"\n",
@@ -139,16 +146,46 @@
139146
"print(f\"Using {device} device\")\n",
140147
"model.to(device)\n",
141148
"\n",
149+
"train_datasets = datasets[\"train\"]\n",
150+
"validation_datasets = datasets[\"validation\"]\n",
151+
"filtered_tokens = [\n",
152+
" processor.tokenizer.bos_token,\n",
153+
" processor.tokenizer.eos_token,\n",
154+
" processor.tokenizer.pad_token,\n",
155+
" processor.tokenizer.unk_token,\n",
156+
"]\n",
157+
"def compute_metrics(eval_pred: EvalPrediction) -> dict:\n",
158+
" decoded_preds = processor.tokenizer.batch_decode(eval_pred.predictions, skip_special_tokens=False)\n",
159+
"\n",
160+
" normed_eds = []\n",
161+
" for idx, pred in enumerate(decoded_preds):\n",
162+
" prediction_sequence = reduce(lambda s, t: s.replace(t, \"\"), filtered_tokens, pred)\n",
163+
" target_sequence = reduce(lambda s, t: s.replace(t, \"\"), filtered_tokens, validation_datasets[idx][\"target_sequence\"])\n",
164+
" ed = edit_distance(prediction_sequence, target_sequence) / max(len(prediction_sequence), len(target_sequence))\n",
165+
" normed_eds.append(ed)\n",
166+
"\n",
167+
" print(f\"[Sample {idx}]\")\n",
168+
" print(f\" Prediction: {prediction_sequence}\")\n",
169+
" print(f\" Target: {target_sequence}\")\n",
170+
" print(f\" Normalized Edit Distance: {ed:.4f}\")\n",
171+
" print(\"-\" * 40)\n",
172+
"\n",
173+
" return {\"normed_edit_distance\": float(np.mean(normed_eds))}\n",
174+
"\n",
142175
"training_args = Seq2SeqTrainingArguments(\n",
143176
" num_train_epochs=EPOCHS,\n",
144177
" per_device_train_batch_size=TRAIN_BATCH_SIZE,\n",
145178
" learning_rate=3e-5,\n",
146179
" per_device_eval_batch_size=1,\n",
147180
" output_dir=\"./.checkpoints\",\n",
148181
" seed=2022,\n",
149-
" warmup_steps=30,\n",
182+
" warmup_steps=300,\n",
150183
" eval_strategy=\"steps\",\n",
151-
" eval_steps=100,\n",
184+
" eval_steps=200,\n",
185+
" fp16=(device == \"cuda\"),\n",
186+
" predict_with_generate=True,\n",
187+
" generation_max_length=MAX_LENGTH,\n",
188+
" generation_num_beams=1,\n",
152189
" logging_strategy=\"steps\",\n",
153190
" logging_steps=50,\n",
154191
" save_strategy=\"steps\",\n",
@@ -157,17 +194,19 @@
157194
" hub_model_id=HUGGINGFACE_MODEL_ID,\n",
158195
" hub_strategy=\"every_save\",\n",
159196
" report_to=\"tensorboard\",\n",
160-
" logging_dir=\"./.checkpoints/logs\",\n",
197+
" logging_dir=\"./.checkpoints/logs\"\n",
161198
")\n",
162199
"trainer = Seq2SeqTrainer(\n",
163200
" model=model,\n",
164201
" args=training_args,\n",
165-
" train_dataset=datasets[\"train\"],\n",
166-
" eval_dataset=datasets[\"test\"],\n",
167-
" tokenizer=processor\n",
202+
" train_dataset=train_datasets,\n",
203+
" eval_dataset=validation_datasets,\n",
204+
" tokenizer=processor,\n",
205+
" compute_metrics=compute_metrics\n",
168206
")\n",
169207
"\n",
170-
"trainer.train()"
208+
"trainer.train()\n",
209+
"trainer.push_to_hub()"
171210
]
172211
},
173212
{

uv.lock

Lines changed: 26 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)