Skip to content

Commit cecc2a6

Browse files
claudevdmClaude
andauthored
Add image embedding support to ml/rag (#37628)
* image embeddings. * comments. * lint. * Add pillow to default requires. * update images * lint * mypy --------- Co-authored-by: Claude <cvandermerwe@google.com>
1 parent f73bd6a commit cecc2a6

26 files changed

+1138
-545
lines changed

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any
2323
from typing import Optional
2424
from typing import Union
25+
from typing import cast
2526

2627
from google import genai
2728
from google.genai import errors
@@ -73,7 +74,7 @@ def generate_from_string(
7374
call.
7475
"""
7576
return model.models.generate_content(
76-
model=model_name, contents=batch, **inference_args)
77+
model=model_name, contents=cast(Any, batch), **inference_args)
7778

7879

7980
def generate_image_from_strings_and_images(
@@ -96,7 +97,7 @@ def generate_image_from_strings_and_images(
9697
call.
9798
"""
9899
return model.models.generate_content(
99-
model=model_name, contents=batch, **inference_args)
100+
model=model_name, contents=cast(Any, batch), **inference_args)
100101

101102

102103
class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult,

sdks/python/apache_beam/ml/rag/embeddings/base_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from apache_beam.ml.rag.embeddings.base import create_text_adapter
2020
from apache_beam.ml.rag.types import Chunk
2121
from apache_beam.ml.rag.types import Content
22+
from apache_beam.ml.rag.types import EmbeddableItem
2223
from apache_beam.ml.rag.types import Embedding
2324

2425

@@ -89,5 +90,41 @@ def test_adapter_output_conversion(self):
8990
self.assertListEqual(embeddings, expected)
9091

9192

93+
class ImageEmbeddableItemTest(unittest.TestCase):
94+
def test_from_image_str(self):
95+
item = EmbeddableItem.from_image('gs://bucket/img.jpg', id='img1')
96+
self.assertEqual(item.content.image, 'gs://bucket/img.jpg')
97+
self.assertIsNone(item.content.text)
98+
self.assertEqual(item.id, 'img1')
99+
100+
def test_from_image_bytes(self):
101+
data = b'\x89PNG\r\n'
102+
item = EmbeddableItem.from_image(data, id='img2')
103+
self.assertEqual(item.content.image, data)
104+
self.assertIsNone(item.content.text)
105+
106+
def test_from_image_with_metadata(self):
107+
item = EmbeddableItem.from_image(
108+
'path/to/img.jpg', id='img3', metadata={'source': 'camera'})
109+
self.assertEqual(item.metadata, {'source': 'camera'})
110+
self.assertEqual(item.content.image, 'path/to/img.jpg')
111+
112+
113+
class ContentStringTest(unittest.TestCase):
114+
def test_text_content(self):
115+
item = EmbeddableItem(content=Content(text="hello"), id="1")
116+
self.assertEqual(item.content_string, "hello")
117+
118+
def test_image_uri_content(self):
119+
item = EmbeddableItem.from_image('gs://bucket/img.jpg', id='img1')
120+
self.assertEqual(item.content_string, 'gs://bucket/img.jpg')
121+
122+
def test_image_bytes_raises(self):
123+
item = EmbeddableItem.from_image(b'\x89PNG\r\n', id='img2')
124+
with self.assertRaisesRegex(ValueError,
125+
"EmbeddableItem does not contain.*"):
126+
item.content_string
127+
128+
92129
if __name__ == '__main__':
93130
unittest.main()

sdks/python/apache_beam/ml/rag/embeddings/huggingface.py

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616

1717
"""RAG-specific embedding implementations using HuggingFace models."""
1818

19+
import io
20+
from collections.abc import Sequence
1921
from typing import Optional
2022

2123
import apache_beam as beam
24+
from apache_beam.io.filesystems import FileSystems
2225
from apache_beam.ml.inference.base import RunInference
26+
from apache_beam.ml.rag.embeddings.base import _add_embedding_fn
2327
from apache_beam.ml.rag.embeddings.base import create_text_adapter
2428
from apache_beam.ml.rag.types import EmbeddableItem
2529
from apache_beam.ml.transforms.base import EmbeddingsManager
30+
from apache_beam.ml.transforms.base import EmbeddingTypeAdapter
31+
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
2632
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
2733
from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler
2834

@@ -31,19 +37,35 @@
3137
except ImportError:
3238
SentenceTransformer = None
3339

40+
try:
41+
from PIL import Image as PILImage
42+
except ImportError:
43+
PILImage = None # type: ignore[assignment]
44+
3445

3546
class HuggingfaceTextEmbeddings(EmbeddingsManager):
3647
def __init__(
3748
self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs):
38-
"""Utilizes huggingface SentenceTransformer embeddings for RAG pipeline.
39-
40-
Args:
41-
model_name: Name of the sentence-transformers model to use
42-
max_seq_length: Maximum sequence length for the model
43-
**kwargs: Additional arguments passed to
44-
:class:`~apache_beam.ml.transforms.base.EmbeddingsManager`
45-
constructor including ModelHandler arguments
46-
"""
49+
"""HuggingFace text embeddings for RAG pipelines.
50+
51+
Args:
52+
model_name: Name of the sentence-transformers model to use.
53+
max_seq_length: Maximum sequence length for the model.
54+
**kwargs: Additional arguments passed to
55+
:class:`~apache_beam.ml.transforms.base.EmbeddingsManager`,
56+
including:
57+
58+
- ``load_model_args``: dict passed to
59+
``SentenceTransformer()`` constructor
60+
(e.g. ``device``, ``cache_folder``).
61+
- ``min_batch_size`` / ``max_batch_size``:
62+
Control batching for inference.
63+
- ``large_model``: If True, share the model
64+
across processes to reduce memory usage.
65+
- ``inference_args``: dict passed to
66+
``model.encode()``
67+
(e.g. ``normalize_embeddings``).
68+
"""
4769
if not SentenceTransformer:
4870
raise ImportError(
4971
"sentence-transformers is required to use "
@@ -73,3 +95,103 @@ def get_ptransform_for_processing(
7395
return RunInference(
7496
model_handler=_TextEmbeddingHandler(self),
7597
inference_args=self.inference_args).with_output_types(EmbeddableItem)
98+
99+
100+
def _extract_images(items: Sequence[EmbeddableItem]) -> list:
101+
"""Extract images from items and convert to PIL.Image objects.
102+
103+
Supports raw bytes, local file paths, and remote URIs
104+
(e.g. gs://, s3://) via Beam's FileSystems.
105+
"""
106+
images = []
107+
for item in items:
108+
if not item.content.image:
109+
raise ValueError(
110+
"Expected image content in "
111+
f"{type(item).__name__} {item.id}, "
112+
"got None")
113+
img_data = item.content.image
114+
if isinstance(img_data, bytes):
115+
img = PILImage.open(io.BytesIO(img_data))
116+
else:
117+
with FileSystems.open(img_data, 'rb') as f:
118+
img = PILImage.open(f)
119+
img.load()
120+
images.append(img.convert('RGB'))
121+
return images
122+
123+
124+
def _create_hf_image_adapter(
125+
) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]:
126+
"""Creates adapter for HuggingFace image embedding.
127+
128+
Extracts content.image from EmbeddableItems and converts
129+
to PIL.Image objects. Supports both raw bytes and file paths.
130+
131+
Returns:
132+
EmbeddingTypeAdapter for HuggingFace image embedding.
133+
"""
134+
return EmbeddingTypeAdapter(
135+
input_fn=_extract_images, output_fn=_add_embedding_fn)
136+
137+
138+
class HuggingfaceImageEmbeddings(EmbeddingsManager):
139+
def __init__(
140+
self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs):
141+
"""HuggingFace image embeddings for RAG pipelines.
142+
143+
Generates embeddings for images using sentence-transformers
144+
models that support image input (e.g. clip-ViT-B-32).
145+
146+
Args:
147+
model_name: Name of the sentence-transformers model.
148+
Must be an image-text model. See
149+
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models
150+
max_seq_length: Maximum sequence length for the model
151+
if applicable.
152+
**kwargs: Additional arguments passed to
153+
:class:`~apache_beam.ml.transforms.base.EmbeddingsManager`,
154+
including:
155+
156+
- ``load_model_args``: dict passed to
157+
``SentenceTransformer()`` constructor
158+
(e.g. ``device``, ``cache_folder``,
159+
``trust_remote_code``).
160+
- ``min_batch_size`` / ``max_batch_size``:
161+
Control batching for inference.
162+
- ``large_model``: If True, share the model
163+
across processes to reduce memory usage.
164+
- ``inference_args``: dict passed to
165+
``model.encode()``
166+
(e.g. ``normalize_embeddings``).
167+
"""
168+
if not SentenceTransformer:
169+
raise ImportError(
170+
"sentence-transformers is required to use "
171+
"HuggingfaceImageEmbeddings. "
172+
"Please install it with `pip install sentence-transformers`.")
173+
if not PILImage:
174+
raise ImportError(
175+
"Pillow is required to use HuggingfaceImageEmbeddings. "
176+
"Please install it with `pip install pillow`.")
177+
super().__init__(type_adapter=_create_hf_image_adapter(), **kwargs)
178+
self.model_name = model_name
179+
self.max_seq_length = max_seq_length
180+
self.model_class = SentenceTransformer
181+
182+
def get_model_handler(self):
183+
"""Returns model handler configured with RAG adapter."""
184+
return _SentenceTransformerModelHandler(
185+
model_class=self.model_class,
186+
max_seq_length=self.max_seq_length,
187+
model_name=self.model_name,
188+
load_model_args=self.load_model_args,
189+
min_batch_size=self.min_batch_size,
190+
max_batch_size=self.max_batch_size,
191+
large_model=self.large_model)
192+
193+
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
194+
"""Returns PTransform for image embedding."""
195+
return RunInference(
196+
model_handler=_ImageEmbeddingHandler(self),
197+
inference_args=self.inference_args).with_output_types(EmbeddableItem)

0 commit comments

Comments
 (0)