Skip to content

Commit 1eda0e6

Browse files
authored
Merge pull request #123 from enoch3712/114-standard-documentloader-output
114 standard documentloader output
2 parents a727e09 + 5b0823e commit 1eda0e6

30 files changed

+1485
-1298
lines changed

extract_thinker/document_loader/document_loader.py

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -58,67 +58,21 @@ def _can_handle_stream(self, stream: BytesIO) -> bool:
5858
return False
5959

6060
@abstractmethod
61-
def load_content_from_file(self, file_path: str) -> Union[str, object]:
62-
pass
63-
64-
@abstractmethod
65-
def load_content_from_stream(self, stream: BytesIO) -> Union[str, object]:
66-
pass
67-
6861
def load(self, source: Union[str, BytesIO]) -> Any:
6962
"""Enhanced load method that handles vision mode."""
70-
if not self.can_handle(source):
71-
raise ValueError("Unsupported file type or stream.")
72-
73-
response = {}
74-
75-
# Always process text content
76-
content = self.load_content_from_file(source) if isinstance(source, str) else self.load_content_from_stream(source)
77-
78-
# Merge content with response
79-
if content is not None:
80-
if isinstance(content, dict):
81-
response.update(content)
82-
else:
83-
response['content'] = content
84-
85-
# If vision mode is enabled, add images
86-
if self.vision_mode:
87-
if not self.can_handle_vision(source):
88-
raise ValueError("Source cannot be processed in vision mode. Only PDFs and images are supported.")
89-
90-
# Convert to images and add to response
91-
response['images'] = self.convert_to_images(source)
92-
93-
return response
63+
pass
9464

9565
def getContent(self) -> Any:
9666
return self.content
9767

98-
def load_content_list(self, input_data: Union[str, BytesIO, List[Union[str, BytesIO]]]) -> Union[str, List[str]]:
99-
if isinstance(input_data, (str, BytesIO)):
100-
return self.load_content_from_stream_list(input_data)
101-
elif isinstance(input_data, list):
102-
return self.load_content_from_file_list(input_data)
103-
else:
104-
raise Exception(f"Unsupported input type: {type(input_data)}")
105-
106-
@abstractmethod
107-
def load_content_from_stream_list(self, stream: BytesIO) -> List[Any]:
108-
pass
109-
110-
@abstractmethod
111-
def load_content_from_file_list(self, file_path: str) -> List[Any]:
112-
pass
113-
114-
def convert_to_images(self, file: Union[str, io.BytesIO], scale: float = 300 / 72) -> Dict[int, bytes]:
68+
def convert_to_images(self, file: Union[str, io.BytesIO, io.BufferedReader], scale: float = 300 / 72) -> Dict[int, bytes]:
11569
# Determine if the input is a file path or a stream
11670
if isinstance(file, str):
11771
return self._convert_file_to_images(file, scale)
118-
elif isinstance(file, io.BytesIO):
72+
elif isinstance(file, (io.BytesIO, io.BufferedReader)): # Accept both BytesIO and BufferedReader
11973
return self._convert_stream_to_images(file, scale)
12074
else:
121-
raise TypeError("file must be a file path (str) or a BytesIO stream")
75+
raise TypeError("file must be a file path (str) or a file-like stream")
12276

12377
def _convert_file_to_images(self, file_path: str, scale: float) -> Dict[int, bytes]:
12478
# Check if the file is already an image
Lines changed: 74 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1-
import asyncio
21
from io import BytesIO
3-
from operator import attrgetter
4-
import os
5-
import threading
6-
from typing import Any, List, Union
7-
from PIL import Image
2+
from typing import Any, Dict, List, Union
83
import boto3
94
import pypdfium2 as pdfium
105

11-
from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
6+
from extract_thinker.document_loader.document_loader import DocumentLoader
127
from extract_thinker.utils import get_file_extension, get_image_type, is_pdf_stream
138

14-
from cachetools import cachedmethod
15-
from cachetools.keys import hashkey
169

