Skip to content

Commit f59efe7

Browse files
committed
updated easyocr loader and tests
1 parent e026048 commit f59efe7

File tree

2 files changed

+119
-39
lines changed

2 files changed

+119
-39
lines changed
Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Union
2+
from io import BytesIO
23
from PIL import Image
34
import numpy as np
45
from dataclasses import dataclass, field
@@ -15,19 +16,23 @@ class EasyOCRConfig:
1516
"""Configuration for EasyOCR loader.
1617
1718
Args:
18-
lang_list: List of languages to use for OCR
19-
gpu: Whether to use GPU acceleration
20-
download_enabled: Whether to download models automatically
21-
cache_ttl: Time-to-live for cache in seconds
19+
lang_list: List of languages to use for OCR. Defaults to ['en'].
20+
gpu: Whether to use GPU acceleration. Defaults to True.
21+
download_enabled: Whether to download models automatically. Defaults to True.
22+
cache_ttl: Time-to-live for cache in seconds. Defaults to 300.
2223
"""
2324
lang_list: List[str] = field(default_factory=lambda: ['en'])
2425
gpu: bool = True
2526
download_enabled: bool = True
2627
cache_ttl: int = 300
2728

2829
def __post_init__(self):
30+
"""Initialize EasyOCR reader with configuration settings and validation."""
2931
if not self.lang_list:
30-
self.lang_list = ['en']
32+
raise ValueError("lang_list must contain at least one language code.")
33+
if self.cache_ttl < 0:
34+
raise ValueError("cache_ttl must be non-negative.")
35+
3136
self.reader = easyocr.Reader(
3237
lang_list=self.lang_list,
3338
gpu=self.gpu,
@@ -39,29 +44,63 @@ class DocumentLoaderEasyOCR(CachedDocumentLoader):
3944
SUPPORTED_FORMATS = ["png", "jpg", "jpeg", "tiff", "tif", "webp"]
4045

4146
def __init__(self, config: EasyOCRConfig):
47+
"""Initialize the EasyOCR document loader.
48+
49+
Args:
50+
config: Configuration object for EasyOCR settings
51+
"""
4252
super().__init__()
4353
self.config = config
4454
self.cache = TTLCache(maxsize=128, ttl=self.config.cache_ttl)
55+
self.vision_mode = False
56+
57+
def can_handle(self, source: Union[str, BytesIO]) -> bool:
58+
"""Check if the loader can handle the given source.
59+
60+
Args:
61+
source: Path to a file or BytesIO stream
4562
46-
def can_handle(self, source: str) -> bool:
47-
if not isinstance(source, str) or '.' not in source:
48-
return False
49-
ext = source.split('.')[-1].lower()
50-
return ext in self.SUPPORTED_FORMATS
63+
Returns:
64+
bool: True if source is supported, False otherwise
65+
"""
66+
# Check if source is a BytesIO stream
67+
if isinstance(source, BytesIO):
68+
return True
69+
# Check if source is a file path and has a valid extension
70+
if isinstance(source, str) and '.' in source:
71+
# Extract the file extension (after the last '.') and convert to lowercase
72+
ext = source.split('.')[-1].lower()
73+
return ext in self.SUPPORTED_FORMATS
74+
return False
5175

52-
@cachedmethod(cache=attrgetter('cache'), key=lambda _, path: hashkey(path))
53-
def load(self, image_path: str) -> List[List[Dict[str, Any]]]:
54-
"""Load and process an image using EasyOCR.
76+
@cachedmethod(cache=attrgetter('cache'), key=lambda self, source: hashkey(source) if isinstance(source, str) else None)
77+
def load(self, source: Union[str, BytesIO]) -> List[List[Dict[str, Any]]]:
78+
"""Load and process an image (file path or BytesIO) using EasyOCR.
5579
5680
Args:
57-
image_path: Path to the image file
81+
source: Image file path or in-memory image stream (BytesIO)
5882
5983
Returns:
6084
List of pages, where each page contains a list of OCR results.
61-
Each OCR result is a dictionary with 'text', 'probability', and 'bbox' keys.
85+
Each OCR result is a dictionary with:
86+
- text: The extracted text
87+
- probability: Confidence score
88+
- bbox: Bounding box coordinates
6289
"""
63-
with Image.open(image_path).convert("RGB") as img:
64-
ocr_result = self.config.reader.readtext(np.array(img))
90+
# Convert image from file path into numpy array
91+
if isinstance(source, str):
92+
with Image.open(source).convert("RGB") as img:
93+
image_array = np.array(img)
94+
# Convert image from bytes stream into numpy array
95+
elif isinstance(source, BytesIO):
96+
source.seek(0)
97+
with Image.open(source).convert("RGB") as img:
98+
image_array = np.array(img)
99+
else:
100+
raise ValueError("Unsupported source type. Expected str or BytesIO.")
101+
102+
ocr_result = self.config.reader.readtext(image_array)
103+
# Loop through OCR results and structure them into a dictionary format
65104
page_data = []
66105
for bbox, text, prob in ocr_result:
67106
page_data.append({
@@ -70,3 +109,13 @@ def load(self, image_path: str) -> List[List[Dict[str, Any]]]:
70109
"probability": prob
71110
})
72111
return [page_data]
112+
113+
def can_handle_vision(self, source: Union[str, BytesIO]) -> bool:
114+
"""EasyOCR currently doesn't support vision mode in this loader."""
115+
return False
116+
117+
def set_vision_mode(self, enabled: bool = True):
118+
"""Disable vision mode, not supported here."""
119+
if enabled:
120+
raise ValueError("Vision mode is not supported in EasyOCR loader.")
121+
Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import pytest
3+
from io import BytesIO
34
import numpy as np
45
from extract_thinker.document_loader.document_loader_easy_ocr import DocumentLoaderEasyOCR, EasyOCRConfig
56
from .test_document_loader_base import BaseDocumentLoaderTest
@@ -25,45 +26,75 @@ def test_file_path(self):
2526
return os.path.join(current_dir, "test_images", "invoice.png")
2627

2728
def test_load_content(self, loader, test_file_path):
29+
"""Tests that the loader can process an image file and return OCR results
30+
in the expected structure"""
2831
content = loader.load(test_file_path)
2932
assert isinstance(content, list) and len(content) > 0
3033
for page in content:
34+
# Each page should be a list of OCR results
3135
assert isinstance(page, list)
3236
for item in page:
37+
# Each OCR result should be a dictionary
3338
assert isinstance(item, dict)
3439
assert all(key in item for key in ['text', 'probability', 'bbox'])
3540
assert isinstance(item['text'], str)
3641
assert isinstance(item['probability'], (float, np.float64))
3742
assert isinstance(item['bbox'], (list, tuple))
3843

39-
def test_can_handle_formats(self, loader, tmp_path):
40-
for fmt in loader.SUPPORTED_FORMATS:
41-
test_file = tmp_path / f"test.{fmt}"
42-
test_file.touch()
43-
assert loader.can_handle(str(test_file))
44+
def test_load_from_bytesio(self, loader, test_file_path):
45+
"""Tests that the loader can process an image provided as a BytesIO stream."""
46+
with open(test_file_path, "rb") as f:
47+
image_bytes = BytesIO(f.read())
48+
content = loader.load(image_bytes)
49+
assert isinstance(content, list) and len(content) > 0
50+
51+
def test_can_handle(self, loader, tmp_path):
52+
"""Tests that the loader correctly identifies supported and unsupported file formats"""
53+
# Supported extensions
54+
for ext in loader.SUPPORTED_FORMATS:
55+
f = tmp_path / f"file.{ext}"
56+
f.touch()
57+
assert loader.can_handle(str(f))
58+
# Unsupported extension
59+
assert not loader.can_handle(str(tmp_path / "file.abc"))
60+
# Missing extension
61+
assert not loader.can_handle(str(tmp_path / "file"))
62+
# BytesIO stream
63+
assert loader.can_handle(BytesIO(b"data"))
4464

45-
bad_file = tmp_path / "test.xyz"
46-
bad_file.touch()
47-
assert not loader.can_handle(str(bad_file))
65+
def test_vision_mode(self, loader):
66+
"""Test that vision mode is not supported"""
67+
# Vision mode should be disabled by default
68+
assert loader.vision_mode is False
69+
70+
# Attempting to enable vision mode should raise an error
71+
with pytest.raises(ValueError, match="Vision mode is not supported"):
72+
loader.set_vision_mode(True)
73+
74+
# Vision mode should still be False after failed attempt
75+
assert loader.vision_mode is False
76+
77+
# can_handle_vision should always return False
78+
assert loader.can_handle_vision("test.txt") is False
4879

4980
def test_language_configuration(self, test_file_path):
81+
"""test that the loader can handle english language"""
5082
loader = DocumentLoaderEasyOCR(EasyOCRConfig(lang_list=['en']))
5183
pages = loader.load(test_file_path)
5284
assert len(pages) > 0
53-
85+
"""test that the loader can handle multiple languages(english and spanish)"""
5486
loader = DocumentLoaderEasyOCR(EasyOCRConfig(lang_list=['en', 'es']))
5587
pages = loader.load(test_file_path)
5688
assert len(pages) > 0
5789

58-
def test_simple_initialization_easyocr(self):
59-
config = EasyOCRConfig(lang_list=["en"])
60-
loader = DocumentLoaderEasyOCR(config)
61-
current_dir = os.path.dirname(os.path.abspath(__file__))
62-
test_file = os.path.join(current_dir, "test_images", "invoice.png")
63-
pages = loader.load(test_file)
64-
assert isinstance(pages, list)
65-
assert len(pages) > 0
66-
assert isinstance(pages[0], list)
67-
assert isinstance(pages[0][0], dict)
68-
assert "text" in pages[0][0]
69-
assert isinstance(pages[0][0]["text"], str)
90+
def test_easyocr_config_validation(self):
91+
"""Test EasyOCRConfig validation"""
92+
# raise error if lang_list is empty
93+
with pytest.raises(ValueError, match="lang_list must contain at least one"):
94+
EasyOCRConfig(lang_list=[])
95+
# raise error if cache_ttl is negative
96+
with pytest.raises(ValueError, match="cache_ttl must be non-negative"):
97+
EasyOCRConfig(cache_ttl=-1)
98+
99+
100+

0 commit comments

Comments
 (0)