|
16 | 16 |
|
17 | 17 | """RAG-specific embedding implementations using HuggingFace models.""" |
18 | 18 |
|
| 19 | +import io |
| 20 | +from collections.abc import Sequence |
19 | 21 | from typing import Optional |
20 | 22 |
|
21 | 23 | import apache_beam as beam |
| 24 | +from apache_beam.io.filesystems import FileSystems |
22 | 25 | from apache_beam.ml.inference.base import RunInference |
| 26 | +from apache_beam.ml.rag.embeddings.base import _add_embedding_fn |
23 | 27 | from apache_beam.ml.rag.embeddings.base import create_text_adapter |
24 | 28 | from apache_beam.ml.rag.types import EmbeddableItem |
25 | 29 | from apache_beam.ml.transforms.base import EmbeddingsManager |
| 30 | +from apache_beam.ml.transforms.base import EmbeddingTypeAdapter |
| 31 | +from apache_beam.ml.transforms.base import _ImageEmbeddingHandler |
26 | 32 | from apache_beam.ml.transforms.base import _TextEmbeddingHandler |
27 | 33 | from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler |
28 | 34 |
|
|
31 | 37 | except ImportError: |
32 | 38 | SentenceTransformer = None |
33 | 39 |
|
| 40 | +try: |
| 41 | + from PIL import Image as PILImage |
| 42 | +except ImportError: |
| 43 | + PILImage = None # type: ignore[assignment] |
| 44 | + |
34 | 45 |
|
35 | 46 | class HuggingfaceTextEmbeddings(EmbeddingsManager): |
36 | 47 | def __init__( |
37 | 48 | self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): |
38 | | - """Utilizes huggingface SentenceTransformer embeddings for RAG pipeline. |
39 | | -
|
40 | | - Args: |
41 | | - model_name: Name of the sentence-transformers model to use |
42 | | - max_seq_length: Maximum sequence length for the model |
43 | | - **kwargs: Additional arguments passed to |
44 | | - :class:`~apache_beam.ml.transforms.base.EmbeddingsManager` |
45 | | - constructor including ModelHandler arguments |
46 | | - """ |
| 49 | + """HuggingFace text embeddings for RAG pipelines. |
| 50 | +
|
| 51 | + Args: |
| 52 | + model_name: Name of the sentence-transformers model to use. |
| 53 | + max_seq_length: Maximum sequence length for the model. |
| 54 | + **kwargs: Additional arguments passed to |
| 55 | + :class:`~apache_beam.ml.transforms.base.EmbeddingsManager`, |
| 56 | + including: |
| 57 | +
|
| 58 | + - ``load_model_args``: dict passed to |
| 59 | + ``SentenceTransformer()`` constructor |
| 60 | + (e.g. ``device``, ``cache_folder``). |
| 61 | + - ``min_batch_size`` / ``max_batch_size``: |
| 62 | + Control batching for inference. |
| 63 | + - ``large_model``: If True, share the model |
| 64 | + across processes to reduce memory usage. |
| 65 | + - ``inference_args``: dict passed to |
| 66 | + ``model.encode()`` |
| 67 | + (e.g. ``normalize_embeddings``). |
| 68 | + """ |
47 | 69 | if not SentenceTransformer: |
48 | 70 | raise ImportError( |
49 | 71 | "sentence-transformers is required to use " |
@@ -73,3 +95,103 @@ def get_ptransform_for_processing( |
73 | 95 | return RunInference( |
74 | 96 | model_handler=_TextEmbeddingHandler(self), |
75 | 97 | inference_args=self.inference_args).with_output_types(EmbeddableItem) |
| 98 | + |
| 99 | + |
| 100 | +def _extract_images(items: Sequence[EmbeddableItem]) -> list: |
| 101 | + """Extract images from items and convert to PIL.Image objects. |
| 102 | +
|
| 103 | + Supports raw bytes, local file paths, and remote URIs |
| 104 | + (e.g. gs://, s3://) via Beam's FileSystems. |
| 105 | + """ |
| 106 | + images = [] |
| 107 | + for item in items: |
| 108 | + if not item.content.image: |
| 109 | + raise ValueError( |
| 110 | + "Expected image content in " |
| 111 | + f"{type(item).__name__} {item.id}, " |
| 112 | + "got None") |
| 113 | + img_data = item.content.image |
| 114 | + if isinstance(img_data, bytes): |
| 115 | + img = PILImage.open(io.BytesIO(img_data)) |
| 116 | + else: |
| 117 | + with FileSystems.open(img_data, 'rb') as f: |
| 118 | + img = PILImage.open(f) |
| 119 | + img.load() |
| 120 | + images.append(img.convert('RGB')) |
| 121 | + return images |
| 122 | + |
| 123 | + |
| 124 | +def _create_hf_image_adapter( |
| 125 | +) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]: |
| 126 | + """Creates adapter for HuggingFace image embedding. |
| 127 | +
|
| 128 | + Extracts content.image from EmbeddableItems and converts |
| 129 | + to PIL.Image objects. Supports both raw bytes and file paths. |
| 130 | +
|
| 131 | + Returns: |
| 132 | + EmbeddingTypeAdapter for HuggingFace image embedding. |
| 133 | + """ |
| 134 | + return EmbeddingTypeAdapter( |
| 135 | + input_fn=_extract_images, output_fn=_add_embedding_fn) |
| 136 | + |
| 137 | + |
| 138 | +class HuggingfaceImageEmbeddings(EmbeddingsManager): |
| 139 | + def __init__( |
| 140 | + self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): |
| 141 | + """HuggingFace image embeddings for RAG pipelines. |
| 142 | +
|
| 143 | + Generates embeddings for images using sentence-transformers |
| 144 | + models that support image input (e.g. clip-ViT-B-32). |
| 145 | +
|
| 146 | + Args: |
| 147 | + model_name: Name of the sentence-transformers model. |
| 148 | + Must be an image-text model. See |
| 149 | + https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models |
| 150 | + max_seq_length: Maximum sequence length for the model |
| 151 | + if applicable. |
| 152 | + **kwargs: Additional arguments passed to |
| 153 | + :class:`~apache_beam.ml.transforms.base.EmbeddingsManager`, |
| 154 | + including: |
| 155 | +
|
| 156 | + - ``load_model_args``: dict passed to |
| 157 | + ``SentenceTransformer()`` constructor |
| 158 | + (e.g. ``device``, ``cache_folder``, |
| 159 | + ``trust_remote_code``). |
| 160 | + - ``min_batch_size`` / ``max_batch_size``: |
| 161 | + Control batching for inference. |
| 162 | + - ``large_model``: If True, share the model |
| 163 | + across processes to reduce memory usage. |
| 164 | + - ``inference_args``: dict passed to |
| 165 | + ``model.encode()`` |
| 166 | + (e.g. ``normalize_embeddings``). |
| 167 | + """ |
| 168 | + if not SentenceTransformer: |
| 169 | + raise ImportError( |
| 170 | + "sentence-transformers is required to use " |
| 171 | + "HuggingfaceImageEmbeddings. " |
| 172 | + "Please install it with `pip install sentence-transformers`.") |
| 173 | + if not PILImage: |
| 174 | + raise ImportError( |
| 175 | + "Pillow is required to use HuggingfaceImageEmbeddings. " |
| 176 | + "Please install it with `pip install pillow`.") |
| 177 | + super().__init__(type_adapter=_create_hf_image_adapter(), **kwargs) |
| 178 | + self.model_name = model_name |
| 179 | + self.max_seq_length = max_seq_length |
| 180 | + self.model_class = SentenceTransformer |
| 181 | + |
| 182 | + def get_model_handler(self): |
| 183 | + """Returns model handler configured with RAG adapter.""" |
| 184 | + return _SentenceTransformerModelHandler( |
| 185 | + model_class=self.model_class, |
| 186 | + max_seq_length=self.max_seq_length, |
| 187 | + model_name=self.model_name, |
| 188 | + load_model_args=self.load_model_args, |
| 189 | + min_batch_size=self.min_batch_size, |
| 190 | + max_batch_size=self.max_batch_size, |
| 191 | + large_model=self.large_model) |
| 192 | + |
| 193 | + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: |
| 194 | + """Returns PTransform for image embedding.""" |
| 195 | + return RunInference( |
| 196 | + model_handler=_ImageEmbeddingHandler(self), |
| 197 | + inference_args=self.inference_args).with_output_types(EmbeddableItem) |
0 commit comments