1- import asyncio
21from io import BytesIO
3- from operator import attrgetter
4- import os
5- import threading
6- from typing import Any , List , Union
7- from PIL import Image
2+ from typing import Any , Dict , List , Union
83import boto3
94import pypdfium2 as pdfium
105
11- from extract_thinker .document_loader .cached_document_loader import CachedDocumentLoader
6+ from extract_thinker .document_loader .document_loader import DocumentLoader
127from extract_thinker .utils import get_file_extension , get_image_type , is_pdf_stream
138
14- from cachetools import cachedmethod
15- from cachetools .keys import hashkey
169
17- class DocumentLoaderAWSTextract (CachedDocumentLoader ):
10+ class DocumentLoaderAWSTextract (DocumentLoader ):
11+ """Loader for documents using AWS Textract."""
12+
1813 SUPPORTED_FORMATS = ["jpeg" , "png" , "pdf" , "tiff" ]
19- def __init__ (self , aws_access_key_id = None , aws_secret_access_key = None , region_name = None , textract_client = None , content = None , cache_ttl = 300 ):
14+
15+ def __init__ (self , aws_access_key_id = None , aws_secret_access_key = None , region_name = None ,
16+ textract_client = None , content = None , cache_ttl = 300 ):
2017 super ().__init__ (content , cache_ttl )
2118 if textract_client :
2219 self .textract_client = textract_client
@@ -33,39 +30,61 @@ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_na
3330 @classmethod
3431 def from_client (cls , textract_client , content = None , cache_ttl = 300 ):
3532 return cls (textract_client = textract_client , content = content , cache_ttl = cache_ttl )
33+
34+ def load (self , source : Union [str , BytesIO ]) -> List [Dict [str , Any ]]:
35+ """
36+ Load and analyze a document using AWS Textract.
37+ Returns a list of pages, each containing:
38+ - content: The text content of the page
39+ - tables: Any tables found on the page
40+ - image: The page image (if vision_mode is True)
3641
37- @cachedmethod (cache = attrgetter ('cache' ), key = lambda self , stream : hashkey (id (stream )))
38- def load_content_from_stream (self , stream : Union [BytesIO , str ]) -> Union [dict , object ]:
39- try :
40- if is_pdf_stream (stream ):
41- file_bytes = stream .getvalue () if isinstance (stream , BytesIO ) else stream
42- return self .process_pdf (file_bytes )
43- elif get_image_type (stream ) in self .SUPPORTED_FORMATS :
44- file_bytes = stream .getvalue () if isinstance (stream , BytesIO ) else stream
45- return self .process_image (file_bytes )
46- else :
47- raise Exception (f"Unsupported stream type: { get_file_extension (stream ) if isinstance (stream , str ) else 'unknown' } " )
48- except Exception as e :
49- raise Exception (f"Error processing stream: { e } " ) from e
42+ Args:
43+ source: Either a file path or BytesIO stream
44+
45+ Returns:
46+ List[Dict[str, Any]]: List of pages with content and optional images
47+ """
48+ if not self .can_handle (source ):
49+ raise ValueError (f"Cannot handle source: { source } " )
5050
51- @cachedmethod (cache = attrgetter ('cache' ), key = lambda self , file_path : hashkey (file_path ))
52- def load_content_from_file (self , file_path : str ) -> Union [dict , object ]:
5351 try :
54- file_type = get_file_extension (file_path )
55- if file_type == 'pdf' :
56- with open (file_path , 'rb' ) as file :
57- file_bytes = file .read ()
58- return self .process_pdf (file_bytes )
59- elif file_type in self .SUPPORTED_FORMATS :
60- with open (file_path , 'rb' ) as file :
52+ # Get the file bytes based on source type
53+ if isinstance (source , str ):
54+ with open (source , 'rb' ) as file :
6155 file_bytes = file .read ()
62- return self .process_image (file_bytes )
6356 else :
64- raise Exception (f"Unsupported file type: { file_path } " )
57+ file_bytes = source .getvalue ()
58+
59+ # Process with Textract based on file type
60+ if is_pdf_stream (source ) or (isinstance (source , str ) and source .lower ().endswith ('.pdf' )):
61+ result = self .process_pdf (file_bytes )
62+ else :
63+ result = self .process_image (file_bytes )
64+
65+ # Convert to our standard page-based format
66+ pages = []
67+ for page_num , page_data in enumerate (result .get ("pages" , [])):
68+ page_dict = {
69+ "content" : "\n " .join (page_data .get ("lines" , [])),
70+ "tables" : result .get ("tables" , []) # For now, attach all tables to each page
71+ }
72+
73+ # If vision mode is enabled, add page image
74+ if self .vision_mode :
75+ images_dict = self .convert_to_images (source )
76+ if page_num in images_dict :
77+ page_dict ["image" ] = images_dict [page_num ]
78+
79+ pages .append (page_dict )
80+
81+ return pages
82+
6583 except Exception as e :
66- raise Exception (f"Error processing file : { e } " ) from e
84+ raise ValueError (f"Error processing document : { str ( e ) } " )
6785
6886 def process_pdf (self , pdf_bytes : bytes ) -> dict :
87+ """Process a PDF document with Textract."""
6988 for attempt in range (3 ):
7089 try :
7190 response = self .textract_client .analyze_document (
@@ -75,58 +94,53 @@ def process_pdf(self, pdf_bytes: bytes) -> dict:
7594 return self ._parse_analyze_document_response (response )
7695 except Exception as e :
7796 if attempt == 2 :
78- raise Exception (f"Failed to process PDF after 3 attempts: { e } " )
97+ raise ValueError (f"Failed to process PDF after 3 attempts: { e } " )
7998 return {}
8099
81100 def process_image (self , image_bytes : bytes ) -> dict :
101+ """Process an image with Textract."""
82102 for attempt in range (3 ):
83103 try :
84104 response = self .textract_client .analyze_document (
85105 Document = {'Bytes' : image_bytes },
86- FeatureTypes = ['TABLES' ] # Only extract tables
106+ FeatureTypes = ['TABLES' ]
87107 )
88108 return self ._parse_analyze_document_response (response )
89109 except Exception as e :
90110 if attempt == 2 :
91- raise Exception (f"Failed to process image after 3 attempts: { e } " )
111+ raise ValueError (f"Failed to process image after 3 attempts: { e } " )
92112 return {}
93113
94114 def _parse_analyze_document_response (self , response : dict ) -> dict :
115+ """Parse Textract response into our format."""
95116 result = {
96117 "pages" : [],
97- "tables" : [],
98- "forms" : [],
99- "layout" : {}
118+ "tables" : []
100119 }
101120
102- current_page = {"paragraphs" : [], " lines" : [], "words" : []}
121+ current_page = {"lines" : [], "words" : []}
103122
104123 for block in response ['Blocks' ]:
105124 if block ['BlockType' ] == 'PAGE' :
106- if current_page ["paragraphs" ] or current_page [ " lines" ] or current_page [ "words " ]:
125+ if current_page ["lines" ]:
107126 result ["pages" ].append (current_page )
108- current_page = {"paragraphs" : [], " lines" : [], "words" : []}
127+ current_page = {"lines" : [], "words" : []}
109128 elif block ['BlockType' ] == 'LINE' :
110129 current_page ["lines" ].append (block ['Text' ])
111130 elif block ['BlockType' ] == 'WORD' :
112131 current_page ["words" ].append (block ['Text' ])
113132 elif block ['BlockType' ] == 'TABLE' :
114133 result ["tables" ].append (self ._parse_table (block , response ['Blocks' ]))
115- elif block ['BlockType' ] == 'KEY_VALUE_SET' :
116- if 'KEY' in block ['EntityTypes' ]:
117- key = block ['Text' ]
118- value = self ._find_value_for_key (block , response ['Blocks' ])
119- result ["forms" ].append ({"key" : key , "value" : value })
120- elif block ['BlockType' ] in ['CELL' , 'SELECTION_ELEMENT' ]:
121- self ._add_to_layout (result ["layout" ], block )
122134
123- if current_page ["paragraphs" ] or current_page [ " lines" ] or current_page [ "words " ]:
135+ if current_page ["lines" ]:
124136 result ["pages" ].append (current_page )
125137
126138 return result
127139
128- def _parse_table (self , table_block , blocks ):
129- cells = [block for block in blocks if block ['BlockType' ] == 'CELL' and block ['Id' ] in table_block ['Relationships' ][0 ]['Ids' ]]
140+ def _parse_table (self , table_block : dict , blocks : List [dict ]) -> List [List [str ]]:
141+ """Parse a table from Textract response."""
142+ cells = [block for block in blocks if block ['BlockType' ] == 'CELL'
143+ and block ['Id' ] in table_block ['Relationships' ][0 ]['Ids' ]]
130144 rows = max (cell ['RowIndex' ] for cell in cells )
131145 cols = max (cell ['ColumnIndex' ] for cell in cells )
132146
@@ -136,55 +150,12 @@ def _parse_table(self, table_block, blocks):
136150 row = cell ['RowIndex' ] - 1
137151 col = cell ['ColumnIndex' ] - 1
138152 if 'Relationships' in cell :
139- words = [block ['Text' ] for block in blocks if block ['Id' ] in cell ['Relationships' ][0 ]['Ids' ]]
153+ words = [block ['Text' ] for block in blocks
154+ if block ['Id' ] in cell ['Relationships' ][0 ]['Ids' ]]
140155 table [row ][col ] = ' ' .join (words )
141156
142157 return table
143158
144- def _find_value_for_key (self , key_block , blocks ):
145- for relationship in key_block ['Relationships' ]:
146- if relationship ['Type' ] == 'VALUE' :
147- value_block = next (block for block in blocks if block ['Id' ] == relationship ['Ids' ][0 ])
148- if 'Relationships' in value_block :
149- words = [block ['Text' ] for block in blocks if block ['Id' ] in value_block ['Relationships' ][0 ]['Ids' ]]
150- return ' ' .join (words )
151- return ''
152-
153- def _add_to_layout (self , layout , block ):
154- block_type = block ['BlockType' ]
155- if block_type not in layout :
156- layout [block_type ] = []
157-
158- layout_item = {
159- 'id' : block ['Id' ],
160- 'text' : block .get ('Text' , '' ),
161- 'confidence' : block ['Confidence' ],
162- 'geometry' : block ['Geometry' ]
163- }
164-
165- if 'RowIndex' in block :
166- layout_item ['row_index' ] = block ['RowIndex' ]
167- if 'ColumnIndex' in block :
168- layout_item ['column_index' ] = block ['ColumnIndex' ]
169- if 'SelectionStatus' in block :
170- layout_item ['selection_status' ] = block ['SelectionStatus' ]
171-
172- layout [block_type ].append (layout_item )
173-
174- def load_content_from_stream_list (self , stream : BytesIO ) -> List [Any ]:
175- images = self .convert_to_images (stream )
176- return self ._process_images (images )
177-
178- def load_content_from_file_list (self , input : List [Union [str , BytesIO ]]) -> List [Any ]:
179- images = self .convert_to_images (input )
180- return self ._process_images (images )
181-
182- async def _process_images (self , images : dict ) -> List [Any ]:
183- tasks = [self .process_image (img ) for img in images .values ()]
184- results = await asyncio .gather (* tasks )
185-
186- contents = []
187- for (image_name , image ), content in zip (images .items (), results ):
188- contents .append ({"image" : Image .open (BytesIO (image )) if isinstance (image , bytes ) else image , "content" : content })
189-
190- return contents
159+ def can_handle_vision (self , source : Union [str , BytesIO ]) -> bool :
160+ """Check if this loader can handle the source in vision mode."""
161+ return self .can_handle (source )
0 commit comments