@@ -89,6 +89,44 @@ def load_model(self, *, dtype_override=None, **kwargs) -> nn.Module:
8989
9090 from surya .detection import DetectionPredictor
9191 from surya .recognition import RecognitionPredictor
92+ from surya .common .surya .processor import (
93+ SuryaOCRProcessor , # type: ignore[reportMissingImports]
94+ )
95+ from surya .detection import (
96+ DetectionPredictor ,
97+ )
98+ from surya .detection import (
99+ heatmap as _surya_heatmap , # type: ignore[reportMissingImports]
100+ )
101+ from surya .detection import heatmap as _surya_heatmap2
102+ from surya .detection .processor import (
103+ SegformerImageProcessor , # type: ignore[reportMissingImports]
104+ )
105+ from surya .foundation .cache .dynamic_ops import (
106+ DynamicOpsCache , # type: ignore[reportMissingImports]
107+ )
108+ from surya .foundation .cache .static_ops import (
109+ StaticOpsCache , # type: ignore[reportMissingImports]
110+ )
111+ from surya .settings import settings # type: ignore[reportMissingImports]
112+ from surya .common .surya import SuryaModel # type: ignore[reportMissingImports]
113+
114+ from .src .utils import (
115+ _detect_boxes_torch ,
116+ _get_dynamic_thresholds_torch ,
117+ _patched_dynamic_ops_cache_init ,
118+ _patched_image_processor ,
119+ _patched_process_and_tile_no_xla ,
120+ _patched_static_ops_cache_init ,
121+ _patched_get_image_embeddings ,
122+ _prepare_image ,
123+ _segformer_preprocess ,
124+ )
125+
126+ DetectionPredictor .prepare_image = _prepare_image
127+ SegformerImageProcessor ._preprocess = _segformer_preprocess
128+ _surya_heatmap .get_dynamic_thresholds = _get_dynamic_thresholds_torch
129+ _surya_heatmap2 .detect_boxes = _detect_boxes_torch
92130
93131 if DetectionPredictor is None or RecognitionPredictor is None :
94132 raise ImportError (
@@ -97,13 +135,19 @@ def load_model(self, *, dtype_override=None, **kwargs) -> nn.Module:
97135 if self .image_tensor is None :
98136 self .load_inputs ()
99137 if self ._variant == ModelVariant .OCR_TEXT :
138+ StaticOpsCache .__init__ = _patched_static_ops_cache_init
139+ DynamicOpsCache .__init__ = _patched_dynamic_ops_cache_init
140+ SuryaOCRProcessor ._image_processor = _patched_image_processor
141+ SuryaOCRProcessor ._process_and_tile = _patched_process_and_tile_no_xla
142+ # Align Surya image embeddings and positional encodings to avoid assertion mismatches
143+ SuryaModel .get_image_embeddings = _patched_get_image_embeddings # type: ignore[assignment]
100144 model = SuryaOCRWrapper (image_tensor = self .image_tensor )
101145 elif self ._variant == ModelVariant .OCR_DETECTION :
102146 model = SuryaOCRDetectionWrapper ()
103147 else :
104148 raise ValueError (f"Invalid variant: { self ._variant } " )
105149 model .eval ()
106-
150+ dtype_override = torch . float32
107151 if dtype_override is not None :
108152 model = model .to (dtype_override )
109153
@@ -124,6 +168,7 @@ def load_inputs(self, dtype_override: Optional[torch.dtype] = torch.float32):
124168
125169 images : List [Image .Image ] = [image ]
126170 self .images = images
171+ dtype_override = torch .float32
127172 if dtype_override is not None :
128173 image_tensor = image_tensor .to (dtype_override )
129174 self .image_tensor = image_tensor
0 commit comments