11import os
2- os .environ ["PYTORCH_ENABLE_MPS_FALLBACK" ] = "1" # Transformers uses .isin for an op, which is not supported on MPS
2+
3+ os .environ ["PYTORCH_ENABLE_MPS_FALLBACK" ] = (
4+ "1" # Transformers uses .isin for an op, which is not supported on MPS
5+ )
36
47from surya .foundation import FoundationPredictor
58from surya .detection import DetectionPredictor
811from surya .recognition import RecognitionPredictor
912from surya .table_rec import TableRecPredictor
1013
11- def create_model_dict (device = None , dtype = None ) -> dict :
12- foundation_predictor = FoundationPredictor (device = device , dtype = dtype )
14+
15+ def create_model_dict (
16+ device = None , dtype = None , attention_implementation : str | None = None
17+ ) -> dict :
18+ foundation_predictor = FoundationPredictor (
19+ device = device , dtype = dtype , attention_implementation = attention_implementation
20+ )
1321 return {
1422 "foundation_model" : foundation_predictor ,
1523 "layout_model" : LayoutPredictor (device = device , dtype = dtype ),
1624 "recognition_model" : RecognitionPredictor (foundation_predictor ),
1725 "table_rec_model" : TableRecPredictor (device = device , dtype = dtype ),
1826 "detection_model" : DetectionPredictor (device = device , dtype = dtype ),
19- "ocr_error_model" : OCRErrorPredictor (device = device , dtype = dtype )
20- }
27+ "ocr_error_model" : OCRErrorPredictor (device = device , dtype = dtype ),
28+ }
0 commit comments