Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion suryaocr/pytorch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,44 @@ def load_model(self, *, dtype_override=None, **kwargs) -> nn.Module:

from surya.detection import DetectionPredictor
from surya.recognition import RecognitionPredictor
from surya.common.surya.processor import (
SuryaOCRProcessor, # type: ignore[reportMissingImports]
)
from surya.detection import (
DetectionPredictor,
)
from surya.detection import (
heatmap as _surya_heatmap, # type: ignore[reportMissingImports]
)
from surya.detection import heatmap as _surya_heatmap2
from surya.detection.processor import (
SegformerImageProcessor, # type: ignore[reportMissingImports]
)
from surya.foundation.cache.dynamic_ops import (
DynamicOpsCache, # type: ignore[reportMissingImports]
)
from surya.foundation.cache.static_ops import (
StaticOpsCache, # type: ignore[reportMissingImports]
)
from surya.settings import settings # type: ignore[reportMissingImports]
from surya.common.surya import SuryaModel # type: ignore[reportMissingImports]

from .src.utils import (
_detect_boxes_torch,
_get_dynamic_thresholds_torch,
_patched_dynamic_ops_cache_init,
_patched_image_processor,
_patched_process_and_tile_no_xla,
_patched_static_ops_cache_init,
_patched_get_image_embeddings,
_prepare_image,
_segformer_preprocess,
)

DetectionPredictor.prepare_image = _prepare_image
SegformerImageProcessor._preprocess = _segformer_preprocess
_surya_heatmap.get_dynamic_thresholds = _get_dynamic_thresholds_torch
_surya_heatmap2.detect_boxes = _detect_boxes_torch

if DetectionPredictor is None or RecognitionPredictor is None:
raise ImportError(
Expand All @@ -97,13 +135,19 @@ def load_model(self, *, dtype_override=None, **kwargs) -> nn.Module:
if self.image_tensor is None:
self.load_inputs()
if self._variant == ModelVariant.OCR_TEXT:
StaticOpsCache.__init__ = _patched_static_ops_cache_init
DynamicOpsCache.__init__ = _patched_dynamic_ops_cache_init
SuryaOCRProcessor._image_processor = _patched_image_processor
SuryaOCRProcessor._process_and_tile = _patched_process_and_tile_no_xla
# Align Surya image embeddings and positional encodings to avoid assertion mismatches
SuryaModel.get_image_embeddings = _patched_get_image_embeddings # type: ignore[assignment]
model = SuryaOCRWrapper(image_tensor=self.image_tensor)
elif self._variant == ModelVariant.OCR_DETECTION:
model = SuryaOCRDetectionWrapper()
else:
raise ValueError(f"Invalid variant: {self._variant}")
model.eval()

dtype_override = torch.float32
if dtype_override is not None:
model = model.to(dtype_override)

Expand All @@ -124,6 +168,7 @@ def load_inputs(self, dtype_override: Optional[torch.dtype] = torch.float32):

images: List[Image.Image] = [image]
self.images = images
dtype_override = torch.float32
if dtype_override is not None:
image_tensor = image_tensor.to(dtype_override)
self.image_tensor = image_tensor
Expand Down
2 changes: 1 addition & 1 deletion suryaocr/pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
surya-ocr==0.15.4
surya-ocr==0.17.0
Loading