17-
class DocumentLoaderAWSTextract(CachedDocumentLoader):
10+
class DocumentLoaderAWSTextract(DocumentLoader):
11+
"""Loader for documents using AWS Textract."""
12+
1813
SUPPORTED_FORMATS = ["jpeg", "png", "pdf", "tiff"]
19-
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_name=None, textract_client=None, content=None, cache_ttl=300):
14+
15+
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_name=None,
16+
textract_client=None, content=None, cache_ttl=300):
2017
super().__init__(content, cache_ttl)
2118
if textract_client:
2219
self.textract_client = textract_client
@@ -33,39 +30,61 @@ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_na
3330
@classmethod
3431
def from_client(cls, textract_client, content=None, cache_ttl=300):
3532
return cls(textract_client=textract_client, content=content, cache_ttl=cache_ttl)
33+
34+
def load(self, source: Union[str, BytesIO]) -> List[Dict[str, Any]]:
35+
"""
36+
Load and analyze a document using AWS Textract.
37+
Returns a list of pages, each containing:
38+
- content: The text content of the page
39+
- tables: Any tables found on the page
40+
- image: The page image (if vision_mode is True)
3641
37-
@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))
38-
def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[dict, object]:
39-
try:
40-
if is_pdf_stream(stream):
41-
file_bytes = stream.getvalue() if isinstance(stream, BytesIO) else stream
42-
return self.process_pdf(file_bytes)
43-
elif get_image_type(stream) in self.SUPPORTED_FORMATS:
44-
file_bytes = stream.getvalue() if isinstance(stream, BytesIO) else stream
45-
return self.process_image(file_bytes)
46-
else:
47-
raise Exception(f"Unsupported stream type: {get_file_extension(stream) if isinstance(stream, str) else 'unknown'}")
48-
except Exception as e:
49-
raise Exception(f"Error processing stream: {e}") from e
42+
Args:
43+
source: Either a file path or BytesIO stream
44+
45+
Returns:
46+
List[Dict[str, Any]]: List of pages with content and optional images
47+
"""
48+
if not self.can_handle(source):
49+
raise ValueError(f"Cannot handle source: {source}")
5050

51-
@cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path))
52-
def load_content_from_file(self, file_path: str) -> Union[dict, object]:
5351
try:
54-
file_type = get_file_extension(file_path)
55-
if file_type == 'pdf':
56-
with open(file_path, 'rb') as file:
57-
file_bytes = file.read()
58-
return self.process_pdf(file_bytes)
59-
elif file_type in self.SUPPORTED_FORMATS:
60-
with open(file_path, 'rb') as file:
52+
# Get the file bytes based on source type
53+
if isinstance(source, str):
54+
with open(source, 'rb') as file:
6155
file_bytes = file.read()
62-
return self.process_image(file_bytes)
6356
else:
64-
raise Exception(f"Unsupported file type: {file_path}")
57+
file_bytes = source.getvalue()
58+
59+
# Process with Textract based on file type
60+
if is_pdf_stream(source) or (isinstance(source, str) and source.lower().endswith('.pdf')):
61+
result = self.process_pdf(file_bytes)
62+
else:
63+
result = self.process_image(file_bytes)
64+
65+
# Convert to our standard page-based format
66+
pages = []
67+
for page_num, page_data in enumerate(result.get("pages", [])):
68+
page_dict = {
69+
"content": "\n".join(page_data.get("lines", [])),
70+
"tables": result.get("tables", []) # For now, attach all tables to each page
71+
}
72+
73+
# If vision mode is enabled, add page image
74+
if self.vision_mode:
75+
images_dict = self.convert_to_images(source)
76+
if page_num in images_dict:
77+
page_dict["image"] = images_dict[page_num]
78+
79+
pages.append(page_dict)
80+
81+
return pages
82+
6583
except Exception as e:
66-
raise Exception(f"Error processing file: {e}") from e
84+
raise ValueError(f"Error processing document: {str(e)}")
6785

