Skip to content

Commit 7c3e6f9

Browse files
authored
Merge pull request #318 from enoch3712/282-bug-classification-tree---use-match-index-for-classification-tree
[HOT FIX] Classification Tree with Id
2 parents 681e235 + c26cc7d commit 7c3e6f9

File tree

3 files changed

+235
-8
lines changed

3 files changed

+235
-8
lines changed

extract_thinker/models/classification.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Optional, Type
2-
from pydantic import BaseModel, field_validator
3-
from extract_thinker.models.contract import Contract
2+
from pydantic import BaseModel, field_validator, Field
3+
from uuid import UUID, uuid4
44
import os
55

66
class Classification(BaseModel):
@@ -10,23 +10,20 @@ class Classification(BaseModel):
1010
extraction_contract: Optional[Type] = None
1111
image: Optional[str] = None
1212
extractor: Optional[Any] = None
13+
uuid: UUID = Field(default_factory=uuid4)
1314

1415
@field_validator('contract', mode='before')
1516
def validate_contract(cls, v):
1617
if v is not None:
1718
if not isinstance(v, type):
1819
raise ValueError('contract must be a type')
19-
if not issubclass(v, Contract):
20-
raise ValueError('contract must be a subclass of Contract')
2120
return v
2221

2322
@field_validator('extraction_contract', mode='before')
2423
def validate_extraction_contract(cls, v):
2524
if v is not None:
2625
if not isinstance(v, type):
2726
raise ValueError('extraction_contract must be a type')
28-
if not issubclass(v, Contract):
29-
raise ValueError('extraction_contract must be a subclass of Contract')
3027
return v
3128

3229
def set_image(self, image_path: str):

extract_thinker/process.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ async def _classify_tree_async(
152152
image=image
153153
)
154154

