Skip to content

Commit 6dd3cd9

Browse files
authored
Merge pull request #197 from enoch3712/195-processextract-must-have-a-completion-strategy
Add strategy to Process.Extract
2 parents c7f556f + 630bfb1 commit 6dd3cd9

File tree

4 files changed

+166
-14
lines changed

4 files changed

+166
-14
lines changed

extract_thinker/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .document_loader.document_loader_pdfplumber import DocumentLoaderPdfPlumber
1111
from .document_loader.document_loader_beautiful_soup import DocumentLoaderBeautifulSoup
1212
from .document_loader.document_loader_markitdown import DocumentLoaderMarkItDown
13+
from .document_loader.document_loader_docling import DocumentLoaderDocling
1314
from .models.classification import Classification
1415
from .models.classification_response import ClassificationResponse
1516
from .process import Process
@@ -18,6 +19,7 @@
1819
from .text_splitter import TextSplitter
1920
from .models.contract import Contract
2021
from .models.splitting_strategy import SplittingStrategy
22+
from .models.completion_strategy import CompletionStrategy
2123
from .batch_job import BatchJob
2224
from .document_loader.document_loader_txt import DocumentLoaderTxt
2325
from .document_loader.document_loader_doc2txt import DocumentLoaderDoc2txt
@@ -47,6 +49,8 @@
4749
'DocumentLoaderDocumentAI',
4850
'DocumentLoaderMarkItDown',
4951
'Classification',
52+
'CompletionStrategy',
53+
'DocumentLoaderDocling',
5054
'ClassificationResponse',
5155
'Process',
5256
'ClassificationStrategy',

extract_thinker/extractor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,9 @@ async def extract_async(
243243
source: Union[str, IO, list],
244244
response_model: type[BaseModel],
245245
vision: bool = False,
246+
completion_strategy: Optional[CompletionStrategy] = CompletionStrategy.FORBIDDEN
246247
) -> Any:
247-
return await asyncio.to_thread(self.extract, source, response_model, vision)
248+
return await asyncio.to_thread(self.extract, source, response_model, vision, "", completion_strategy)
248249

