Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ def extract_from_stream(
def classify_from_image(
self, image: Any, classifications: List[Classification]
):
# requires no content extraction from loader
# Encode the image using the utility function
encoded_image = encode_image(image)
content = {
"image": image,
"image": encoded_image,
}
return self._classify(content, classifications, image)

Expand Down Expand Up @@ -315,23 +316,24 @@ def _classify(
},
]

# Common classification structure for both image and non-image cases
classification_info = "\n".join(
f"{c.name}: {c.description} \n{self._add_classification_structure(c)}"
for c in classifications
)

if self.is_classify_image:
input_data = (
f"##Take the first image, and compare to the several images provided. Then classify according to the classification attached to the image\n"
f"##Take the last image, and compare to the several images provided. Then classify according to the classification attached to the image\n"
f"##Classifications\n{classification_info}\n"
+ "Output Example: \n"
+ "{\r\n\t\"name\": \"DMV Form\",\r\n\t\"confidence\": 8\r\n}"
+ "\n\n##ClassificationResponse JSON Output\n"
)

else:
input_data = (
f"##Content\n{content}\n##Classifications\n#if contract present, each field present increase confidence level\n"
+ "\n".join(
[
f"{c.name}: {c.description} \n{self._add_classification_structure(c)}"
for c in classifications
]
)
f"{classification_info}\n"
+ "#Don't use contract structure, just to help on the ClassificationResponse\nOutput Example: \n"
+ "{\r\n\t\"name\": \"DMV Form\",\r\n\t\"confidence\": 8\r\n}"
+ "\n\n##ClassificationResponse JSON Output\n"
Expand Down Expand Up @@ -378,6 +380,24 @@ def _classify(
f"Image required for classification '{classification.name}' but not found."
)

# Add the first image to be classified with context
if 'image' in content:
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "##classify",
},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64," + content['image']
},
},
],
})

response = self.llm.request(messages, ClassificationResponse)
else:
messages.append({"role": "user", "content": input_data})
Expand All @@ -391,12 +411,12 @@ def classify(
classifications: List[Classification],
image: bool = False,
):
document_loader = self.get_document_loader_for_file(input)
self.is_classify_image = image

if image:
document_loader.set_vision_mode(True)
return self.classify_from_image(input, classifications)

document_loader = self.get_document_loader_for_file(input)
if document_loader is None:
raise ValueError("No suitable document loader found for the input.")

Expand Down
2 changes: 1 addition & 1 deletion extract_thinker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def _classify_tree_async(

if classification.confidence < threshold:
raise ValueError(
f"Classification confidence {classification.confidence}"
f"Classification confidence {classification.confidence} "
f"for '{classification.name}' is below the threshold of {threshold}."
)

Expand Down
14 changes: 7 additions & 7 deletions tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,17 @@ def test_with_image():
"""Test classification using both consensus and higher order strategies with a threshold."""
process = setup_process_with_gpt4_extractor()

COMMON_CLASSIFICATIONS[0].contract = InvoiceContract
COMMON_CLASSIFICATIONS[1].contract = DriverLicense
COMMON_CLASSIFICATIONS[0].contract = DriverLicense
COMMON_CLASSIFICATIONS[1].contract = InvoiceContract

COMMON_CLASSIFICATIONS[0].image = INVOICE_FILE_PATH
COMMON_CLASSIFICATIONS[1].image = DRIVER_LICENSE_FILE_PATH
COMMON_CLASSIFICATIONS[0].image = DRIVER_LICENSE_FILE_PATH
COMMON_CLASSIFICATIONS[1].image = INVOICE_FILE_PATH

result = process.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS, strategy=ClassificationStrategy.CONSENSUS, image=True)
result = process.classify(DRIVER_LICENSE_FILE_PATH, COMMON_CLASSIFICATIONS, strategy=ClassificationStrategy.CONSENSUS, image=True)

assert result is not None
assert isinstance(result, ClassificationResponse)
assert result.name == "Invoice"
assert result.name == COMMON_CLASSIFICATIONS[0].name


def test_with_tree():
Expand Down Expand Up @@ -295,4 +295,4 @@ def test_mom_classification_layers():
# Assertions
assert final_result is not None, "MoM should produce a result"
assert final_result.name == "Credit Note", "Final classification should be Credit Note"
assert final_result.confidence >= 8, "Final confidence should be high"
assert final_result.confidence >= 8, "Final confidence should be high"
Loading