Skip to content

Commit 0d8b01a

Browse files
Fix Torch Dynamo failure in SuryaOCR Model
1 parent f9b4f8e commit 0d8b01a

File tree

3 files changed

+610
-17
lines changed

3 files changed

+610
-17
lines changed

suryaocr/pytorch/loader.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,44 @@ def load_model(self, *, dtype_override=None, **kwargs) -> nn.Module:
8989

9090
from surya.detection import DetectionPredictor
9191
from surya.recognition import RecognitionPredictor
92+
from surya.common.surya.processor import (
93+
SuryaOCRProcessor, # type: ignore[reportMissingImports]
94+
)
95+
from surya.detection import (
96+
DetectionPredictor,
97+
)
98+
from surya.detection import (
99+
heatmap as _surya_heatmap, # type: ignore[reportMissingImports]
100+
)
101+
from surya.detection import heatmap as _surya_heatmap2
102+
from surya.detection.processor import (
103+
SegformerImageProcessor, # type: ignore[reportMissingImports]
104+
)
105+
from surya.foundation.cache.dynamic_ops import (
106+
DynamicOpsCache, # type: ignore[reportMissingImports]
107+
)
108+
from surya.foundation.cache.static_ops import (
109+
StaticOpsCache, # type: ignore[reportMissingImports]
110+
)
111+
from surya.settings import settings # type: ignore[reportMissingImports]
112+
from surya.common.surya import SuryaModel # type: ignore[reportMissingImports]
113+
114+
from .src.utils import (
115+
_detect_boxes_torch,
116+
_get_dynamic_thresholds_torch,
117+
_patched_dynamic_ops_cache_init,
118+
_patched_image_processor,
119+
_patched_process_and_tile_no_xla,
120+
_patched_static_ops_cache_init,
121+
_patched_get_image_embeddings,
122+
_prepare_image,
123+
_segformer_preprocess,
124+
)
125+
126+
DetectionPredictor.prepare_image = _prepare_image
127+
SegformerImageProcessor._preprocess = _segformer_preprocess
128+
_surya_heatmap.get_dynamic_thresholds = _get_dynamic_thresholds_torch
129+
_surya_heatmap2.detect_boxes = _detect_boxes_torch
92130

93131
if DetectionPredictor is None or RecognitionPredictor is None:
94132
raise ImportError(
@@ -97,13 +135,19 @@ def load_model(self, *, dtype_override=None, **kwargs) -> nn.Module:
97135
if self.image_tensor is None:
98136
self.load_inputs()
99137
if self._variant == ModelVariant.OCR_TEXT:
138+
StaticOpsCache.__init__ = _patched_static_ops_cache_init
139+
DynamicOpsCache.__init__ = _patched_dynamic_ops_cache_init
140+
SuryaOCRProcessor._image_processor = _patched_image_processor
141+
SuryaOCRProcessor._process_and_tile = _patched_process_and_tile_no_xla
142+
# Align Surya image embeddings and positional encodings to avoid assertion mismatches
143+
SuryaModel.get_image_embeddings = _patched_get_image_embeddings # type: ignore[assignment]
100144
model = SuryaOCRWrapper(image_tensor=self.image_tensor)
101145
elif self._variant == ModelVariant.OCR_DETECTION:
102146
model = SuryaOCRDetectionWrapper()
103147
else:
104148
raise ValueError(f"Invalid variant: {self._variant}")
105149
model.eval()
106-
150+
dtype_override = torch.float32
107151
if dtype_override is not None:
108152
model = model.to(dtype_override)
109153

@@ -124,6 +168,7 @@ def load_inputs(self, dtype_override: Optional[torch.dtype] = torch.float32):
124168

125169
images: List[Image.Image] = [image]
126170
self.images = images
171+
dtype_override = torch.float32
127172
if dtype_override is not None:
128173
image_tensor = image_tensor.to(dtype_override)
129174
self.image_tensor = image_tensor

suryaocr/pytorch/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
surya-ocr==0.15.4
1+
surya-ocr==0.17.0

0 commit comments

Comments
 (0)