|
70 | 70 | "\n", |
71 | 71 | "DATASETS_REPO_ID = \"ryanlinjui/menu-zh-TW\" # set your dataset repo id for training\n", |
72 | 72 | "PRETRAINED_MODEL_REPO_ID = \"naver-clova-ix/donut-base\" # set your pretrained model repo id for fine-tuning\n", |
73 | | - "TASK_PROMPT_NAME = \"<s_menu>\" # set your task prompt name for training\n", |
74 | | - "MAX_LENGTH = 768 # set your max length for maximum output length\n", |
| 73 | + "TASK_PROMPT_NAME = \"<s_menu-text-detection>\" # set your task prompt name for training\n", |
| 74 | + "MAX_LENGTH = 1024 # set your max length for maximum output length, max to 1536 for donut-base\n", |
75 | 75 | "IMAGE_SIZE = [1280, 960] # set your image size for training\n", |
76 | 76 | "\n", |
77 | 77 | "raw_datasets = load_dataset(DATASETS_REPO_ID)\n", |
|
97 | 97 | " annotation_column=\"menu\",\n", |
98 | 98 | " task_start_token=TASK_PROMPT_NAME,\n", |
99 | 99 | " prompt_end_token=TASK_PROMPT_NAME,\n", |
| 100 | + " max_length=MAX_LENGTH,\n", |
100 | 101 | " train_split=0.8,\n", |
101 | 102 | " validation_split=0.1,\n", |
102 | 103 | " test_split=0.1,\n", |
103 | | - " sort_json_key=True,\n", |
104 | | - " seed=42\n", |
| 104 | + " sort_json_key=False,\n", |
| 105 | + " seed=42,\n", |
| 106 | + " shuffle=False\n", |
105 | 107 | ")\n", |
106 | 108 | "\n", |
107 | 109 | "# Model: load the pretrained model and set the config.\n", |
|
124 | 126 | "metadata": {}, |
125 | 127 | "outputs": [], |
126 | 128 | "source": [ |
| 129 | + "from functools import reduce\n", |
| 130 | + "\n", |
127 | 131 | "import torch\n", |
| 132 | + "import numpy as np\n", |
| 133 | + "from nltk.metrics import edit_distance\n", |
| 134 | + "from transformers.trainer_utils import EvalPrediction\n", |
128 | 135 | "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", |
129 | 136 | "\n", |
130 | 137 | "HUGGINGFACE_MODEL_ID = \"ryanlinjui/donut-base-finetuned-menu\" # set your huggingface model repo id for saving / pushing to the hub\n", |
131 | 138 | "EPOCHS = 100 # set your training epochs\n", |
132 | | - "TRAIN_BATCH_SIZE = 4 # set your training batch size\n", |
| 139 | + "TRAIN_BATCH_SIZE = 3 # set your training batch size\n", |
133 | 140 | "\n", |
134 | 141 | "device = (\n", |
135 | 142 | " \"cuda\"\n", |
|
139 | 146 | "print(f\"Using {device} device\")\n", |
140 | 147 | "model.to(device)\n", |
141 | 148 | "\n", |
| 149 | + "train_datasets = datasets[\"train\"]\n", |
| 150 | + "validation_datasets = datasets[\"validation\"]\n", |
| 151 | + "filtered_tokens = [\n", |
| 152 | + " processor.tokenizer.bos_token,\n", |
| 153 | + " processor.tokenizer.eos_token,\n", |
| 154 | + " processor.tokenizer.pad_token,\n", |
| 155 | + " processor.tokenizer.unk_token,\n", |
| 156 | + "]\n", |
| 157 | + "def compute_metrics(eval_pred: EvalPrediction) -> dict:\n", |
| 158 | + " decoded_preds = processor.tokenizer.batch_decode(eval_pred.predictions, skip_special_tokens=False)\n", |
| 159 | + "\n", |
| 160 | + " normed_eds = []\n", |
| 161 | + " for idx, pred in enumerate(decoded_preds):\n", |
| 162 | + " prediction_sequence = reduce(lambda s, t: s.replace(t, \"\"), filtered_tokens, pred)\n", |
| 163 | + " target_sequence = reduce(lambda s, t: s.replace(t, \"\"), filtered_tokens, validation_datasets[idx][\"target_sequence\"])\n", |
| 164 | + " ed = edit_distance(prediction_sequence, target_sequence) / max(len(prediction_sequence), len(target_sequence))\n", |
| 165 | + " normed_eds.append(ed)\n", |
| 166 | + "\n", |
| 167 | + " print(f\"[Sample {idx}]\")\n", |
| 168 | + " print(f\" Prediction: {prediction_sequence}\")\n", |
| 169 | + " print(f\" Target: {target_sequence}\")\n", |
| 170 | + " print(f\" Normalized Edit Distance: {ed:.4f}\")\n", |
| 171 | + " print(\"-\" * 40)\n", |
| 172 | + "\n", |
| 173 | + " return {\"normed_edit_distance\": float(np.mean(normed_eds))}\n", |
| 174 | + "\n", |
142 | 175 | "training_args = Seq2SeqTrainingArguments(\n", |
143 | 176 | " num_train_epochs=EPOCHS,\n", |
144 | 177 | " per_device_train_batch_size=TRAIN_BATCH_SIZE,\n", |
145 | 178 | " learning_rate=3e-5,\n", |
146 | 179 | " per_device_eval_batch_size=1,\n", |
147 | 180 | " output_dir=\"./.checkpoints\",\n", |
148 | 181 | " seed=2022,\n", |
149 | | - " warmup_steps=30,\n", |
| 182 | + " warmup_steps=300,\n", |
150 | 183 | " eval_strategy=\"steps\",\n", |
151 | | - " eval_steps=100,\n", |
| 184 | + " eval_steps=200,\n", |
| 185 | + " fp16=(device == \"cuda\"),\n", |
| 186 | + " predict_with_generate=True,\n", |
| 187 | + " generation_max_length=MAX_LENGTH,\n", |
| 188 | + " generation_num_beams=1,\n", |
152 | 189 | " logging_strategy=\"steps\",\n", |
153 | 190 | " logging_steps=50,\n", |
154 | 191 | " save_strategy=\"steps\",\n", |
|
157 | 194 | " hub_model_id=HUGGINGFACE_MODEL_ID,\n", |
158 | 195 | " hub_strategy=\"every_save\",\n", |
159 | 196 | " report_to=\"tensorboard\",\n", |
160 | | - " logging_dir=\"./.checkpoints/logs\",\n", |
| 197 | + " logging_dir=\"./.checkpoints/logs\"\n", |
161 | 198 | ")\n", |
162 | 199 | "trainer = Seq2SeqTrainer(\n", |
163 | 200 | " model=model,\n", |
164 | 201 | " args=training_args,\n", |
165 | | - " train_dataset=datasets[\"train\"],\n", |
166 | | - " eval_dataset=datasets[\"test\"],\n", |
167 | | - " tokenizer=processor\n", |
| 202 | + " train_dataset=train_datasets,\n", |
| 203 | + " eval_dataset=validation_datasets,\n", |
| 204 | + " tokenizer=processor,\n", |
| 205 | + " compute_metrics=compute_metrics\n", |
168 | 206 | ")\n", |
169 | 207 | "\n", |
170 | | - "trainer.train()" |
| 208 | + "trainer.train()\n", |
| 209 | + "trainer.push_to_hub()" |
171 | 210 | ] |
172 | 211 | }, |
173 | 212 | { |
|
0 commit comments