Skip to content

Commit 610e9ee

Browse files
authored
Merge pull request #876 from datalab-to/vik/quality
Vik/quality
2 parents 51042c1 + 27eeb44 commit 610e9ee

5 files changed

Lines changed: 111 additions & 63 deletions

File tree

marker/converters/pdf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from marker.schema.document import Document
4+
35
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
46

57
from collections import defaultdict
@@ -171,7 +173,7 @@ def filepath_to_str(self, file_input: Union[str, io.BytesIO]):
171173
if temp_file is not None and os.path.exists(temp_file.name):
172174
os.unlink(temp_file.name)
173175

174-
def build_document(self, filepath: str):
176+
def build_document(self, filepath: str) -> Document:
175177
provider_cls = provider_from_filepath(filepath)
176178
layout_builder = self.resolve_dependencies(self.layout_builder_class)
177179
line_builder = self.resolve_dependencies(LineBuilder)

marker/models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
2-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
2+
3+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = (
4+
"1" # Transformers uses .isin for an op, which is not supported on MPS
5+
)
36

47
from surya.foundation import FoundationPredictor
58
from surya.detection import DetectionPredictor
@@ -8,13 +11,18 @@
811
from surya.recognition import RecognitionPredictor
912
from surya.table_rec import TableRecPredictor
1013

11-
def create_model_dict(device=None, dtype=None) -> dict:
12-
foundation_predictor = FoundationPredictor(device=device, dtype=dtype)
14+
15+
def create_model_dict(
16+
device=None, dtype=None, attention_implementation: str | None = None
17+
) -> dict:
18+
foundation_predictor = FoundationPredictor(
19+
device=device, dtype=dtype, attention_implementation=attention_implementation
20+
)
1321
return {
1422
"foundation_model": foundation_predictor,
1523
"layout_model": LayoutPredictor(device=device, dtype=dtype),
1624
"recognition_model": RecognitionPredictor(foundation_predictor),
1725
"table_rec_model": TableRecPredictor(device=device, dtype=dtype),
1826
"detection_model": DetectionPredictor(device=device, dtype=dtype),
19-
"ocr_error_model": OCRErrorPredictor(device=device, dtype=dtype)
20-
}
27+
"ocr_error_model": OCRErrorPredictor(device=device, dtype=dtype),
28+
}

marker/schema/blocks/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class Block(BaseModel):
100100
lowres_image: Image.Image | None = None
101101
highres_image: Image.Image | None = None
102102
removed: bool = False # Has block been replaced by new block?
103+
_metadata: Optional[dict] = None
103104

104105
model_config = ConfigDict(arbitrary_types_allowed=True)
105106

@@ -114,6 +115,16 @@ def from_block(cls, block: Block) -> Block:
114115
block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"])
115116
return cls(**block_attrs)
116117

118+
def set_internal_metadata(self, key, data):
119+
if self._metadata is None:
120+
self._metadata = {}
121+
self._metadata[key] = data
122+
123+
def get_internal_metadata(self, key):
124+
if self._metadata is None:
125+
return None
126+
return self._metadata.get(key)
127+
117128
def get_image(
118129
self,
119130
document: Document,

0 commit comments

Comments
 (0)