249250
def extract_with_strategy(
250251
self,
@@ -265,13 +266,17 @@ def extract_with_strategy(
265266
Returns:
266267
Parsed response matching response_model
267268
"""
268-
# Get appropriate document loader
269-
document_loader = self.get_document_loader(source)
270-
if document_loader is None:
271-
raise ValueError("No suitable document loader found for the input.")
269+
# If source is already a list, use it directly
270+
if isinstance(source, list):
271+
content = source
272+
else:
273+
# Get appropriate document loader
274+
document_loader = self.get_document_loader(source)
275+
if document_loader is None:
276+
raise ValueError("No suitable document loader found for the input.")
272277

273-
# Load content using list method
274-
content = document_loader.load(source)
278+
# Load content using list method
279+
content = document_loader.load(source)
275280

276281
# Handle based on strategy
277282
if completion_strategy == CompletionStrategy.PAGINATE:

extract_thinker/process.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from extract_thinker.image_splitter import ImageSplitter
44
from extract_thinker.models.classification_response import ClassificationResponse
55
from extract_thinker.models.classification_strategy import ClassificationStrategy
6+
from extract_thinker.models.completion_strategy import CompletionStrategy
67
from extract_thinker.models.doc_groups2 import DocGroups2
78
from extract_thinker.models.splitting_strategy import SplittingStrategy
89
from extract_thinker.extractor import Extractor
@@ -232,15 +233,17 @@ def split(self, classifications: List[Classification], strategy: SplittingStrate
232233

233234
return self
234235

235-
def extract(self, vision: bool = False) -> List[Any]:
236+
def extract(self,
237+
vision: bool = False,
238+
completion_strategy: Optional[CompletionStrategy] = CompletionStrategy.FORBIDDEN) -> List[Any]:
236239
"""Extract information from the document groups."""
237240
if self.doc_groups is None:
238241
raise ValueError("Document groups have not been initialized")
239242

240243
async def _extract(doc_group):
241244
# Find matching classification and extractor
242245
classificationStr = doc_group.classification
243-
extractor = None
246+
extractor: Optional[Extractor] = None
244247
contract = None
245248

246249
for classification in self.split_classifications:
@@ -271,7 +274,12 @@ async def _extract(doc_group):
271274
# Set flag to skip loading since content is already processed
272275
extractor.set_skip_loading(True)
273276
try:
274-
result = await extractor.extract_async(group_pages, contract, vision=vision)
277+
result = await extractor.extract_async(
278+
group_pages,
279+
contract,
280+
vision,
281+
completion_strategy
282+
)
275283
finally:
276284
# Reset flag after extraction
277285
extractor.set_skip_loading(False)

tests/test_ollama.py

Lines changed: 139 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,62 @@
11
import os
2+
from typing import Optional
23
from dotenv import load_dotenv
3-
from extract_thinker.document_loader.document_loader_pypdf import DocumentLoaderPyPdf
4-
from extract_thinker.extractor import Extractor
5-
from tests.models.invoice import InvoiceContract
4+
from extract_thinker import DocumentLoaderPyPdf
5+
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling, DoclingConfig
6+
from extract_thinker import Extractor
7+
from extract_thinker import Contract
8+
from extract_thinker import Classification
9+
from extract_thinker import DocumentLoaderMarkItDown
10+
from extract_thinker.models.completion_strategy import CompletionStrategy
11+
from extract_thinker import SplittingStrategy
12+
from extract_thinker import Process
13+
from extract_thinker import TextSplitter
14+
from extract_thinker import ImageSplitter
15+
from pydantic import Field
16+
17+
from docling.datamodel.pipeline_options import (
18+
PdfPipelineOptions,
19+
TesseractCliOcrOptions,
20+
TableStructureOptions,
21+
)
22+
from docling.datamodel.base_models import InputFormat
23+
from docling.document_converter import PdfFormatOption
624

725
load_dotenv()
826
cwd = os.getcwd()
927

28+
# Define the contracts as shown in the article
29+
class InvoiceContract(Contract):
30+
invoice_number: str = Field(description="Unique invoice identifier")
31+
invoice_date: str = Field(description="Date of the invoice")
32+
total_amount: float = Field(description="Overall total amount")
33+
34+
class VehicleRegistration(Contract):
35+
name_primary: Optional[str] = Field(
36+
default=None,
37+
description="Primary registrant's name (Last, First, Middle)"
38+
)
39+
name_secondary: Optional[str] = Field(
40+
default=None,
41+
description="Co-registrant's name if applicable"
42+
)
43+
address: Optional[str] = Field(
44+
default=None,
45+
description="Primary registrant's mailing address including street, city, state and zip code"
46+
)
47+
vehicle_type: Optional[str] = Field(
48+
default=None,
49+
description="Type of vehicle (e.g., 2-Door, 4-Door, Pick-up, Van, etc.)"
50+
)
51+
vehicle_color: Optional[str] = Field(
52+
default=None,
53+
description="Primary color of the vehicle"
54+
)
55+
56+
class DriverLicenseContract(Contract):
57+
name: Optional[str] = Field(description="Full name on the license")
58+
age: Optional[int] = Field(description="Age of the license holder")
59+
license_number: Optional[str] = Field(description="License number")
1060

1161
def test_extract_with_ollama():
1262
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
@@ -17,7 +67,7 @@ def test_extract_with_ollama():
1767
)
1868

1969
os.environ["API_BASE"] = "http://localhost:11434"
20-
extractor.load_llm("ollama/phi3.5")
70+
extractor.load_llm("ollama/phi4")
2171

2272
# Act
2373
result = extractor.extract(test_file_path, InvoiceContract)
@@ -26,3 +76,88 @@ def test_extract_with_ollama():
2676
assert result is not None
2777
assert result.invoice_number == "00012"
2878
assert result.invoice_date == "1/30/23"
79+
80+
def test_extract_with_ollama_full_pipeline():
81+
"""Test the complete document processing pipeline as described in the article"""
82+
# Setup test file path
83+
test_file_path = os.path.join(cwd, "tests", "files", "bulk.pdf")
84+
85+
# Create classifications
86+
test_classifications = [
87+
Classification(
88+
name="Vehicle Registration",
89+
description="This is a vehicle registration document",
90+
contract=VehicleRegistration
91+
),
92+
Classification(
93+
name="Driver License",
94+
description="This is a driver license document",
95+
contract=DriverLicenseContract
96+
)
97+
]
98+
99+
# Setup OCR options
100+
ocr_options = TesseractCliOcrOptions(
101+
force_full_page_ocr=True,
102+
tesseract_cmd="/opt/homebrew/bin/tesseract"
103+
)
104+
105+
# Setup pipeline options
106+
pipeline_options = PdfPipelineOptions(
107+
do_table_structure=True,
108+
do_ocr=True,
109+
ocr_options=ocr_options,
110+
table_structure_options=TableStructureOptions(
111+
do_cell_matching=True
112+
)
113+
)
114+
115+
# Create format options
116+
format_options = {
117+
InputFormat.PDF: PdfFormatOption(
118+
pipeline_options=pipeline_options
119+
)
120+
}
121+
122+
# Create docling config with OCR enabled
123+
docling_config = DoclingConfig(
124+
format_options=format_options,
125+
ocr_enabled=True,
126+
force_full_page_ocr=True
127+
)
128+
129+
# Setup extractor with OCR-enabled docling loader
130+
extractor = Extractor()
131+
extractor.load_document_loader(DocumentLoaderDocling(docling_config))
132+
133+
# Configure Ollama
134+
os.environ["API_BASE"] = "http://localhost:11434"
135+
extractor.load_llm("ollama/phi4")
136+
137+
# Attach extractor to classifications
138+
for classification in test_classifications:
139+
classification.extractor = extractor
140+
141+
# Setup process
142+
process = Process()
143+
process.load_document_loader(DocumentLoaderDocling(docling_config))
144+
process.load_splitter(ImageSplitter(model="claude-3-5-sonnet-20241022"))
145+
146+
test_classifications[0].extractor = extractor
147+
test_classifications[1].extractor = extractor
148+
149+
# Run the complete pipeline
150+
result = (
151+
process
152+
.load_file(test_file_path)
153+
.split(test_classifications, strategy=SplittingStrategy.LAZY)
154+
.extract(vision=False, completion_strategy=CompletionStrategy.PAGINATE)
155+
)
156+
157+
# Assert
158+
assert result is not None
159+
assert isinstance(result, list)
160+
161+
# Check each extracted item
162+
for item in result:
163+
assert isinstance(item, (VehicleRegistration, DriverLicenseContract))

0 commit comments

Comments
 (0)