Skip to content

Commit 5d06ed3

Browse files
authored
Merge pull request #279 from enoch3712/245-add-global-models-to-test
refactor of the tests. document_loader multiple choice. Multi image fix
2 parents fd24ba3 + b436c8b commit 5d06ed3

File tree

7 files changed

+107
-100
lines changed

7 files changed

+107
-100
lines changed

extract_thinker/document_loader/document_loader_data.py

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def can_handle(self, source: Any) -> bool:
5151
return True
5252
if isinstance(source, list) and all(isinstance(item, dict) for item in source):
5353
return True
54+
if isinstance(source, dict):
55+
return True
5456
return False
5557

5658
@cachedmethod(cache=attrgetter('cache'),
@@ -80,6 +82,8 @@ def load(self, source: Union[str, IO, List[Dict[str, Any]]]) -> List[Dict[str, A
8082
return self._load_from_string(source)
8183
elif hasattr(source, "read"):
8284
return self._load_from_stream(source)
85+
elif isinstance(source, dict):
86+
return source
8387

8488
except Exception as e:
8589
raise ValueError(f"Error processing content: {str(e)}")

extract_thinker/extractor.py

+54-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from instructor.batch import BatchJob
55
import uuid
66
from pydantic import BaseModel
7+
from extract_thinker.document_loader.document_loader_data import DocumentLoaderData
78
from extract_thinker.llm_engine import LLMEngine
89
from extract_thinker.concatenation_handler import ConcatenationHandler
910
from extract_thinker.document_loader.document_loader import DocumentLoader
@@ -53,6 +54,7 @@ def __init__(
5354
self.is_classify_image: bool = False
5455
self._skip_loading: bool = False
5556
self.chunk_height: int = 1500
57+
self.allow_vision: bool = False
5658

5759
def add_interceptor(
5860
self, interceptor: Union[LoaderInterceptor, LlmInterceptor]
@@ -85,7 +87,7 @@ def get_document_loader_for_file(self, source: Union[str, IO]) -> DocumentLoader
8587

8688
raise ValueError("No suitable document loader found for the input.")
8789

88-
def get_document_loader(self, source: Union[str, IO]) -> Optional[DocumentLoader]:
90+
def get_document_loader(self, source: Union[str, IO, List[Union[str, IO]]]) -> Optional[DocumentLoader]:
8991
"""
9092
Retrieve the appropriate document loader for the given source.
9193
@@ -110,6 +112,14 @@ def get_document_loader(self, source: Union[str, IO]) -> Optional[DocumentLoader
110112
for loader in self.document_loaders_by_file_type.values():
111113
if loader.can_handle(source):
112114
return loader
115+
116+
# if is a list, usually coming from split, return documentLoaderData
117+
if isinstance(source, List) or isinstance(source, dict):
118+
return DocumentLoaderData()
119+
120+
# Last check, if allow vision just return the document loader llm image
121+
if self.allow_vision:
122+
return DocumentLoaderLLMImage()
113123

114124
return None
115125

@@ -148,6 +158,36 @@ def set_skip_loading(self, skip: bool = True) -> None:
148158
"""Internal method to control content loading behavior"""
149159
self._skip_loading = skip
150160

161+
def remove_images_from_content(self, content: Union[Dict[str, Any], List[Dict[str, Any]], str]) -> Union[Dict[str, Any], List[Dict[str, Any]], str]:
162+
"""
163+
Remove image-related keys from the content while preserving the original structure.
164+
165+
Args:
166+
content: Input content that can be a dictionary, list of dictionaries, or string
167+
168+
Returns:
169+
Content with image-related keys removed, maintaining the original type
170+
"""
171+
if isinstance(content, dict):
172+
# Create a deep copy to avoid modifying the original
173+
content_copy = {
174+
k: v for k, v in content.items()
175+
if k not in ('images', 'image')
176+
}
177+
return content_copy
178+
179+
elif isinstance(content, list):
180+
# Handle list of dictionaries
181+
return [
182+
self.remove_images_from_content(item)
183+
if isinstance(item, (dict, list))
184+
else item
185+
for item in content
186+
]
187+
188+
# Return strings or other types unchanged
189+
return content
190+
151191
def extract(
152192
self,
153193
source: Union[str, IO, List[Union[str, IO]]],
@@ -176,12 +216,16 @@ def extract(
176216
self._validate_dependencies(response_model, vision)
177217
self.extra_content = content
178218
self.completion_strategy = completion_strategy
219+
self.allow_vision = vision
179220

180221
if vision:
181222
try:
182223
self._handle_vision_mode(source)
183224
except ValueError as e:
184225
raise InvalidVisionDocumentLoaderError(str(e))
226+
else:
227+
if isinstance(source, List):
228+
source = self.remove_images_from_content(source)
185229

186230
if completion_strategy is not CompletionStrategy.FORBIDDEN:
187231
return self.extract_with_strategy(source, response_model, vision, completion_strategy)
@@ -1149,23 +1193,18 @@ def _add_images_to_message_content(
11491193
content: Union[Dict[str, Any], List[Any]],
11501194
message_content: List[Dict[str, Any]],
11511195
) -> None:
1152-
"""
1153-
Add images to the message content.
1154-
Handles both legacy format and new page-based format from document loaders.
1155-
1156-
Args:
1157-
content: The content containing images.
1158-
message_content: The message content to append images to.
1159-
"""
11601196
if isinstance(content, list):
1161-
# Handle new page-based format
11621197
for page in content:
1163-
if isinstance(page, dict) and 'image' in page:
1164-
self._append_images(page['image'], message_content)
1198+
if isinstance(page, dict):
1199+
if 'image' in page:
1200+
self._append_images(page['image'], message_content)
1201+
if 'images' in page:
1202+
self._append_images(page['images'], message_content)
11651203
elif isinstance(content, dict):
1166-
# Handle legacy format
1167-
image_data = content.get('image') or content.get('images')
1168-
self._append_images(image_data[0], message_content)
1204+
if 'image' in content:
1205+
self._append_images(content['image'], message_content)
1206+
if 'images' in content:
1207+
self._append_images(content['images'], message_content)
11691208

11701209
def _append_images(
11711210
self,

extract_thinker/global_models.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
def get_lite_model():
22
"""Return the lite model for cost efficiency."""
3+
#return "vertex_ai/gemini-2.0-flash"
34
return "gpt-4o-mini"
45

5-
66
def get_big_model():
77
"""Return the big model for high performance."""
8-
return "gpt-4o"
8+
#return "vertex_ai/gemini-2.0-flash"
9+
return "gpt-4o"

extract_thinker/process.py

-3
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
from extract_thinker.models.classification_response import ClassificationResponse
55
from extract_thinker.models.classification_strategy import ClassificationStrategy
66
from extract_thinker.models.completion_strategy import CompletionStrategy
7-
from extract_thinker.models.doc_groups2 import DocGroups2
87
from extract_thinker.models.splitting_strategy import SplittingStrategy
98
from extract_thinker.extractor import Extractor
109
from extract_thinker.models.classification import Classification
1110
from extract_thinker.document_loader.document_loader import DocumentLoader
1211
from extract_thinker.models.classification_tree import ClassificationTree
13-
from extract_thinker.models.classification_node import ClassificationNode
14-
from extract_thinker.models.doc_group import DocGroup
1512
from extract_thinker.splitter import Splitter
1613
from extract_thinker.models.doc_groups import (
1714
DocGroups,

tests/test_extractor.py

+14-79
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
from tests.models.handbook_contract import HandbookContract
1919
from extract_thinker.global_models import get_lite_model, get_big_model
2020
from pydantic import BaseModel, Field
21+
from extract_thinker.exceptions import ExtractThinkerError
2122

2223

2324
load_dotenv()
2425
cwd = os.getcwd()
2526

26-
def test_extract_with_pypdf_and_gpt4o_mini():
27+
def test_extract_with_pypdf_and_gpt4o_mini_vision():
2728

2829
# Arrange
29-
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
30+
test_file_path = os.path.join(cwd, "tests", "test_images", "invoice.png")
3031

3132
extractor = Extractor()
3233
extractor.load_document_loader(
@@ -35,51 +36,13 @@ def test_extract_with_pypdf_and_gpt4o_mini():
3536
extractor.load_llm(get_lite_model())
3637

3738
# Act
38-
result = extractor.extract(test_file_path, InvoiceContract)
39+
result = extractor.extract(test_file_path, InvoiceContract, vision=True)
3940

4041
# Assert
4142
assert result is not None
4243
assert result.invoice_number == "0000001"
4344
assert result.invoice_date == "2014-05-07"
4445

45-
def test_extract_with_azure_di_and_gpt4o_mini():
46-
subscription_key = os.getenv("AZURE_SUBSCRIPTION_KEY")
47-
endpoint = os.getenv("AZURE_ENDPOINT")
48-
test_file_path = os.path.join(cwd, "tests", "test_images", "invoice.png")
49-
50-
extractor = Extractor()
51-
extractor.load_document_loader(
52-
DocumentLoaderAzureForm(subscription_key, endpoint)
53-
)
54-
extractor.load_llm(get_lite_model())
55-
# Act
56-
result = extractor.extract(test_file_path, InvoiceContract)
57-
58-
# Assert
59-
assert result is not None
60-
assert result.lines[0].description == "Website Redesign"
61-
assert result.lines[0].quantity == 1
62-
assert result.lines[0].unit_price == 2500
63-
assert result.lines[0].amount == 2500
64-
65-
def test_extract_with_pypdf_and_gpt4o_mini():
66-
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
67-
68-
extractor = Extractor()
69-
document_loader = DocumentLoaderPyPdf()
70-
extractor.load_document_loader(document_loader)
71-
extractor.load_llm("gpt-4o-mini")
72-
73-
# Act
74-
result = extractor.extract(test_file_path, InvoiceContract, vision=True)
75-
76-
# Assert
77-
assert result is not None
78-
assert result.lines[0].description == "Consultation services"
79-
assert result.lines[0].quantity == 3
80-
assert result.lines[0].unit_price == 375
81-
assert result.lines[0].amount == 1125
82-
8346
def test_vision_content_pdf():
8447
# Arrange
8548
extractor = Extractor()
@@ -156,10 +119,10 @@ def test_extract_with_invalid_file_path():
156119
invalid_file_path = os.path.join(cwd, "tests", "nonexistent", "fake_file.png")
157120

158121
# Act & Assert
159-
with pytest.raises(ValueError) as exc_info:
122+
with pytest.raises(ExtractThinkerError) as exc_info:
160123
extractor.extract(invalid_file_path, InvoiceContract, vision=True)
161124

162-
assert "Failed to extract from source" in str(exc_info.value.args[0])
125+
assert "Failed to extract from source: Cannot handle source" in str(exc_info.value)
163126

164127
def test_forbidden_strategy_with_token_limit():
165128
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "eu_tax_chart.png")
@@ -358,34 +321,6 @@ def test_llm_timeout():
358321
result = extractor.extract(test_file_path, InvoiceContract)
359322
assert result is not None
360323

361-
def test_dynamic_json_parsing():
362-
"""Test dynamic JSON parsing with local Ollama model."""
363-
# Initialize components
364-
llm = LLM(model="ollama/deepseek-r1:1.5b")
365-
llm.set_dynamic(True) # Enable dynamic JSON parsing
366-
367-
document_loader = DocumentLoaderPyPdf()
368-
extractor = Extractor(document_loader=document_loader, llm=llm)
369-
370-
# Test content that should produce JSON response
371-
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
372-
373-
# Extract with dynamic parsing
374-
try:
375-
result = extractor.extract(test_file_path, InvoiceContract)
376-
377-
# Verify the result is an InvoiceContract instance
378-
assert isinstance(result, InvoiceContract)
379-
380-
# Verify invoice fields
381-
assert result.invoice_number is not None
382-
assert result.invoice_date is not None
383-
assert result.total_amount is not None
384-
assert isinstance(result.lines, list)
385-
386-
except Exception as e:
387-
pytest.fail(f"Dynamic JSON parsing test failed: {str(e)}")
388-
389324
def test_extract_with_default_backend():
390325
"""Test extraction using default LiteLLM backend"""
391326
# Arrange
@@ -407,8 +342,6 @@ def test_extract_with_default_backend():
407342
def test_extract_with_pydanticai_backend():
408343
"""Test extraction using PydanticAI backend if available"""
409344
try:
410-
import pydantic_ai
411-
412345
# Arrange
413346
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
414347

@@ -439,13 +372,12 @@ def test_extract_from_url_docling_and_gpt4o_mini():
439372
extractor = Extractor()
440373
extractor.load_document_loader(DocumentLoaderDocling())
441374
extractor.load_llm(get_lite_model())
442-
375+
443376
# Act: Extract the document using the specified URL and the HandbookContract
444-
result = extractor.extract(url, HandbookContract)
377+
result: HandbookContract = extractor.extract(url, HandbookContract)
445378

446-
# Assert: Verify that the extracted title matches the expected value.
447-
expected_title = "BCOBS 2A.1 Restriction on marketing or providing an optional product for which a fee is payable"
448-
assert result.title == expected_title
379+
# Check handbook data
380+
assert "FCA Handbook" in result.title, f"Expected title to contain 'FCA Handbook', but got: {result.title}"
449381

450382
def test_extract_from_multiple_sources():
451383
"""
@@ -480,4 +412,7 @@ class CombinedData(BaseModel):
480412
assert result.total_amount == 1125
481413

482414
# Check handbook data
483-
assert "FCA Handbook" in result.handbook_title, f"Expected title to contain 'FCA Handbook', but got: {result.handbook_title}"
415+
assert "FCA Handbook" in result.handbook_title, f"Expected title to contain 'FCA Handbook', but got: {result.handbook_title}"
416+
417+
if __name__ == "__main__":
418+
test_extract_with_invalid_file_path()

tests/test_ollama.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
from typing import Optional
33
from dotenv import load_dotenv
4+
import pytest
45
from extract_thinker import DocumentLoaderPyPdf
56
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling, DoclingConfig
67
from extract_thinker import Extractor
78
from extract_thinker import Contract
89
from extract_thinker import Classification
910
from extract_thinker import DocumentLoaderMarkItDown
11+
from extract_thinker.llm import LLM
1012
from extract_thinker.models.completion_strategy import CompletionStrategy
1113
from extract_thinker import SplittingStrategy
1214
from extract_thinker import Process
@@ -160,4 +162,32 @@ def test_extract_with_ollama_full_pipeline():
160162

161163
# Check each extracted item
162164
for item in result:
163-
assert isinstance(item, (VehicleRegistration, DriverLicenseContract))
165+
assert isinstance(item, (VehicleRegistration, DriverLicenseContract))
166+
167+
def test_dynamic_json_parsing():
168+
"""Test dynamic JSON parsing with local Ollama model."""
169+
# Initialize components
170+
llm = LLM(model="ollama/deepseek-r1:1.5b")
171+
llm.set_dynamic(True) # Enable dynamic JSON parsing
172+
173+
document_loader = DocumentLoaderPyPdf()
174+
extractor = Extractor(document_loader=document_loader, llm=llm)
175+
176+
# Test content that should produce JSON response
177+
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
178+
179+
# Extract with dynamic parsing
180+
try:
181+
result = extractor.extract(test_file_path, InvoiceContract)
182+
183+
# Verify the result is an InvoiceContract instance
184+
assert isinstance(result, InvoiceContract)
185+
186+
# Verify invoice fields
187+
assert result.invoice_number is not None
188+
assert result.invoice_date is not None
189+
assert result.total_amount is not None
190+
assert isinstance(result.lines, list)
191+
192+
except Exception as e:
193+
pytest.fail(f"Dynamic JSON parsing test failed: {str(e)}")

tests/test_process.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def setup_process_and_classifications():
6565

6666
def test_eager_splitting_strategy():
6767
"""Test eager splitting strategy with a multi-page document"""
68+
6869
# Arrange
6970
process, classifications = setup_process_and_classifications()
7071
process.load_splitter(ImageSplitter(get_big_model()))

0 commit comments

Comments
 (0)