6886
def process_pdf(self, pdf_bytes: bytes) -> dict:
87+
"""Process a PDF document with Textract."""
6988
for attempt in range(3):
7089
try:
7190
response = self.textract_client.analyze_document(
@@ -75,58 +94,53 @@ def process_pdf(self, pdf_bytes: bytes) -> dict:
7594
return self._parse_analyze_document_response(response)
7695
except Exception as e:
7796
if attempt == 2:
78-
raise Exception(f"Failed to process PDF after 3 attempts: {e}")
97+
raise ValueError(f"Failed to process PDF after 3 attempts: {e}")
7998
return {}
8099

81100
def process_image(self, image_bytes: bytes) -> dict:
101+
"""Process an image with Textract."""
82102
for attempt in range(3):
83103
try:
84104
response = self.textract_client.analyze_document(
85105
Document={'Bytes': image_bytes},
86-
FeatureTypes=['TABLES'] # Only extract tables
106+
FeatureTypes=['TABLES']
87107
)
88108
return self._parse_analyze_document_response(response)
89109
except Exception as e:
90110
if attempt == 2:
91-
raise Exception(f"Failed to process image after 3 attempts: {e}")
111+
raise ValueError(f"Failed to process image after 3 attempts: {e}")
92112
return {}
93113

94114
def _parse_analyze_document_response(self, response: dict) -> dict:
115+
"""Parse Textract response into our format."""
95116
result = {
96117
"pages": [],
97-
"tables": [],
98-
"forms": [],
99-
"layout": {}
118+
"tables": []
100119
}
101120

102-
current_page = {"paragraphs": [], "lines": [], "words": []}
121+
current_page = {"lines": [], "words": []}
103122

104123
for block in response['Blocks']:
105124
if block['BlockType'] == 'PAGE':
106-
if current_page["paragraphs"] or current_page["lines"] or current_page["words"]:
125+
if current_page["lines"]:
107126
result["pages"].append(current_page)
108-
current_page = {"paragraphs": [], "lines": [], "words": []}
127+
current_page = {"lines": [], "words": []}
109128
elif block['BlockType'] == 'LINE':
110129
current_page["lines"].append(block['Text'])
111130
elif block['BlockType'] == 'WORD':
112131
current_page["words"].append(block['Text'])
113132
elif block['BlockType'] == 'TABLE':
114133
result["tables"].append(self._parse_table(block, response['Blocks']))
115-
elif block['BlockType'] == 'KEY_VALUE_SET':
116-
if 'KEY' in block['EntityTypes']:
117-
key = block['Text']
118-
value = self._find_value_for_key(block, response['Blocks'])
119-
result["forms"].append({"key": key, "value": value})
120-
elif block['BlockType'] in ['CELL', 'SELECTION_ELEMENT']:
121-
self._add_to_layout(result["layout"], block)
122134

123-
if current_page["paragraphs"] or current_page["lines"] or current_page["words"]:
135+
if current_page["lines"]:
124136
result["pages"].append(current_page)
125137

126138
return result
127139

128-
def _parse_table(self, table_block, blocks):
129-
cells = [block for block in blocks if block['BlockType'] == 'CELL' and block['Id'] in table_block['Relationships'][0]['Ids']]
140+
def _parse_table(self, table_block: dict, blocks: List[dict]) -> List[List[str]]:
141+
"""Parse a table from Textract response."""
142+
cells = [block for block in blocks if block['BlockType'] == 'CELL'
143+
and block['Id'] in table_block['Relationships'][0]['Ids']]
130144
rows = max(cell['RowIndex'] for cell in cells)
131145
cols = max(cell['ColumnIndex'] for cell in cells)
132146

@@ -136,55 +150,12 @@ def _parse_table(self, table_block, blocks):
136150
row = cell['RowIndex'] - 1
137151
col = cell['ColumnIndex'] - 1
138152
if 'Relationships' in cell:
139-
words = [block['Text'] for block in blocks if block['Id'] in cell['Relationships'][0]['Ids']]
153+
words = [block['Text'] for block in blocks
154+
if block['Id'] in cell['Relationships'][0]['Ids']]
140155
table[row][col] = ' '.join(words)
141156

142157
return table
143158

144-
def _find_value_for_key(self, key_block, blocks):
145-
for relationship in key_block['Relationships']:
146-
if relationship['Type'] == 'VALUE':
147-
value_block = next(block for block in blocks if block['Id'] == relationship['Ids'][0])
148-
if 'Relationships' in value_block:
149-
words = [block['Text'] for block in blocks if block['Id'] in value_block['Relationships'][0]['Ids']]
150-
return ' '.join(words)
151-
return ''
152-
153-
def _add_to_layout(self, layout, block):
154-
block_type = block['BlockType']
155-
if block_type not in layout:
156-
layout[block_type] = []
157-
158-
layout_item = {
159-
'id': block['Id'],
160-
'text': block.get('Text', ''),
161-
'confidence': block['Confidence'],
162-
'geometry': block['Geometry']
163-
}
164-
165-
if 'RowIndex' in block:
166-
layout_item['row_index'] = block['RowIndex']
167-
if 'ColumnIndex' in block:
168-
layout_item['column_index'] = block['ColumnIndex']
169-
if 'SelectionStatus' in block:
170-
layout_item['selection_status'] = block['SelectionStatus']
171-
172-
layout[block_type].append(layout_item)
173-
174-
def load_content_from_stream_list(self, stream: BytesIO) -> List[Any]:
175-
images = self.convert_to_images(stream)
176-
return self._process_images(images)
177-
178-
def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[Any]:
179-
images = self.convert_to_images(input)
180-
return self._process_images(images)
181-
182-
async def _process_images(self, images: dict) -> List[Any]:
183-
tasks = [self.process_image(img) for img in images.values()]
184-
results = await asyncio.gather(*tasks)
185-
186-
contents = []
187-
for (image_name, image), content in zip(images.items(), results):
188-
contents.append({"image": Image.open(BytesIO(image)) if isinstance(image, bytes) else image, "content": content})
189-
190-
return contents
159+
def can_handle_vision(self, source: Union[str, BytesIO]) -> bool:
160+
"""Check if this loader can handle the source in vision mode."""
161+
return self.can_handle(source)

0 commit comments

Comments
 (0)