Skip to content

Commit d6236ab

Browse files
committed
fix: transformers import
1 parent b78c95a commit d6236ab

1 file changed

Lines changed: 25 additions & 23 deletions

File tree

marie/components/document_classifier/transformers.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
import torch
99
import torch.nn as nn
1010
import torch.nn.functional as F
11-
import transformers
1211
from docarray import DocList
1312
from PIL import Image
1413
from torch.nn import Module
1514
from torch.utils.data import DataLoader
1615
from tqdm import tqdm
16+
from transformers import (
17+
AutoModelForSequenceClassification,
18+
AutoTokenizer,
19+
LayoutLMv3ImageProcessor,
20+
LayoutLMv3Processor,
21+
pipeline,
22+
)
1723

1824
from marie.api.docs import BatchableMarieDoc, MarieDoc
1925
from marie.components.document_classifier import BaseDocumentClassifier
@@ -242,7 +248,7 @@ def __init__(
242248
tokenizer = model_name_or_path
243249

244250
if task == "zero-shot-classification":
245-
self.model = transformers.pipeline(
251+
self.model = pipeline(
246252
task=task,
247253
model=model_name_or_path,
248254
tokenizer=tokenizer,
@@ -251,7 +257,7 @@ def __init__(
251257
device=self.device,
252258
)
253259
elif task == "text-classification":
254-
self.model = transformers.pipeline(
260+
self.model = pipeline(
255261
task=task,
256262
model=model_name_or_path,
257263
tokenizer=tokenizer,
@@ -273,13 +279,13 @@ def __init__(
273279
model_class = (
274280
resolve_architecture(architecture)
275281
if architecture
276-
else transformers.AutoModelForSequenceClassification
282+
else AutoModelForSequenceClassification
277283
)
278284

279285
self.model = model_class.from_pretrained(model_name_or_path)
280286
self.model = optimize_model(self.model, self.logger)
281287
self.model = self.model.eval().to(self.device)
282-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
288+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
283289

284290
if os.path.exists(
285291
os.path.join(model_name_or_path, "preprocessor_config.json")
@@ -288,14 +294,14 @@ def __init__(
288294
"Found preprocessor_config.json, loading processor from %s",
289295
model_name_or_path,
290296
)
291-
self.processor = transformers.LayoutLMv3Processor.from_pretrained(
297+
self.processor = LayoutLMv3Processor.from_pretrained(
292298
model_name_or_path, tokenizer=self.tokenizer
293299
)
294300
else:
295-
feature_extractor = transformers.LayoutLMv3ImageProcessor(
301+
feature_extractor = LayoutLMv3ImageProcessor(
296302
apply_ocr=False, do_resize=True, resample=Image.BILINEAR
297303
)
298-
self.processor = transformers.LayoutLMv3Processor(
304+
self.processor = LayoutLMv3Processor(
299305
feature_extractor, tokenizer=self.tokenizer
300306
)
301307

@@ -528,8 +534,6 @@ def __init__(
528534
"""
529535
Load a text classification model from ModelRepository or HuggingFace model hub.
530536
531-
TODO: ADD EXAMPLE AND CODE SNIPPET
532-
533537
See https://huggingface.co/models for full list of available models.
534538
Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
535539
Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
@@ -615,27 +619,27 @@ def __init__(
615619
model_class = (
616620
resolve_architecture(architecture)
617621
if architecture
618-
else transformers.AutoModelForSequenceClassification
622+
else AutoModelForSequenceClassification
619623
)
620624

621625
self.model = model_class.from_pretrained(model_name_or_path)
622626
self.model = optimize_model(self.model, self.logger)
623627
self.model = self.model.eval().to(self.device)
624-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
628+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
625629

626630
if os.path.exists(os.path.join(model_name_or_path, "preprocessor_config.json")):
627631
self.logger.info(
628632
"Found preprocessor_config.json, loading processor from %s",
629633
model_name_or_path,
630634
)
631-
self.processor = transformers.LayoutLMv3Processor.from_pretrained(
635+
self.processor = LayoutLMv3Processor.from_pretrained(
632636
model_name_or_path, tokenizer=self.tokenizer
633637
)
634638
else:
635-
feature_extractor = transformers.LayoutLMv3ImageProcessor(
639+
feature_extractor = LayoutLMv3ImageProcessor(
636640
apply_ocr=False, do_resize=True, resample=Image.BILINEAR
637641
)
638-
self.processor = transformers.LayoutLMv3Processor(
642+
self.processor = LayoutLMv3Processor(
639643
feature_extractor, tokenizer=self.tokenizer
640644
)
641645

@@ -758,8 +762,6 @@ def __init__(
758762
"""
759763
Load a text classification model from ModelRepository or HuggingFace model hub.
760764
761-
TODO: ADD EXAMPLE AND CODE SNIPPET
762-
763765
See https://huggingface.co/models for full list of available models.
764766
Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
765767
Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
@@ -845,13 +847,13 @@ def __init__(
845847
model_class = (
846848
resolve_architecture(architecture)
847849
if architecture
848-
else transformers.AutoModelForSequenceClassification
850+
else AutoModelForSequenceClassification
849851
)
850852

851853
self.model = model_class.from_pretrained(model_name_or_path)
852854
self.model = optimize_model(self.model, self.logger)
853855
self.model = self.model.eval().to(self.device)
854-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
856+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
855857

856858
self.context_pages_num = int(
857859
self.model_config.get("custom_model_parameters", {}).get(
@@ -864,14 +866,14 @@ def __init__(
864866
"Found preprocessor_config.json, loading processor from %s",
865867
model_name_or_path,
866868
)
867-
self.processor = transformers.LayoutLMv3Processor.from_pretrained(
869+
self.processor = LayoutLMv3Processor.from_pretrained(
868870
model_name_or_path, tokenizer=self.tokenizer
869871
)
870872
else:
871-
feature_extractor = transformers.LayoutLMv3ImageProcessor(
873+
feature_extractor = LayoutLMv3ImageProcessor(
872874
apply_ocr=False, do_resize=True, resample=Image.BILINEAR
873875
)
874-
self.processor = transformers.LayoutLMv3Processor(
876+
self.processor = LayoutLMv3Processor(
875877
feature_extractor, tokenizer=self.tokenizer
876878
)
877879

@@ -971,7 +973,7 @@ def predict(
971973

972974
predictions = self.predict_document_boundaries(data_loader)
973975

974-
# TODO fixme
976+
# We cannot split on the last page
975977
predictions.append(
976978
{
977979
"current_page_num": page_count,

0 commit comments

Comments
 (0)