Skip to content

Commit ecf0267

Browse files
fix: add language to OCRAgentGoogleVision constructor (#3696)
This PR addresses issue #3659 by adding an optional `language` parameter to the `OCRAgentGoogleVision` class constructor. This parameter serves as a "language hint" for the `document_text_detection` method in the `ImageAnnotatorClient`. For more information on language hints, refer to the [Google Cloud Vision documentation](https://cloud.google.com/vision/docs/languages). **Default Behavior**: The language parameter defaults to None, allowing Google Cloud Vision to auto-detect the language, as recommended in their documentation. **Purpose**: This change is necessary because the `OCRAgent`'s `get_instance` method expects all `OCRAgent`s to include a language parameter in their constructors. **Context on Issue:** When trying to parse a PDF with `OCR_AGENT=unstructured.partition.utils.ocr_models.google_vision_ocr.OCRAgentGoogleVision`, an error occurs in the `get_instance` method. The method expects a `language` parameter, which the current `OCRAgentGoogleVision` constructor does not support, leading to a positional argument error. --------- Co-authored-by: Christine Straub <[email protected]>
1 parent 6ba376a commit ecf0267

File tree

4 files changed

+32
-19
lines changed

4 files changed

+32
-19
lines changed

Diff for: CHANGELOG.md

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
## 0.15.15-dev0
2+
3+
### Enhancements
4+
5+
### Features
6+
7+
### Fixes
8+
9+
* **Add language parameter to `OCRAgentGoogleVision`.** Introduces an optional language parameter in the `OCRAgentGoogleVision` constructor to serve as a language hint for `document_text_detection`. This ensures compatibility with the OCRAgent's `get_instance` method and resolves errors when parsing PDFs with Google Cloud Vision as the OCR agent.
10+
111
## 0.15.14
212

313
### Enhancements

Diff for: test_unstructured/partition/pdf_image/test_ocr.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import namedtuple
2+
from typing import Optional
23
from unittest.mock import patch
34

45
import numpy as np
@@ -226,12 +227,13 @@ def google_vision_client(google_vision_text_annotation):
226227
Response = namedtuple("Response", "full_text_annotation")
227228

228229
class FakeGoogleVisionClient:
229-
def document_text_detection(self, image):
230+
def document_text_detection(self, image, image_context):
230231
return Response(full_text_annotation=google_vision_text_annotation)
231232

232233
class OCRAgentFakeGoogleVision(OCRAgentGoogleVision):
233-
def __init__(self):
234+
def __init__(self, language: Optional[str] = None):
234235
self.client = FakeGoogleVisionClient()
236+
self.language = language
235237

236238
return OCRAgentFakeGoogleVision()
237239

@@ -249,7 +251,7 @@ def test_get_layout_from_image_google_vision(google_vision_client):
249251
image = Image.new("RGB", (100, 100))
250252

251253
ocr_agent = google_vision_client
252-
regions = ocr_agent.get_layout_from_image(image, ocr_languages="eng")
254+
regions = ocr_agent.get_layout_from_image(image)
253255
assert len(regions) == 1
254256
assert regions[0].text == "Hello World!"
255257
assert regions[0].source == Source.OCR_GOOGLEVISION
@@ -263,7 +265,7 @@ def test_get_layout_elements_from_image_google_vision(google_vision_client):
263265
image = Image.new("RGB", (100, 100))
264266

265267
ocr_agent = google_vision_client
266-
layout_elements = ocr_agent.get_layout_elements_from_image(image, ocr_languages="eng")
268+
layout_elements = ocr_agent.get_layout_elements_from_image(image)
267269
assert len(layout_elements) == 1
268270

269271

Diff for: unstructured/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.15.14" # pragma: no cover
1+
__version__ = "0.15.15-dev0" # pragma: no cover

Diff for: unstructured/partition/utils/ocr_models/google_vision_ocr.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
from io import BytesIO
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Optional
55

6-
from google.cloud.vision import Image, ImageAnnotatorClient, Paragraph, TextAnnotation
6+
from google.cloud.vision import Image, ImageAnnotatorClient, ImageContext, Paragraph, TextAnnotation
77

88
from unstructured.logger import logger, trace_logger
99
from unstructured.partition.utils.config import env_config
@@ -19,7 +19,8 @@
1919
class OCRAgentGoogleVision(OCRAgent):
2020
"""OCR service implementation for Google Vision API."""
2121

22-
def __init__(self) -> None:
22+
def __init__(self, language: Optional[str] = None) -> None:
23+
self.language = language
2324
client_options = {}
2425
api_endpoint = env_config.GOOGLEVISION_API_ENDPOINT
2526
if api_endpoint:
@@ -32,40 +33,40 @@ def __init__(self) -> None:
3233
def is_text_sorted(self) -> bool:
3334
return True
3435

35-
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
36+
def get_text_from_image(self, image: PILImage.Image) -> str:
37+
image_context = ImageContext(language_hints=[self.language]) if self.language else None
3638
with BytesIO() as buffer:
3739
image.save(buffer, format="PNG")
38-
response = self.client.document_text_detection(image=Image(content=buffer.getvalue()))
40+
response = self.client.document_text_detection(
41+
image=Image(content=buffer.getvalue()), image_context=image_context
42+
)
3943
document = response.full_text_annotation
4044
assert isinstance(document, TextAnnotation)
4145
return document.text
4246

43-
def get_layout_from_image(
44-
self, image: PILImage.Image, ocr_languages: str = "eng"
45-
) -> list[TextRegion]:
47+
def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]:
4648
trace_logger.detail("Processing entire page OCR with Google Vision API...")
49+
image_context = ImageContext(language_hints=[self.language]) if self.language else None
4750
with BytesIO() as buffer:
4851
image.save(buffer, format="PNG")
49-
response = self.client.document_text_detection(image=Image(content=buffer.getvalue()))
52+
response = self.client.document_text_detection(
53+
image=Image(content=buffer.getvalue()), image_context=image_context
54+
)
5055
document = response.full_text_annotation
5156
assert isinstance(document, TextAnnotation)
5257
regions = self._parse_regions(document)
5358
return regions
5459

55-
def get_layout_elements_from_image(
56-
self, image: PILImage.Image, ocr_languages: str = "eng"
57-
) -> list[LayoutElement]:
60+
def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]:
5861
from unstructured.partition.pdf_image.inference_utils import (
5962
build_layout_elements_from_ocr_regions,
6063
)
6164

6265
ocr_regions = self.get_layout_from_image(
6366
image,
64-
ocr_languages=ocr_languages,
6567
)
6668
ocr_text = self.get_text_from_image(
6769
image,
68-
ocr_languages=ocr_languages,
6970
)
7071
layout_elements = build_layout_elements_from_ocr_regions(
7172
ocr_regions=ocr_regions,

0 commit comments

Comments
 (0)