Skip to content

Commit 21e3871

Browse files
authored
Merge pull request #110 from enoch3712/107-mom-of-classification-group_classifications-value-it-is-replaced-every-time-the-layer-is-changed
classification layers fixed
2 parents cf082e2 + d0df385 commit 21e3871

File tree

2 files changed

+110
-26
lines changed

2 files changed

+110
-26
lines changed

extract_thinker/process.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from typing import IO, Any, Dict, List, Optional, Union
3+
from extract_thinker.models.classification_response import ClassificationResponse
34
from extract_thinker.models.classification_strategy import ClassificationStrategy
45
from extract_thinker.models.doc_groups2 import DocGroups2
56
from extract_thinker.models.splitting_strategy import SplittingStrategy
@@ -52,8 +53,10 @@ async def _classify_async(self, extractor: Extractor, file: str, classifications
5253
return await loop.run_in_executor(None, extractor.classify, file, classifications, image)
5354

5455
def classify(self, file: str, classifications, strategy: ClassificationStrategy = ClassificationStrategy.CONSENSUS, threshold: int = 9, image: bool = False) -> Optional[Classification]:
56+
if not isinstance(threshold, int) or threshold < 1 or threshold > 10:
57+
raise ValueError("Threshold must be an integer between 1 and 10")
58+
5559
result = asyncio.run(self.classify_async(file, classifications, strategy, threshold, image))
56-
5760
return result
5861

5962
async def classify_async(
@@ -64,28 +67,43 @@ async def classify_async(
6467
threshold: int = 9,
6568
image: str = False
6669
) -> Optional[Classification]:
70+
if not isinstance(threshold, int) or threshold < 1 or threshold > 10:
71+
raise ValueError("Threshold must be an integer between 1 and 10")
6772

6873
if isinstance(classifications, ClassificationTree):
6974
return await self._classify_tree_async(file, classifications, threshold, image)
7075

76+
# Try each layer of extractors until we get a valid result
7177
for extractor_group in self.extractor_groups:
72-
group_classifications = await asyncio.gather(*(self._classify_async(extractor, file, classifications, image) for extractor in extractor_group))
73-
74-
# Implement different strategies
75-
if strategy == ClassificationStrategy.CONSENSUS:
76-
# Check if all classifications in the group are the same
77-
if len(set(group_classifications)) == 1:
78-
return group_classifications[0]
79-
elif strategy == ClassificationStrategy.HIGHER_ORDER:
80-
# Pick the result with the highest confidence
81-
return max(group_classifications, key=lambda c: c.confidence)
82-
elif strategy == ClassificationStrategy.CONSENSUS_WITH_THRESHOLD:
83-
if len(set(group_classifications)) == 1:
84-
maxResult = max(group_classifications, key=lambda c: c.confidence)
85-
if maxResult.confidence >= threshold:
86-
return maxResult
87-
88-
raise ValueError("No consensus could be reached on the classification of the document. Please try again with a different strategy or threshold.")
78+
group_classifications = await asyncio.gather(*(
79+
self._classify_async(extractor, file, classifications, image)
80+
for extractor in extractor_group
81+
))
82+
83+
try:
84+
# Attempt to get result based on strategy
85+
if strategy == ClassificationStrategy.CONSENSUS:
86+
if len(set(c.name for c in group_classifications)) == 1:
87+
return group_classifications[0]
88+
89+
elif strategy == ClassificationStrategy.HIGHER_ORDER:
90+
return max(group_classifications, key=lambda c: c.confidence)
91+
92+
elif strategy == ClassificationStrategy.CONSENSUS_WITH_THRESHOLD:
93+
if len(set(c.name for c in group_classifications)) == 1:
94+
if all(c.confidence >= threshold for c in group_classifications):
95+
return group_classifications[0]
96+
97+
# If we get here, current layer didn't meet criteria - continue to next layer
98+
continue
99+
100+
except Exception as e:
101+
# If there's an error processing this layer, try the next one
102+
print(f"Layer failed with error: {str(e)}")
103+
continue
104+
105+
# If we've tried all layers and none worked
106+
raise ValueError("No consensus could be reached on the classification of the document across any layer. Please try again with a different strategy or threshold.")
89107

90108
async def _classify_tree_async(
91109
self,
@@ -94,6 +112,9 @@ async def _classify_tree_async(
94112
threshold: float,
95113
image: bool
96114
) -> Optional[Classification]:
115+
if not isinstance(threshold, (int, float)) or threshold < 1 or threshold > 10:
116+
raise ValueError("Threshold must be a number between 1 and 10")
117+
97118
"""
98119
Perform classification in a hierarchical, level-by-level approach.
99120
"""
@@ -114,23 +135,23 @@ async def _classify_tree_async(
114135

115136
if classification.confidence < threshold:
116137
raise ValueError(
117-
f"Classification confidence {classification.confidence} "
118-
f"for '{classification.classification}' is below the threshold of {threshold}."
138+
f"Classification confidence {classification.confidence}"
139+
f"for '{classification.name}' is below the threshold of {threshold}."
119140
)
120141

121-
best_classification = classification
142+
best_classification: ClassificationResponse = classification
122143

123144
matching_node = next(
124145
(
125-
node for node in current_nodes
146+
node for node in current_nodes
126147
if node.classification.name == best_classification.name
127148
),
128149
None
129150
)
130151

131152
if matching_node is None:
132153
raise ValueError(
133-
f"No matching node found for classification '{classification.classification}'."
154+
f"No matching node found for classification '{classification.name}'."
134155
)
135156

136157
if matching_node.children:

tests/test_classify.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
from dotenv import load_dotenv
44
from extract_thinker.document_loader.document_loader_aws_textract import DocumentLoaderAWSTextract
5+
from extract_thinker.document_loader.document_loader_txt import DocumentLoaderTxt
56
from extract_thinker.extractor import Extractor
67
from extract_thinker.models.classification_node import ClassificationNode
78
from extract_thinker.models.classification_tree import ClassificationTree
@@ -116,7 +117,7 @@ def test_classify_consensus():
116117
assert result is not None
117118
assert isinstance(result, ClassificationResponse)
118119
assert result.name == "Invoice"
119-
120+
120121

121122
def test_classify_higher_order():
122123
"""Test classification using higher order strategy."""
@@ -229,7 +230,69 @@ def test_with_tree():
229230
current_dir = os.path.dirname(os.path.abspath(__file__))
230231
pdf_path = os.path.join(current_dir, 'files','invoice.pdf')
231232

232-
result = process.classify(pdf_path, classification_tree, threshold=0.8)
233+
result = process.classify(pdf_path, classification_tree, threshold=7)
233234

234235
assert result is not None
235-
assert result.name == "Invoice"
236+
assert result.name == "Invoice"
237+
238+
def test_mom_classification_layers():
239+
"""Test Mixture of Models (MoM) classification with multiple layers."""
240+
# Arrange
241+
document_loader = DocumentLoaderTxt()
242+
243+
# Get test file path
244+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
245+
CREDIT_NOTE_PATH = os.path.join(CURRENT_DIR, "files", "ambiguous_credit_note.txt")
246+
247+
# Create ambiguous classifications
248+
test_classifications = [
249+
Classification(
250+
name="Receipt",
251+
description="A document showing payment received for goods or services, typically including items purchased, amounts, and payment method",
252+
contract=InvoiceContract
253+
),
254+
Classification(
255+
name="Credit Note",
256+
description="A document issued to reverse a previous transaction, showing returned items and credit amount, usually referencing an original invoice",
257+
contract=CreditNoteContract
258+
)
259+
]
260+
261+
# Initialize extractors with different models
262+
# Layer 1: Small models that might disagree
263+
gpt35_extractor = Extractor(document_loader)
264+
gpt35_extractor.load_llm("gpt-3.5-turbo")
265+
266+
claude_haiku_extractor = Extractor(document_loader)
267+
claude_haiku_extractor.load_llm("claude-3-haiku-20240307")
268+
269+
# Layer 2: More capable models for resolution
270+
gpt4_extractor = Extractor(document_loader)
271+
gpt4_extractor.load_llm("gpt-4o")
272+
sonnet_extractor = Extractor(document_loader)
273+
sonnet_extractor.load_llm("claude-3-5-sonnet-20241022")
274+
275+
# Create process with multiple layers
276+
process = Process()
277+
process.add_classify_extractor([
278+
[gpt35_extractor, claude_haiku_extractor], # Layer 1: Small models
279+
[gpt4_extractor, sonnet_extractor] # Layer 2: Resolution model
280+
])
281+
282+
# Test full MoM process (should resolve using Layer 2)
283+
final_result = process.classify(
284+
CREDIT_NOTE_PATH,
285+
test_classifications,
286+
strategy=ClassificationStrategy.CONSENSUS_WITH_THRESHOLD,
287+
threshold=8
288+
)
289+
290+
# Print results for debugging
291+
print("\nMoM Classification Results:")
292+
print(f"Final Classification: {final_result.name}")
293+
print(f"Confidence: {final_result.confidence}")
294+
295+
# Assertions
296+
assert final_result is not None, "MoM should produce a result"
297+
assert final_result.name == "Credit Note", "Final classification should be Credit Note"
298+
assert final_result.confidence >= 8, "Final confidence should be high"

0 commit comments

Comments
 (0)