|
26 | 26 | "- After finishing, push to HuggingFace Datasets.\n", |
27 | 27 | "- For labeling:\n", |
28 | 28 | " - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n", |
29 | | - " - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)." |
| 29 | + " - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n", |
| 30 | + "\n", |
| 31 | + "### Menu Type\n", |
| 32 | + "- **h**: horizontal menu\n", |
| 33 | + "- **v**: vertical menu\n", |
| 34 | + "- **d**: document-style menu\n", |
| 35 | + "- **s**: in-scene menu (non-document style)\n", |
| 36 | + "- **i**: irregular menu (menu with irregular text layout)\n", |
| 37 | + "\n", |
| 38 | + "> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details." |
30 | 39 | ] |
31 | 40 | }, |
32 | 41 | { |
|
37 | 46 | "source": [ |
38 | 47 | "from datasets import load_dataset\n", |
39 | 48 | "\n", |
40 | | - "dataset = load_dataset(path=\"datasets/menu-zh-TW\")\n", |
41 | | - "dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\")" |
| 49 | + "dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n", |
| 50 | + "dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub" |
42 | 51 | ] |
43 | 52 | }, |
44 | 53 | { |
|
56 | 65 | "source": [ |
57 | 66 | "from datasets import load_dataset\n", |
58 | 67 | "from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig\n", |
| 68 | + "\n", |
59 | 69 | "from menu.donut import DonutDatasets\n", |
60 | 70 | "\n", |
61 | | - "DATASETS_REPO_ID = \"ryanlinjui/menu-zh-TW\"\n", |
62 | | - "PRETRAINED_MODEL_REPO_ID = \"naver-clova-ix/donut-base\"\n", |
63 | | - "TASK_PROMPT_NAME = \"<s_menu>\"\n", |
64 | | - "MAX_LENGTH = 768\n", |
65 | | - "IMAGE_SIZE = [1280, 960]\n", |
| 71 | + "DATASETS_REPO_ID = \"ryanlinjui/menu-zh-TW\" # set your dataset repo id for training\n", |
| 72 | + "PRETRAINED_MODEL_REPO_ID = \"naver-clova-ix/donut-base\" # set your pretrained model repo id for fine-tuning\n", |
| 73 | + "TASK_PROMPT_NAME = \"<s_menu>\" # set your task prompt name for training\n", |
| 74 | + "MAX_LENGTH = 768 # set your max length for maximum output length\n", |
| 75 | + "IMAGE_SIZE = [1280, 960] # set your image size for training\n", |
66 | 76 | "\n", |
67 | 77 | "raw_datasets = load_dataset(DATASETS_REPO_ID)\n", |
68 | 78 | "\n", |
69 | | - "# Config: 預訓練模型載入 Encoder–Decoder 的設定\n", |
| 79 | + "# Config: set the model config\n", |
70 | 80 | "config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_REPO_ID)\n", |
71 | 81 | "config.encoder.image_size = IMAGE_SIZE\n", |
72 | 82 | "config.decoder.max_length = MAX_LENGTH\n", |
73 | 83 | "\n", |
74 | | - "# Processor: 影像前處理與文字後處理\n", |
| 84 | + "# Processor: use the processor to process the dataset. \n", |
| 85 | + "# Convert the image to the tensor and the text to the token ids.\n", |
75 | 86 | "processor = DonutProcessor.from_pretrained(PRETRAINED_MODEL_REPO_ID)\n", |
76 | 87 | "processor.feature_extractor.size = IMAGE_SIZE[::-1]\n", |
77 | 88 | "processor.feature_extractor.do_align_long_axis = False\n", |
78 | 89 | "\n", |
79 | | - "# Donut Datasets: \n", |
| 90 | + "# DonutDatasets: use the DonutDatasets to process the dataset.\n", |
| 91 | + "# For model inpit, the image must be converted to the tensor and the json text must be converted to the token with the task prompt string.\n", |
| 92 | + "# This example sets the column name by \"image\" and \"menu\". So that image file is included in the \"image\" column and the json text is included in the \"menu\" column.\n", |
80 | 93 | "datasets = DonutDatasets(\n", |
81 | 94 | " datasets=raw_datasets,\n", |
82 | 95 | " processor=processor,\n", |
|
91 | 104 | " seed=42\n", |
92 | 105 | ")\n", |
93 | 106 | "\n", |
94 | | - "# Model: 載入預訓練模型\n", |
| 107 | + "# Model: load the pretrained model and set the config.\n", |
95 | 108 | "model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_REPO_ID, config=config)\n", |
96 | 109 | "model.decoder.resize_token_embeddings(len(processor.tokenizer))\n", |
97 | 110 | "model.config.pad_token_id = processor.tokenizer.pad_token_id\n", |
|
114 | 127 | "import torch\n", |
115 | 128 | "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", |
116 | 129 | "\n", |
117 | | - "HUGGINGFACE_MODEL_ID = \"ryanlinjui/donut-base-finetuned-menu\"\n", |
118 | | - "EPOCHS = 100\n", |
119 | | - "TRAIN_BATCH_SIZE = 4\n", |
| 130 | + "HUGGINGFACE_MODEL_ID = \"ryanlinjui/donut-base-finetuned-menu\" # set your huggingface model repo id for saving / pushing to the hub\n", |
| 131 | + "EPOCHS = 100 # set your training epochs\n", |
| 132 | + "TRAIN_BATCH_SIZE = 4 # set your training batch size\n", |
120 | 133 | "\n", |
121 | | - "if torch.cuda.is_available():\n", |
122 | | - " print(\"Using GPU\")\n", |
123 | | - " model.to(\"cuda\")\n", |
124 | | - "else:\n", |
125 | | - " print(\"Using default device\")\n", |
| 134 | + "device = (\n", |
| 135 | + " \"cuda\"\n", |
| 136 | + " if torch.cuda.is_available()\n", |
| 137 | + " else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", |
| 138 | + ")\n", |
| 139 | + "print(f\"Using {device} device\")\n", |
| 140 | + "model.to(device)\n", |
126 | 141 | "\n", |
127 | 142 | "training_args = Seq2SeqTrainingArguments(\n", |
128 | 143 | " num_train_epochs=EPOCHS,\n", |
|
189 | 204 | "outputs = ocr_pipeline(image)\n", |
190 | 205 | "\n", |
191 | 206 | "# 5. 印出辨識文字\n", |
192 | | - "print(outputs[0][\"generated_text\"])\n" |
193 | | - ] |
194 | | - }, |
195 | | - { |
196 | | - "cell_type": "markdown", |
197 | | - "metadata": {}, |
198 | | - "source": [ |
199 | | - "# Plot the results" |
200 | | - ] |
201 | | - }, |
202 | | - { |
203 | | - "cell_type": "code", |
204 | | - "execution_count": null, |
205 | | - "metadata": {}, |
206 | | - "outputs": [], |
207 | | - "source": [ |
| 207 | + "print(outputs[0][\"generated_text\"])\n", |
| 208 | + "\n", |
| 209 | + "'''\n", |
208 | 210 | "# test model\n", |
209 | 211 | "import re\n", |
210 | 212 | "\n", |
211 | 213 | "from transformers import VisionEncoderDecoderModel\n", |
| 214 | + "from transformers import DonutProcessor\n", |
| 215 | + "import torch\n", |
212 | 216 | "from PIL import Image\n", |
213 | 217 | "\n", |
214 | 218 | "image = Image.open(\"./examples/menu-hd.jpg\").convert(\"RGB\")\n", |
215 | 219 | "\n", |
216 | 220 | "processor = DonutProcessor.from_pretrained(\"ryanlinjui/donut-base-finetuned-menu\")\n", |
217 | 221 | "model = VisionEncoderDecoderModel.from_pretrained(\"ryanlinjui/donut-base-finetuned-menu\")\n", |
218 | | - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", |
| 222 | + "device = \"cuda\" if torch.cuda.is_available() else \"mps\"\n", |
219 | 223 | "\n", |
220 | 224 | "model.eval()\n", |
221 | 225 | "model.to(device)\n", |
|
241 | 245 | "\n", |
242 | 246 | "seq = processor.batch_decode(outputs.sequences)[0]\n", |
243 | 247 | "seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n", |
244 | | - "seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n", |
| 248 | + "# seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n", |
245 | 249 | "seq = processor.token2json(seq)\n", |
246 | | - "print(seq)" |
| 250 | + "print(seq)\n", |
| 251 | + "'''\n" |
| 252 | + ] |
| 253 | + }, |
| 254 | + { |
| 255 | + "cell_type": "markdown", |
| 256 | + "metadata": {}, |
| 257 | + "source": [ |
| 258 | + "# Plot the results" |
| 259 | + ] |
| 260 | + }, |
| 261 | + { |
| 262 | + "cell_type": "code", |
| 263 | + "execution_count": null, |
| 264 | + "metadata": {}, |
| 265 | + "outputs": [], |
| 266 | + "source": [ |
| 267 | + "# Training Loss\n", |
| 268 | + "# Validation Normal ED per each epoch 1~0, 1 -> 0.22\n", |
| 269 | + "# Test Accuracy TED Accuracy, F1 Score Accuracy 0.687058, 0.51119 " |
247 | 270 | ] |
248 | 271 | } |
249 | 272 | ], |
|
263 | 286 | "name": "python", |
264 | 287 | "nbconvert_exporter": "python", |
265 | 288 | "pygments_lexer": "ipython3", |
266 | | - "version": "3.11.10" |
| 289 | + "version": "3.11.12" |
267 | 290 | } |
268 | 291 | }, |
269 | 292 | "nbformat": 4, |
|
0 commit comments