Skip to content

Commit 2b797c7

Browse files
authored
Merge pull request #249 from enoch3712/248-paginate-handler---bad-optional-list-aggregation
248 paginate handler bad optional list aggregation
2 parents c8b69eb + d8655e7 commit 2b797c7

File tree

7 files changed

+104
-61
lines changed

7 files changed

+104
-61
lines changed

extract_thinker/global_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def get_lite_model():
2+
"""Return the lite model for cost efficiency."""
3+
return "gpt-4o-mini"
4+
5+
6+
def get_big_model():
7+
"""Return the big model for high performance."""
8+
return "gpt-4o"

extract_thinker/pagination_handler.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ class PaginationHandler(CompletionHandler):
1717
def __init__(self, llm):
1818
super().__init__(llm)
1919

20+
def _make_hashable(self, item: Any) -> Any:
21+
"""Recursively convert a value to something hashable."""
22+
if isinstance(item, dict):
23+
return tuple(sorted((k, self._make_hashable(v)) for k, v in item.items()))
24+
elif isinstance(item, list):
25+
return tuple(self._make_hashable(x) for x in item)
26+
return item
27+
2028
def handle(self,
2129
content: List[Dict[str, Any]],
2230
response_model: type[BaseModel],
@@ -81,57 +89,56 @@ def _merge_results(self, results: List[Any], response_model: type[BaseModel], pa
8189
for _, result in pages_data:
8290
result_dict = result.model_dump()
8391
for field_name, field_value in result_dict.items():
84-
if field_name not in field_values:
85-
field_values[field_name] = []
86-
field_values[field_name].append(field_value)
92+
field_values.setdefault(field_name, []).append(field_value)
8793

8894
# Merge fields
8995
merged = {}
9096
for field_name, values in field_values.items():
97+
# Get the annotated type from the response model
9198
field_type = response_model.model_fields[field_name].annotation if field_name in response_model.model_fields else None
9299
non_null_values = [v for v in values if v is not None]
93100

94101
if field_type and get_origin(field_type) is list:
95-
# Merge lists using a more sophisticated approach
102+
# Merge list fields using a more sophisticated approach
96103
merged_list = self._merge_list_field(field_name, values, field_type)
97104
merged[field_name] = merged_list
98105
else:
99106
# Scalar field handling
100107
if len(non_null_values) == 0:
101-
merged[field_name] = None
108+
# If the field is expected to be a string, default to an empty string.
109+
if field_type == str or (get_origin(field_type) is Union and str in get_args(field_type)):
110+
merged[field_name] = ""
111+
else:
112+
continue
102113
else:
103-
# Convert unhashable types (e.g., lists) to hashable types
104-
hashable_values = [tuple(item) if isinstance(item, list) else item for item in non_null_values]
114+
# Build a mapping from the hashable version of each candidate to the original candidate.
115+
distinct_map = {}
116+
for candidate in non_null_values:
117+
key = self._make_hashable(candidate)
118+
if key not in distinct_map:
119+
distinct_map[key] = candidate
120+
distinct_values = list(distinct_map.values())
105121

106-
try:
107-
distinct_values = list(set(hashable_values))
108-
except TypeError:
109-
# Fallback to order-preserving method if conversion fails
110-
seen = set()
111-
distinct_values = []
112-
for item in hashable_values:
113-
if item not in seen:
114-
seen.add(item)
115-
distinct_values.append(item)
116-
# **Modification Ends Here**
117-
118122
if len(distinct_values) == 1:
119123
merged[field_name] = distinct_values[0]
120124
else:
121-
# Store conflicts in special structure
125+
# Store conflicts in a special structure
122126
merged[field_name] = {
123127
"_conflict": True,
124128
"candidates": distinct_values
125129
}
126-
130+
127131
# Check for conflicts and resolve if necessary
128132
if self._has_conflicts(merged, response_model):
129133
merged = self._resolve_conflicts(merged, response_model, pages_data, field_values)
130134

131135
# Clean merged dictionary to ensure it's compatible with the response model
132136
merged = self._clean_merged_dict(merged, response_model)
133-
134-
# Now that conflicts are resolved and cleaned, instantiate the response model
137+
138+
# Filter out any keys with a None value,
139+
# now every required field (e.g., a string like "thinking") will be non-null.
140+
merged = {k: v for k, v in merged.items() if v is not None}
141+
135142
return response_model(**merged)
136143

137144
def _merge_list_field(self, field_name: str, values: List[Any], field_type: Any) -> List[Any]:

tests/models/gdp_contract.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ class CountryData(Contract):
3535
)
3636

3737
class EUData(Contract):
38-
thinking: str = Field(None, description="Think step by step. You have 2 pages dont forget to add them.")
38+
thinking: str = Field(None, description="Think step by step. You have 2 pages dont forget to add them. Cannot be NULL or empty.")
3939
eu_total_gdp_million_27: float = Field(None, description="EU27 Total GDP (€ million)")
4040
eu_total_gdp_million_28: float = Field(None, description="EU28 Total GDP (€ million)")
41-
countries: List[CountryData] = Field(None, description="List of countries. Make sure you add all countries of every page, not just the first one.")
41+
countries: List[CountryData] = Field(None, description="List of countries. Make sure you add all countries of every page, not just the first one.")
42+
43+
class EUDataOptional(Contract):
44+
#thinking: str = Field(None, description="Think step by step. You have 2 pages dont forget to add them.")
45+
eu_total_gdp_million_27: float = Field(None, description="EU27 Total GDP (€ million)")
46+
eu_total_gdp_million_28: float = Field(None, description="EU28 Total GDP (€ million)")
47+
countries: Optional[List[CountryData]] = Field(None, description="List of countries. Make sure you add all countries of every page, not just the first one.")

tests/test_classify.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from extract_thinker.models.classification_response import ClassificationResponse
1313
from tests.models.invoice import CreditNoteContract, FinancialContract, InvoiceContract
1414
from tests.models.driver_license import DriverLicense, IdentificationContract
15+
from extract_thinker.global_models import get_lite_model, get_big_model
1516

1617
# Setup environment and common paths
1718
load_dotenv()
@@ -32,9 +33,9 @@ def setup_extractors():
3233
document_loader = DocumentLoaderTesseract(tesseract_path)
3334

3435
extractors = [
35-
("gpt-3.5-turbo", "gpt-3.5-turbo"),
36-
("claude-3-haiku-20240307", "claude-3-haiku-20240307"),
37-
("gpt-4o", "gpt-4o")
36+
(get_lite_model(), get_lite_model()),
37+
(get_big_model(), get_big_model()),
38+
(get_big_model(), get_big_model())
3839
]
3940

4041
configured_extractors = []
@@ -78,9 +79,9 @@ def setup_process_with_gpt4_extractor():
7879
print(f"Tesseract path: {tesseract_path}")
7980
document_loader = DocumentLoaderTesseract(tesseract_path)
8081

81-
# Initialize the GPT-4 extractor
82+
# Initialize the GPT-4 extractor using the big model
8283
gpt_4_extractor = Extractor(document_loader)
83-
gpt_4_extractor.load_llm("gpt-4o")
84+
gpt_4_extractor.load_llm(get_big_model())
8485

8586
# Create the process with only the GPT-4 extractor
8687
process = Process()
@@ -298,16 +299,16 @@ def test_mom_classification_layers():
298299
# Initialize extractors with different models
299300
# Layer 1: Small models that might disagree
300301
gpt35_extractor = Extractor(document_loader)
301-
gpt35_extractor.load_llm("claude-3-5-haiku-20241022")
302+
gpt35_extractor.load_llm(get_big_model())
302303

303304
claude_haiku_extractor = Extractor(document_loader)
304-
claude_haiku_extractor.load_llm("gpt-4o-mini")
305+
claude_haiku_extractor.load_llm(get_lite_model())
305306

306307
# Layer 2: More capable models for resolution
307308
gpt4_extractor = Extractor(document_loader)
308-
gpt4_extractor.load_llm("gpt-4o")
309+
gpt4_extractor.load_llm(get_big_model())
309310
sonnet_extractor = Extractor(document_loader)
310-
sonnet_extractor.load_llm("claude-3-5-sonnet-20241022")
311+
sonnet_extractor.load_llm(get_big_model())
311312

312313
# Create process with multiple layers
313314
process = Process()

tests/test_document_loader_pypdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import pytest
33
from extract_thinker.document_loader.document_loader_pypdf import DocumentLoaderPyPdf, PyPDFConfig
4-
from .test_document_loader_base import BaseDocumentLoaderTest
4+
from tests.test_document_loader_base import BaseDocumentLoaderTest
55

66
class TestDocumentLoaderPyPdf(BaseDocumentLoaderTest):
77
@pytest.fixture

tests/test_extractor.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
from tests.models.invoice import InvoiceContract
1212
from tests.models.ChartWithContent import ChartWithContent
1313
from tests.models.page_contract import ReportContract
14-
from tests.models.gdp_contract import EUData
14+
from tests.models.gdp_contract import EUData, EUDataOptional
1515
from extract_thinker.document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm
1616
import pytest
1717
import numpy as np
1818
from litellm import embedding
19-
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling
19+
from extract_thinker.document_loader.document_loader_docling import DoclingConfig, DocumentLoaderDocling
2020
from tests.models.handbook_contract import HandbookContract
21+
from extract_thinker.global_models import get_lite_model, get_big_model
2122

2223

2324
load_dotenv()
@@ -32,7 +33,7 @@ def test_extract_with_pypdf_and_gpt4o_mini():
3233
extractor.load_document_loader(
3334
DocumentLoaderPyPdf()
3435
)
35-
extractor.load_llm("gpt-4o-mini")
36+
extractor.load_llm(get_lite_model())
3637

3738
# Act
3839
result = extractor.extract(test_file_path, InvoiceContract)
@@ -51,7 +52,7 @@ def test_extract_with_azure_di_and_gpt4o_mini():
5152
extractor.load_document_loader(
5253
DocumentLoaderAzureForm(subscription_key, endpoint)
5354
)
54-
extractor.load_llm("gpt-4o-mini")
55+
extractor.load_llm(get_lite_model())
5556
# Act
5657
result = extractor.extract(test_file_path, InvoiceContract)
5758

@@ -71,7 +72,7 @@ def test_extract_with_pypdf_and_gpt4o_mini():
7172
extractor.load_llm("gpt-4o-mini")
7273

7374
# Act
74-
result = extractor.extract(test_file_path, InvoiceContract)
75+
result = extractor.extract(test_file_path, InvoiceContract, vision=True)
7576

7677
# Assert
7778
assert result is not None
@@ -83,7 +84,7 @@ def test_extract_with_pypdf_and_gpt4o_mini():
8384
def test_vision_content_pdf():
8485
# Arrange
8586
extractor = Extractor()
86-
extractor.load_llm("gpt-4o-mini")
87+
extractor.load_llm(get_lite_model())
8788
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
8889

8990
# Act
@@ -108,7 +109,7 @@ def test_vision_content_pdf():
108109
def test_chart_with_content():
109110
# Arrange
110111
extractor = Extractor()
111-
extractor.load_llm("gpt-4o-mini")
112+
extractor.load_llm(get_lite_model())
112113
test_file_path = os.path.join(cwd, "tests", "test_images", "eu_tax_chart.png")
113114

114115
# Act
@@ -131,7 +132,7 @@ def test_extract_with_loader_and_vision():
131132
extractor = Extractor()
132133
loader = DocumentLoaderPyPdf()
133134
extractor.load_document_loader(loader)
134-
extractor.load_llm("gpt-4o-mini")
135+
extractor.load_llm(get_lite_model())
135136

136137
# Act
137138
result = extractor.extract(test_file_path, InvoiceContract, vision=True)
@@ -152,7 +153,7 @@ def test_extract_with_loader_and_vision():
152153
def test_extract_with_invalid_file_path():
153154
# Arrange
154155
extractor = Extractor()
155-
extractor.load_llm("gpt-4o-mini")
156+
extractor.load_llm(get_lite_model())
156157
invalid_file_path = os.path.join(cwd, "tests", "nonexistent", "fake_file.png")
157158

158159
# Act & Assert
@@ -165,7 +166,7 @@ def test_forbidden_strategy_with_token_limit():
165166
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "eu_tax_chart.png")
166167
tesseract_path = os.getenv("TESSERACT_PATH")
167168

