11import asyncio
22from typing import IO , Any , Dict , List , Optional , Union
3+ from extract_thinker .models .classification_response import ClassificationResponse
34from extract_thinker .models .classification_strategy import ClassificationStrategy
45from extract_thinker .models .doc_groups2 import DocGroups2
56from extract_thinker .models .splitting_strategy import SplittingStrategy
@@ -52,8 +53,10 @@ async def _classify_async(self, extractor: Extractor, file: str, classifications
5253 return await loop .run_in_executor (None , extractor .classify , file , classifications , image )
5354
5455 def classify (self , file : str , classifications , strategy : ClassificationStrategy = ClassificationStrategy .CONSENSUS , threshold : int = 9 , image : bool = False ) -> Optional [Classification ]:
56+ if not isinstance (threshold , int ) or threshold < 1 or threshold > 10 :
57+ raise ValueError ("Threshold must be an integer between 1 and 10" )
58+
5559 result = asyncio .run (self .classify_async (file , classifications , strategy , threshold , image ))
56-
5760 return result
5861
5962 async def classify_async (
@@ -64,28 +67,43 @@ async def classify_async(
6467 threshold : int = 9 ,
6568 image : str = False
6669 ) -> Optional [Classification ]:
70+ if not isinstance (threshold , int ) or threshold < 1 or threshold > 10 :
71+ raise ValueError ("Threshold must be an integer between 1 and 10" )
6772
6873 if isinstance (classifications , ClassificationTree ):
6974 return await self ._classify_tree_async (file , classifications , threshold , image )
7075
76+ # Try each layer of extractors until we get a valid result
7177 for extractor_group in self .extractor_groups :
72- group_classifications = await asyncio .gather (* (self ._classify_async (extractor , file , classifications , image ) for extractor in extractor_group ))
73-
74- # Implement different strategies
75- if strategy == ClassificationStrategy .CONSENSUS :
76- # Check if all classifications in the group are the same
77- if len (set (group_classifications )) == 1 :
78- return group_classifications [0 ]
79- elif strategy == ClassificationStrategy .HIGHER_ORDER :
80- # Pick the result with the highest confidence
81- return max (group_classifications , key = lambda c : c .confidence )
82- elif strategy == ClassificationStrategy .CONSENSUS_WITH_THRESHOLD :
83- if len (set (group_classifications )) == 1 :
84- maxResult = max (group_classifications , key = lambda c : c .confidence )
85- if maxResult .confidence >= threshold :
86- return maxResult
87-
88- raise ValueError ("No consensus could be reached on the classification of the document. Please try again with a different strategy or threshold." )
78+ group_classifications = await asyncio .gather (* (
79+ self ._classify_async (extractor , file , classifications , image )
80+ for extractor in extractor_group
81+ ))
82+
83+ try :
84+ # Attempt to get result based on strategy
85+ if strategy == ClassificationStrategy .CONSENSUS :
86+ if len (set (c .name for c in group_classifications )) == 1 :
87+ return group_classifications [0 ]
88+
89+ elif strategy == ClassificationStrategy .HIGHER_ORDER :
90+ return max (group_classifications , key = lambda c : c .confidence )
91+
92+ elif strategy == ClassificationStrategy .CONSENSUS_WITH_THRESHOLD :
93+ if len (set (c .name for c in group_classifications )) == 1 :
94+ if all (c .confidence >= threshold for c in group_classifications ):
95+ return group_classifications [0 ]
96+
97+ # If we get here, current layer didn't meet criteria - continue to next layer
98+ continue
99+
100+ except Exception as e :
101+ # If there's an error processing this layer, try the next one
102+ print (f"Layer failed with error: { str (e )} " )
103+ continue
104+
105+ # If we've tried all layers and none worked
106+ raise ValueError ("No consensus could be reached on the classification of the document across any layer. Please try again with a different strategy or threshold." )
89107
90108 async def _classify_tree_async (
91109 self ,
@@ -94,6 +112,9 @@ async def _classify_tree_async(
94112 threshold : float ,
95113 image : bool
96114 ) -> Optional [Classification ]:
115+ if not isinstance (threshold , (int , float )) or threshold < 1 or threshold > 10 :
116+ raise ValueError ("Threshold must be a number between 1 and 10" )
117+
97118 """
98119 Perform classification in a hierarchical, level-by-level approach.
99120 """
@@ -114,23 +135,23 @@ async def _classify_tree_async(
114135
115136 if classification .confidence < threshold :
116137 raise ValueError (
117- f"Classification confidence { classification .confidence } "
118- f"for '{ classification .classification } ' is below the threshold of { threshold } ."
138+ f"Classification confidence { classification .confidence } "
139+ f"for '{ classification .name } ' is below the threshold of { threshold } ."
119140 )
120141
121- best_classification = classification
142+ best_classification : ClassificationResponse = classification
122143
123144 matching_node = next (
124145 (
125- node for node in current_nodes
146+ node for node in current_nodes
126147 if node .classification .name == best_classification .name
127148 ),
128149 None
129150 )
130151
131152 if matching_node is None :
132153 raise ValueError (
133- f"No matching node found for classification '{ classification .classification } '."
154+ f"No matching node found for classification '{ classification .name } '."
134155 )
135156
136157 if matching_node .children :
0 commit comments