155+
# Handle cases where classification fails at this level
156+
if classification_response is None:
157+
raise ValueError(
158+
"Classification failed at the current level. Could not determine a match."
159+
)
160+
155161
if classification_response.confidence < threshold:
156162
raise ValueError(
157163
f"Classification confidence {classification_response.confidence} "
@@ -160,10 +166,11 @@ async def _classify_tree_async(
160166

161167
best_classification = classification_response
162168

169+
# Use UUID for robust matching instead of name
163170
matching_node = next(
164171
(
165172
node for node in current_nodes
166-
if node.classification.name == best_classification.name
173+
if node.classification.uuid == best_classification.classification.uuid
167174
),
168175
None
169176
)

tests/test_classify.py

Lines changed: 224 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from tests.models.invoice import CreditNoteContract, FinancialContract, InvoiceContract
1414
from tests.models.driver_license import DriverLicense, IdentificationContract
1515
from extract_thinker.global_models import get_lite_model, get_big_model
16+
import pytest
17+
from pydantic import BaseModel
1618

1719
# Setup environment and common paths
1820
load_dotenv()
@@ -27,6 +29,14 @@
2729
Classification(name="Invoice", description="This is an invoice"),
2830
]
2931

32+
# Dummy contracts for the large tree example
33+
class BankStatementContract(BaseModel): pass
34+
class ContractAgreementContract(BaseModel): pass
35+
class LegalNoticeContract(BaseModel): pass
36+
class PassportContract(BaseModel): pass
37+
class SalesInvoiceContract(BaseModel): pass
38+
class PurchaseInvoiceContract(BaseModel): pass
39+
3040
def setup_extractors():
3141
"""Sets up and returns a list of configured extractors."""
3242
tesseract_path = os.getenv("TESSERACT_PATH")
@@ -271,6 +281,215 @@ def test_with_tree():
271281
assert result.classification.name == "Invoice"
272282
assert result.classification.description == "This is an invoice"
273283
assert result.classification.contract == InvoiceContract
284+
# Verify UUID matching worked (assuming financial_docs.children[0] is the Invoice node)
285+
expected_invoice_node = next(node for node in financial_docs.children if node.name == "Invoice")
286+
assert result.classification.uuid == expected_invoice_node.classification.uuid
287+
288+
289+
def test_tree_classification_low_confidence():
290+
"""Test tree classification raises error when confidence is below threshold."""
291+
process = setup_process_with_gpt4_extractor()
292+
293+
# Reuse the same tree structure from test_with_tree
294+
financial_docs = ClassificationNode(
295+
name="Financial Document",
296+
classification=Classification(
297+
name="Financial Document",
298+
description="This is a financial document",
299+
contract=FinancialContract,
300+
),
301+
children=[
302+
ClassificationNode(
303+
name="Invoice",
304+
classification=Classification(
305+
name="Invoice",
306+
description="This is an invoice",
307+
contract=InvoiceContract,
308+
)
309+
),
310+
ClassificationNode(
311+
name="Credit Note",
312+
classification=Classification(
313+
name="Credit Note",
314+
description="This is a credit note",
315+
contract=CreditNoteContract,
316+
)
317+
)
318+
]
319+
)
320+
legal_docs = ClassificationNode(
321+
name="Identity Documents",
322+
classification=Classification(
323+
name="Identity Documents",
324+
description="This is an identity document",
325+
contract=IdentificationContract,
326+
),
327+
children=[
328+
ClassificationNode(
329+
name="Driver License",
330+
classification=Classification(
331+
name="Driver License",
332+
description="This is a driver license",
333+
contract=DriverLicense,
334+
)
335+
)
336+
]
337+
)
338+
classification_tree = ClassificationTree(
339+
nodes=[financial_docs, legal_docs]
340+
)
341+
342+
current_dir = os.path.dirname(os.path.abspath(__file__))
343+
pdf_path = os.path.join(current_dir, 'files','invoice.pdf')
344+
345+
# Set an impossibly high threshold
346+
high_threshold = 10.1 # Force confidence < threshold
347+
348+
# Assert that the correct ValueError is raised due to low confidence
349+
with pytest.raises(ValueError):
350+
process.classify(pdf_path, classification_tree, threshold=high_threshold)
351+
352+
353+
def test_large_classification_tree():
354+
"""Test classification with a larger, multi-level tree."""
355+
process = setup_process_with_gpt4_extractor()
356+
357+
# --- Define Tree Structure ---
358+
359+
# Level 3: Invoice Types
360+
sales_invoice_node = ClassificationNode(
361+
name="Sales Invoice",
362+
classification=Classification(
363+
name="Sales Invoice",
364+
description="An invoice sent to a customer detailing products/services sold and amount due.",
365+
contract=SalesInvoiceContract, # Specific contract for Sales Invoice
366+
)
367+
)
368+
purchase_invoice_node = ClassificationNode(
369+
name="Purchase Invoice",
370+
classification=Classification(
371+
name="Purchase Invoice",
372+
description="An invoice received from a supplier detailing products/services bought.",
373+
contract=PurchaseInvoiceContract, # Specific contract for Purchase Invoice
374+
)
375+
)
376+
377+
# Level 2: Financial Subtypes (Invoice node becomes a parent)
378+
invoice_node = ClassificationNode(
379+
name="Invoice/Bill", # More general name
380+
classification=Classification(
381+
name="Invoice/Bill",
382+
description="A general bill requesting payment for goods or services.",
383+
contract=InvoiceContract, # General invoice contract remains here
384+
),
385+
children=[sales_invoice_node, purchase_invoice_node] # Add Level 3 children
386+
)
387+
credit_note_node = ClassificationNode(
388+
name="Credit Note",
389+
classification=Classification(
390+
name="Credit Note",
391+
description="A document correcting a previous invoice or returning funds.",
392+
contract=CreditNoteContract,
393+
)
394+
)
395+
bank_statement_node = ClassificationNode(
396+
name="Bank Statement",
397+
classification=Classification(
398+
name="Bank Statement",
399+
description="A summary of financial transactions occurring over a given period.",
400+
contract=BankStatementContract,
401+
)
402+
)
403+
404+
# Level 1: Financial Documents Root
405+
financial_docs = ClassificationNode(
406+
name="Financial Document",
407+
classification=Classification(
408+
name="Financial Document",
409+
description="Documents related to financial transactions or status.",
410+
contract=FinancialContract,
411+
),
412+
children=[invoice_node, credit_note_node, bank_statement_node]
413+
)
414+
415+
# Level 2: Legal Subtypes
416+
contract_agreement_node = ClassificationNode(
417+
name="Contract/Agreement",
418+
classification=Classification(
419+
name="Contract/Agreement",
420+
description="A legally binding agreement between parties.",
421+
contract=ContractAgreementContract,
422+
)
423+
)
424+
legal_notice_node = ClassificationNode(
425+
name="Legal Notice",
426+
classification=Classification(
427+
name="Legal Notice",
428+
description="A formal notification required or permitted by law.",
429+
contract=LegalNoticeContract,
430+
)
431+
)
432+
433+
# Level 1: Legal Documents Root
434+
legal_docs = ClassificationNode(
435+
name="Legal Document",
436+
classification=Classification(
437+
name="Legal Document",
438+
description="Documents with legal significance or implications.",
439+
contract=None, # Example: Root might not have a specific contract
440+
),
441+
children=[contract_agreement_node, legal_notice_node]
442+
)
443+
444+
# Level 2: Identification Subtypes
445+
driver_license_node = ClassificationNode(
446+
name="Driver License",
447+
classification=Classification(
448+
name="Driver License",
449+
description="Official document permitting an individual to operate a motor vehicle.",
450+
contract=DriverLicense,
451+
)
452+
)
453+
passport_node = ClassificationNode(
454+
name="Passport",
455+
classification=Classification(
456+
name="Passport",
457+
description="Official government document certifying identity and citizenship for travel.",
458+
contract=PassportContract,
459+
)
460+
)
461+
462+
# Level 1: Identity Documents Root
463+
identity_docs = ClassificationNode(
464+
name="Identification Document",
465+
classification=Classification(
466+
name="Identification Document",
467+
description="Documents used to verify a person's identity.",
468+
contract=IdentificationContract,
469+
),
470+
children=[driver_license_node, passport_node]
471+
)
472+
473+
# Top Level Tree
474+
classification_tree = ClassificationTree(
475+
nodes=[financial_docs, legal_docs, identity_docs]
476+
)
477+
478+
# --- Perform Classification ---
479+
current_dir = os.path.dirname(os.path.abspath(__file__))
480+
# Using the same invoice PDF as before
481+
pdf_path = os.path.join(current_dir, 'files', 'invoice.pdf')
482+
483+
# Classify with a reasonable threshold
484+
result = process.classify(pdf_path, classification_tree, threshold=7)
485+
486+
# --- Assert Results ---
487+
assert result is not None, "Classification should return a result."
488+
assert result.name == "Sales Invoice", "The document should be classified as a Sales Invoice."
489+
assert result.classification is not None, "Result should contain classification details."
490+
# Verify it picked the correct node using UUID
491+
assert result.classification.uuid == sales_invoice_node.classification.uuid, "Result UUID should match the Sales Invoice node UUID."
492+
assert result.classification.contract == SalesInvoiceContract, "Result contract should be SalesInvoiceContract."
274493

275494

276495
def test_mom_classification_layers():
@@ -337,4 +556,8 @@ def test_mom_classification_layers():
337556
assert final_result.classification is not None
338557
assert final_result.classification.name == "Credit Note"
339558
assert final_result.classification.description == "A document issued to reverse a previous transaction, showing returned items and credit amount, usually referencing an original invoice"
340-
assert final_result.classification.contract == CreditNoteContract
559+
assert final_result.classification.contract == CreditNoteContract
560+
561+
562+
if __name__ == "__main__":
563+
test_tree_classification_low_confidence()

0 commit comments

Comments
 (0)