88import torch
99import torch .nn as nn
1010import torch .nn .functional as F
11- import transformers
1211from docarray import DocList
1312from PIL import Image
1413from torch .nn import Module
1514from torch .utils .data import DataLoader
1615from tqdm import tqdm
16+ from transformers import (
17+ AutoModelForSequenceClassification ,
18+ AutoTokenizer ,
19+ LayoutLMv3ImageProcessor ,
20+ LayoutLMv3Processor ,
21+ pipeline ,
22+ )
1723
1824from marie .api .docs import BatchableMarieDoc , MarieDoc
1925from marie .components .document_classifier import BaseDocumentClassifier
@@ -242,7 +248,7 @@ def __init__(
242248 tokenizer = model_name_or_path
243249
244250 if task == "zero-shot-classification" :
245- self .model = transformers . pipeline (
251+ self .model = pipeline (
246252 task = task ,
247253 model = model_name_or_path ,
248254 tokenizer = tokenizer ,
@@ -251,7 +257,7 @@ def __init__(
251257 device = self .device ,
252258 )
253259 elif task == "text-classification" :
254- self .model = transformers . pipeline (
260+ self .model = pipeline (
255261 task = task ,
256262 model = model_name_or_path ,
257263 tokenizer = tokenizer ,
@@ -273,13 +279,13 @@ def __init__(
273279 model_class = (
274280 resolve_architecture (architecture )
275281 if architecture
276- else transformers . AutoModelForSequenceClassification
282+ else AutoModelForSequenceClassification
277283 )
278284
279285 self .model = model_class .from_pretrained (model_name_or_path )
280286 self .model = optimize_model (self .model , self .logger )
281287 self .model = self .model .eval ().to (self .device )
282- self .tokenizer = transformers . AutoTokenizer .from_pretrained (tokenizer )
288+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer )
283289
284290 if os .path .exists (
285291 os .path .join (model_name_or_path , "preprocessor_config.json" )
@@ -288,14 +294,14 @@ def __init__(
288294 "Found preprocessor_config.json, loading processor from %s" ,
289295 model_name_or_path ,
290296 )
291- self .processor = transformers . LayoutLMv3Processor .from_pretrained (
297+ self .processor = LayoutLMv3Processor .from_pretrained (
292298 model_name_or_path , tokenizer = self .tokenizer
293299 )
294300 else :
295- feature_extractor = transformers . LayoutLMv3ImageProcessor (
301+ feature_extractor = LayoutLMv3ImageProcessor (
296302 apply_ocr = False , do_resize = True , resample = Image .BILINEAR
297303 )
298- self .processor = transformers . LayoutLMv3Processor (
304+ self .processor = LayoutLMv3Processor (
299305 feature_extractor , tokenizer = self .tokenizer
300306 )
301307
@@ -528,8 +534,6 @@ def __init__(
528534 """
529535 Load a text classification model from ModelRepository or HuggingFace model hub.
530536
531- TODO: ADD EXAMPLE AND CODE SNIPPET
532-
533537 See https://huggingface.co/models for full list of available models.
534538 Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
535539 Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
@@ -615,27 +619,27 @@ def __init__(
615619 model_class = (
616620 resolve_architecture (architecture )
617621 if architecture
618- else transformers . AutoModelForSequenceClassification
622+ else AutoModelForSequenceClassification
619623 )
620624
621625 self .model = model_class .from_pretrained (model_name_or_path )
622626 self .model = optimize_model (self .model , self .logger )
623627 self .model = self .model .eval ().to (self .device )
624- self .tokenizer = transformers . AutoTokenizer .from_pretrained (tokenizer )
628+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer )
625629
626630 if os .path .exists (os .path .join (model_name_or_path , "preprocessor_config.json" )):
627631 self .logger .info (
628632 "Found preprocessor_config.json, loading processor from %s" ,
629633 model_name_or_path ,
630634 )
631- self .processor = transformers . LayoutLMv3Processor .from_pretrained (
635+ self .processor = LayoutLMv3Processor .from_pretrained (
632636 model_name_or_path , tokenizer = self .tokenizer
633637 )
634638 else :
635- feature_extractor = transformers . LayoutLMv3ImageProcessor (
639+ feature_extractor = LayoutLMv3ImageProcessor (
636640 apply_ocr = False , do_resize = True , resample = Image .BILINEAR
637641 )
638- self .processor = transformers . LayoutLMv3Processor (
642+ self .processor = LayoutLMv3Processor (
639643 feature_extractor , tokenizer = self .tokenizer
640644 )
641645
@@ -758,8 +762,6 @@ def __init__(
758762 """
759763 Load a text classification model from ModelRepository or HuggingFace model hub.
760764
761- TODO: ADD EXAMPLE AND CODE SNIPPET
762-
763765 See https://huggingface.co/models for full list of available models.
764766 Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
765767 Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
@@ -845,13 +847,13 @@ def __init__(
845847 model_class = (
846848 resolve_architecture (architecture )
847849 if architecture
848- else transformers . AutoModelForSequenceClassification
850+ else AutoModelForSequenceClassification
849851 )
850852
851853 self .model = model_class .from_pretrained (model_name_or_path )
852854 self .model = optimize_model (self .model , self .logger )
853855 self .model = self .model .eval ().to (self .device )
854- self .tokenizer = transformers . AutoTokenizer .from_pretrained (tokenizer )
856+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer )
855857
856858 self .context_pages_num = int (
857859 self .model_config .get ("custom_model_parameters" , {}).get (
@@ -864,14 +866,14 @@ def __init__(
864866 "Found preprocessor_config.json, loading processor from %s" ,
865867 model_name_or_path ,
866868 )
867- self .processor = transformers . LayoutLMv3Processor .from_pretrained (
869+ self .processor = LayoutLMv3Processor .from_pretrained (
868870 model_name_or_path , tokenizer = self .tokenizer
869871 )
870872 else :
871- feature_extractor = transformers . LayoutLMv3ImageProcessor (
873+ feature_extractor = LayoutLMv3ImageProcessor (
872874 apply_ocr = False , do_resize = True , resample = Image .BILINEAR
873875 )
874- self .processor = transformers . LayoutLMv3Processor (
876+ self .processor = LayoutLMv3Processor (
875877 feature_extractor , tokenizer = self .tokenizer
876878 )
877879
@@ -971,7 +973,7 @@ def predict(
971973
972974 predictions = self .predict_document_boundaries (data_loader )
973975
974- # TODO fixme
976+ # We cannot split on the last page
975977 predictions .append (
976978 {
977979 "current_page_num" : page_count ,
0 commit comments