Skip to content

Commit fd24ba3

Browse files
authored
Merge pull request #278 from enoch3712/253-merge-action---allow-multiple-sources-to-be-added
extractor extract multiple sources added with tests
2 parents c764896 + 68ed539 commit fd24ba3

File tree

2 files changed

+101
-18
lines changed

2 files changed

+101
-18
lines changed

extract_thinker/extractor.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,28 @@ def set_skip_loading(self, skip: bool = True) -> None:
150150

151151
def extract(
152152
self,
153-
source: Union[str, IO, list],
154-
response_model: type[BaseModel],
153+
source: Union[str, IO, List[Union[str, IO]]],
154+
response_model: Type[BaseModel],
155155
vision: bool = False,
156156
content: Optional[str] = None,
157157
completion_strategy: Optional[CompletionStrategy] = CompletionStrategy.FORBIDDEN
158158
) -> Any:
159159
"""
160-
Extract information from the provided source.
160+
Extract information from one or more sources.
161+
162+
If source is a list, it loads each one, converts each to a universal format, and
163+
merges them as if they were a single document. This merged content is then passed
164+
to the extraction logic to produce a final result.
165+
166+
Args:
167+
source: A single file path/stream or a list of them.
168+
response_model: A Pydantic model class for validating the extracted data.
169+
vision: Whether to use vision mode (affecting how content is processed).
170+
content: Optional extra content to prepend to the merged content.
171+
completion_strategy: Strategy for handling completions.
172+
173+
Returns:
174+
The parsed result from the LLM as validated by response_model.
161175
"""
162176
self._validate_dependencies(response_model, vision)
163177
self.extra_content = content
@@ -173,18 +187,53 @@ def extract(
173187
return self.extract_with_strategy(source, response_model, vision, completion_strategy)
174188

175189
try:
176-
if self._skip_loading:
177-
# Skip loading if flag is set (content from splitting)
178-
unified_content = self._map_to_universal_format(source, vision)
190+
if isinstance(source, list):
191+
all_contents = []
192+
for src in source:
193+
loader = self.get_document_loader(src)
194+
if loader is None:
195+
raise ValueError(f"No suitable document loader found for source: {src}")
196+
# Load the content (e.g. text, images, metadata)
197+
loaded = loader.load(src)
198+
# Map to a universal format that your extraction logic understands.
199+
universal = self._map_to_universal_format(loaded, vision)
200+
all_contents.append(universal)
201+
202+
# Merge the text contents with a clear separator.
203+
merged_text = "\n\n--- Document Separator ---\n\n".join(
204+
item.get("content", "") for item in all_contents
205+
)
206+
# Merge all image lists into one.
207+
merged_images = []
208+
for item in all_contents:
209+
merged_images.extend(item.get("images", []))
210+
211+
merged_content = {
212+
"content": merged_text,
213+
"images": merged_images,
214+
"metadata": {"num_documents": len(all_contents)}
215+
}
216+
217+
# Optionally, prepend any extra content provided by the caller.
218+
if content:
219+
merged_content["content"] = content + "\n\n" + merged_content["content"]
220+
221+
return self._extract(merged_content, response_model, vision)
179222
else:
180-
# Normal loading path
181-
loader = self.get_document_loader(source)
182-
if not loader:
183-
raise ValueError("No suitable document loader found for the input.")
184-
loaded_content = loader.load(source)
185-
unified_content = self._map_to_universal_format(loaded_content, vision)
186-
187-
return self._extract(unified_content, response_model, vision)
223+
# Single source; use existing behavior.
224+
if self._skip_loading:
225+
# Skip loading if flag is set (content from splitting)
226+
unified_content = self._map_to_universal_format(source, vision)
227+
else:
228+
# Normal loading path
229+
loader = self.get_document_loader(source)
230+
if not loader:
231+
raise ValueError("No suitable document loader found for the input.")
232+
loaded_content = loader.load(source)
233+
unified_content = self._map_to_universal_format(loaded_content, vision)
234+
235+
return self._extract(unified_content, response_model, vision)
236+
188237
except IncompleteOutputException as e:
189238
raise ValueError("Incomplete output received and FORBIDDEN strategy is set") from e
190239
except Exception as e:

tests/test_extractor.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import asyncio
22
import os
33
from dotenv import load_dotenv
4-
from extract_thinker.document_loader.document_loader_pdfplumber import DocumentLoaderPdfPlumber
54
from extract_thinker.extractor import Extractor
65
from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract
76
from extract_thinker.document_loader.document_loader_pypdf import DocumentLoaderPyPdf
87
from extract_thinker.llm import LLM, LLMEngine
98
from extract_thinker.models.completion_strategy import CompletionStrategy
10-
from extract_thinker.models.contract import Contract
119
from tests.models.invoice import InvoiceContract
1210
from tests.models.ChartWithContent import ChartWithContent
1311
from tests.models.page_contract import ReportContract
@@ -16,9 +14,10 @@
1614
import pytest
1715
import numpy as np
1816
from litellm import embedding
19-
from extract_thinker.document_loader.document_loader_docling import DoclingConfig, DocumentLoaderDocling
17+
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling
2018
from tests.models.handbook_contract import HandbookContract
2119
from extract_thinker.global_models import get_lite_model, get_big_model
20+
from pydantic import BaseModel, Field
2221

2322

2423
load_dotenv()
@@ -446,4 +445,39 @@ def test_extract_from_url_docling_and_gpt4o_mini():
446445

447446
# Assert: Verify that the extracted title matches the expected value.
448447
expected_title = "BCOBS 2A.1 Restriction on marketing or providing an optional product for which a fee is payable"
449-
assert result.title == expected_title
448+
assert result.title == expected_title
449+
450+
def test_extract_from_multiple_sources():
451+
"""
452+
Test extracting from multiple sources (PDF and URL) in a single call.
453+
Combines invoice data with handbook data using DocumentLoaderDocling.
454+
"""
455+
# Arrange
456+
pdf_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
457+
url = "https://www.handbook.fca.org.uk/handbook/BCOBS/2A/?view=chapter"
458+
459+
extractor = Extractor()
460+
docling_loader = DocumentLoaderDocling()
461+
extractor.load_document_loader(docling_loader)
462+
extractor.load_llm(get_big_model())
463+
464+
class CombinedData(BaseModel):
465+
invoice_number: str
466+
invoice_date: str
467+
total_amount: float
468+
handbook_title: str = Field(alias="title of the url, and not the invoice")
469+
470+
# Act
471+
result: CombinedData = extractor.extract(
472+
[pdf_path, url],
473+
CombinedData,
474+
)
475+
476+
# Assert
477+
# Check invoice data
478+
assert result.invoice_number == "00012"
479+
assert result.invoice_date == "1/30/23"
480+
assert result.total_amount == 1125
481+
482+
# Check handbook data
483+
assert "FCA Handbook" in result.handbook_title, f"Expected title to contain 'FCA Handbook', but got: {result.handbook_title}"

0 commit comments

Comments
 (0)