1313from tests .models .invoice import CreditNoteContract , FinancialContract , InvoiceContract
1414from tests .models .driver_license import DriverLicense , IdentificationContract
1515from extract_thinker .global_models import get_lite_model , get_big_model
16+ import pytest
17+ from pydantic import BaseModel
1618
1719# Setup environment and common paths
1820load_dotenv ()
2729 Classification (name = "Invoice" , description = "This is an invoice" ),
2830]
2931
32+ # Dummy contracts for the large tree example
33+ class BankStatementContract (BaseModel ): pass
34+ class ContractAgreementContract (BaseModel ): pass
35+ class LegalNoticeContract (BaseModel ): pass
36+ class PassportContract (BaseModel ): pass
37+ class SalesInvoiceContract (BaseModel ): pass
38+ class PurchaseInvoiceContract (BaseModel ): pass
39+
3040def setup_extractors ():
3141 """Sets up and returns a list of configured extractors."""
3242 tesseract_path = os .getenv ("TESSERACT_PATH" )
@@ -271,6 +281,215 @@ def test_with_tree():
271281 assert result .classification .name == "Invoice"
272282 assert result .classification .description == "This is an invoice"
273283 assert result .classification .contract == InvoiceContract
284+ # Verify UUID matching worked (assuming financial_docs.children[0] is the Invoice node)
285+ expected_invoice_node = next (node for node in financial_docs .children if node .name == "Invoice" )
286+ assert result .classification .uuid == expected_invoice_node .classification .uuid
287+
288+
289+ def test_tree_classification_low_confidence ():
290+ """Test tree classification raises error when confidence is below threshold."""
291+ process = setup_process_with_gpt4_extractor ()
292+
293+ # Reuse the same tree structure from test_with_tree
294+ financial_docs = ClassificationNode (
295+ name = "Financial Document" ,
296+ classification = Classification (
297+ name = "Financial Document" ,
298+ description = "This is a financial document" ,
299+ contract = FinancialContract ,
300+ ),
301+ children = [
302+ ClassificationNode (
303+ name = "Invoice" ,
304+ classification = Classification (
305+ name = "Invoice" ,
306+ description = "This is an invoice" ,
307+ contract = InvoiceContract ,
308+ )
309+ ),
310+ ClassificationNode (
311+ name = "Credit Note" ,
312+ classification = Classification (
313+ name = "Credit Note" ,
314+ description = "This is a credit note" ,
315+ contract = CreditNoteContract ,
316+ )
317+ )
318+ ]
319+ )
320+ legal_docs = ClassificationNode (
321+ name = "Identity Documents" ,
322+ classification = Classification (
323+ name = "Identity Documents" ,
324+ description = "This is an identity document" ,
325+ contract = IdentificationContract ,
326+ ),
327+ children = [
328+ ClassificationNode (
329+ name = "Driver License" ,
330+ classification = Classification (
331+ name = "Driver License" ,
332+ description = "This is a driver license" ,
333+ contract = DriverLicense ,
334+ )
335+ )
336+ ]
337+ )
338+ classification_tree = ClassificationTree (
339+ nodes = [financial_docs , legal_docs ]
340+ )
341+
342+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
343+ pdf_path = os .path .join (current_dir , 'files' ,'invoice.pdf' )
344+
345+ # Set an impossibly high threshold
346+ high_threshold = 10.1 # Force confidence < threshold
347+
348+ # Assert that the correct ValueError is raised due to low confidence
349+ with pytest .raises (ValueError ):
350+ process .classify (pdf_path , classification_tree , threshold = high_threshold )
351+
352+
353+ def test_large_classification_tree ():
354+ """Test classification with a larger, multi-level tree."""
355+ process = setup_process_with_gpt4_extractor ()
356+
357+ # --- Define Tree Structure ---
358+
359+ # Level 3: Invoice Types
360+ sales_invoice_node = ClassificationNode (
361+ name = "Sales Invoice" ,
362+ classification = Classification (
363+ name = "Sales Invoice" ,
364+ description = "An invoice sent to a customer detailing products/services sold and amount due." ,
365+ contract = SalesInvoiceContract , # Specific contract for Sales Invoice
366+ )
367+ )
368+ purchase_invoice_node = ClassificationNode (
369+ name = "Purchase Invoice" ,
370+ classification = Classification (
371+ name = "Purchase Invoice" ,
372+ description = "An invoice received from a supplier detailing products/services bought." ,
373+ contract = PurchaseInvoiceContract , # Specific contract for Purchase Invoice
374+ )
375+ )
376+
377+ # Level 2: Financial Subtypes (Invoice node becomes a parent)
378+ invoice_node = ClassificationNode (
379+ name = "Invoice/Bill" , # More general name
380+ classification = Classification (
381+ name = "Invoice/Bill" ,
382+ description = "A general bill requesting payment for goods or services." ,
383+ contract = InvoiceContract , # General invoice contract remains here
384+ ),
385+ children = [sales_invoice_node , purchase_invoice_node ] # Add Level 3 children
386+ )
387+ credit_note_node = ClassificationNode (
388+ name = "Credit Note" ,
389+ classification = Classification (
390+ name = "Credit Note" ,
391+ description = "A document correcting a previous invoice or returning funds." ,
392+ contract = CreditNoteContract ,
393+ )
394+ )
395+ bank_statement_node = ClassificationNode (
396+ name = "Bank Statement" ,
397+ classification = Classification (
398+ name = "Bank Statement" ,
399+ description = "A summary of financial transactions occurring over a given period." ,
400+ contract = BankStatementContract ,
401+ )
402+ )
403+
404+ # Level 1: Financial Documents Root
405+ financial_docs = ClassificationNode (
406+ name = "Financial Document" ,
407+ classification = Classification (
408+ name = "Financial Document" ,
409+ description = "Documents related to financial transactions or status." ,
410+ contract = FinancialContract ,
411+ ),
412+ children = [invoice_node , credit_note_node , bank_statement_node ]
413+ )
414+
415+ # Level 2: Legal Subtypes
416+ contract_agreement_node = ClassificationNode (
417+ name = "Contract/Agreement" ,
418+ classification = Classification (
419+ name = "Contract/Agreement" ,
420+ description = "A legally binding agreement between parties." ,
421+ contract = ContractAgreementContract ,
422+ )
423+ )
424+ legal_notice_node = ClassificationNode (
425+ name = "Legal Notice" ,
426+ classification = Classification (
427+ name = "Legal Notice" ,
428+ description = "A formal notification required or permitted by law." ,
429+ contract = LegalNoticeContract ,
430+ )
431+ )
432+
433+ # Level 1: Legal Documents Root
434+ legal_docs = ClassificationNode (
435+ name = "Legal Document" ,
436+ classification = Classification (
437+ name = "Legal Document" ,
438+ description = "Documents with legal significance or implications." ,
439+ contract = None , # Example: Root might not have a specific contract
440+ ),
441+ children = [contract_agreement_node , legal_notice_node ]
442+ )
443+
444+ # Level 2: Identification Subtypes
445+ driver_license_node = ClassificationNode (
446+ name = "Driver License" ,
447+ classification = Classification (
448+ name = "Driver License" ,
449+ description = "Official document permitting an individual to operate a motor vehicle." ,
450+ contract = DriverLicense ,
451+ )
452+ )
453+ passport_node = ClassificationNode (
454+ name = "Passport" ,
455+ classification = Classification (
456+ name = "Passport" ,
457+ description = "Official government document certifying identity and citizenship for travel." ,
458+ contract = PassportContract ,
459+ )
460+ )
461+
462+ # Level 1: Identity Documents Root
463+ identity_docs = ClassificationNode (
464+ name = "Identification Document" ,
465+ classification = Classification (
466+ name = "Identification Document" ,
467+ description = "Documents used to verify a person's identity." ,
468+ contract = IdentificationContract ,
469+ ),
470+ children = [driver_license_node , passport_node ]
471+ )
472+
473+ # Top Level Tree
474+ classification_tree = ClassificationTree (
475+ nodes = [financial_docs , legal_docs , identity_docs ]
476+ )
477+
478+ # --- Perform Classification ---
479+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
480+ # Using the same invoice PDF as before
481+ pdf_path = os .path .join (current_dir , 'files' , 'invoice.pdf' )
482+
483+ # Classify with a reasonable threshold
484+ result = process .classify (pdf_path , classification_tree , threshold = 7 )
485+
486+ # --- Assert Results ---
487+ assert result is not None , "Classification should return a result."
488+ assert result .name == "Sales Invoice" , "The document should be classified as a Sales Invoice."
489+ assert result .classification is not None , "Result should contain classification details."
490+ # Verify it picked the correct node using UUID
491+ assert result .classification .uuid == sales_invoice_node .classification .uuid , "Result UUID should match the Sales Invoice node UUID."
492+ assert result .classification .contract == SalesInvoiceContract , "Result contract should be SalesInvoiceContract."
274493
275494
276495def test_mom_classification_layers ():
@@ -337,4 +556,8 @@ def test_mom_classification_layers():
337556 assert final_result .classification is not None
338557 assert final_result .classification .name == "Credit Note"
339558 assert final_result .classification .description == "A document issued to reverse a previous transaction, showing returned items and credit amount, usually referencing an original invoice"
340- assert final_result .classification .contract == CreditNoteContract
559+ assert final_result .classification .contract == CreditNoteContract
560+
561+
562+ if __name__ == "__main__" :
563+ test_tree_classification_low_confidence ()
0 commit comments