168-
llm = LLM("gpt-4o-mini", token_limit=10)
169+
llm = LLM(get_lite_model(), token_limit=10)
169170

170171
extractor = Extractor()
171172
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
@@ -194,7 +195,7 @@ def test_pagination_handler():
194195

195196
extractor = Extractor()
196197
extractor.load_document_loader(DocumentLoaderDocling())
197-
extractor.load_llm("gpt-4o")
198+
extractor.load_llm(get_big_model())
198199

199200
# Create and run both extractions in parallel
200201
async def run_parallel_extractions():
@@ -204,8 +205,8 @@ async def run_parallel_extractions():
204205
)
205206
return result_1, result_2
206207

207-
# Run the async code
208-
results: tuple[EUData, EUData] = asyncio.run(run_parallel_extractions())
208+
# Run the async extraction and get the results as instances of OptionalEUData
209+
results = asyncio.run(run_parallel_extractions())
209210
result_1, result_2 = results
210211

211212
# Compare top-level EU data
@@ -252,6 +253,25 @@ async def run_parallel_extractions():
252253
# assert province1.share_in_eu27_gdp == matching_province.share_in_eu27_gdp
253254
# assert province1.gdp_per_capita == matching_province.gdp_per_capita
254255

256+
def test_pagination_handler_optional():
257+
test_file_path = os.path.join(os.getcwd(), "tests", "files", "Regional_GDP_per_capita_2018_2.pdf")
258+
259+
extractor = Extractor()
260+
extractor.load_document_loader(DocumentLoaderDocling())
261+
extractor.load_llm(get_big_model())
262+
263+
async def extract_async_optional(extractor, file_path, vision, completion_strategy):
264+
return extractor.extract(
265+
file_path,
266+
EUDataOptional,
267+
vision=vision,
268+
completion_strategy=completion_strategy
269+
)
270+
271+
result = asyncio.run(extract_async_optional(extractor, test_file_path, vision=True, completion_strategy=CompletionStrategy.PAGINATE))
272+
273+
assert len(result.countries) == 6
274+
255275
def get_embedding(text, model="text-embedding-ada-002"):
256276
text = text.replace("\n", " ")
257277
response = embedding(
@@ -284,7 +304,7 @@ def test_concatenation_handler():
284304
tesseract_path = os.getenv("TESSERACT_PATH")
285305
extractor = Extractor()
286306
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
287-
llm_first = LLM("gpt-4o", token_limit=500)
307+
llm_first = LLM(get_big_model(), token_limit=500)
288308
extractor.load_llm(llm_first)
289309

290310
result_1: ReportContract = extractor.extract(
@@ -296,7 +316,7 @@ def test_concatenation_handler():
296316

297317
second_extractor = Extractor()
298318
second_extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
299-
second_extractor.load_llm("gpt-4o")
319+
second_extractor.load_llm(get_big_model())
300320

301321
result_2: ReportContract = second_extractor.extract(
302322
test_file_path,
@@ -324,7 +344,7 @@ def test_llm_timeout():
324344
extractor.load_document_loader(DocumentLoaderPyPdf())
325345

326346
# Create LLM with very short timeout
327-
llm = LLM("gpt-4o-mini")
347+
llm = LLM(get_lite_model())
328348
llm.set_timeout(1) # Set timeout to 1ms (extremely short to force timeout)
329349
extractor.load_llm(llm)
330350

@@ -374,7 +394,7 @@ def test_extract_with_default_backend():
374394

375395
extractor = Extractor()
376396
extractor.load_document_loader(DocumentLoaderPyPdf())
377-
extractor.load_llm(LLM("gpt-4o-mini", backend=LLMEngine.DEFAULT))
397+
extractor.load_llm(LLM(get_lite_model(), backend=LLMEngine.DEFAULT))
378398

379399
# Act
380400
result = extractor.extract(test_file_path, InvoiceContract)
@@ -395,7 +415,7 @@ def test_extract_with_pydanticai_backend():
395415

396416
extractor = Extractor()
397417
extractor.load_document_loader(DocumentLoaderPyPdf())
398-
extractor.load_llm(LLM("openai:gpt-4o-mini", backend=LLMEngine.PYDANTIC_AI))
418+
extractor.load_llm(LLM(get_lite_model(), backend=LLMEngine.PYDANTIC_AI))
399419

400420
# Act
401421
result = extractor.extract(test_file_path, InvoiceContract)
@@ -419,7 +439,7 @@ def test_extract_from_url_docling_and_gpt4o_mini():
419439
# Initialize the extractor, load the Docling loader and the gpt-4o-mini LLM
420440
extractor = Extractor()
421441
extractor.load_document_loader(DocumentLoaderDocling())
422-
extractor.load_llm("gpt-4o-mini")
442+
extractor.load_llm(get_lite_model())
423443

424444
# Act: Extract the document using the specified URL and the HandbookContract
425445
result = extractor.extract(url, HandbookContract)

0 commit comments

Comments
 (0)