Skip to content

Commit b8a5aa3

Browse files
scd31Stephen D
and
Stephen D
authored
feat: allow embedding image bytes, Nix environment (#150)
* allow embedding images from their byte representation * address feedback --------- Co-authored-by: Stephen D <[email protected]>
1 parent 6990869 commit b8a5aa3

File tree

5 files changed

+193
-64
lines changed

5 files changed

+193
-64
lines changed

.envrc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
use flake

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,6 @@ tests/rustdoc-gui/src/**.lock
7575
## Rust files
7676
main.rs
7777
Cargo.lock
78+
79+
## Nix
80+
/.direnv

flake.lock

+62
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
inputs = {
3+
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
4+
flake-utils.url = "github:numtide/flake-utils?ref=main";
5+
};
6+
7+
outputs = inputs:
8+
inputs.flake-utils.lib.eachDefaultSystem (system:
9+
let
10+
pkgs = inputs.nixpkgs.legacyPackages.${system};
11+
12+
in {
13+
devShells.default = pkgs.mkShell {
14+
packages = (with pkgs; [
15+
openssl
16+
pkg-config
17+
]);
18+
};
19+
});
20+
}

src/image_embedding/impl.rs

+107-64
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ use hf_hub::{
33
api::sync::{ApiBuilder, ApiRepo},
44
Cache,
55
};
6+
use image::DynamicImage;
67
use ndarray::{Array3, ArrayView3};
78
use ort::{
89
session::{builder::GraphOptimizationLevel, Session},
910
value::Value,
1011
};
1112
#[cfg(feature = "hf-hub")]
1213
use std::path::PathBuf;
13-
use std::{path::Path, thread::available_parallelism};
14+
use std::{io::Cursor, path::Path, thread::available_parallelism};
1415

1516
use crate::{
1617
common::normalize, models::image_embedding::models_list, Embedding, ImageEmbeddingModel,
@@ -132,14 +133,12 @@ impl ImageEmbedding {
132133
.expect("Model not found.")
133134
}
134135

135-
/// Method to generate image embeddings for a Vec of image path
136-
// Generic type to accept String, &str, OsString, &OsStr
137-
pub fn embed<S: AsRef<Path> + Send + Sync>(
136+
/// Method to generate image embeddings for a Vec of image bytes
137+
pub fn embed_bytes(
138138
&self,
139-
images: Vec<S>,
139+
images: &[&[u8]],
140140
batch_size: Option<usize>,
141141
) -> anyhow::Result<Vec<Embedding>> {
142-
// Determine the batch size, default if not specified
143142
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
144143

145144
let output = images
@@ -149,72 +148,47 @@ impl ImageEmbedding {
149148
let inputs = batch
150149
.iter()
151150
.map(|img| {
152-
let img = image::ImageReader::open(img)?
151+
image::ImageReader::new(Cursor::new(img))
152+
.with_guessed_format()?
153153
.decode()
154-
.map_err(|err| anyhow!("image decode: {}", err))?;
155-
let pixels = self.preprocessor.transform(TransformData::Image(img))?;
156-
match pixels {
157-
TransformData::NdArray(array) => Ok(array),
158-
_ => Err(anyhow!("Preprocessor configuration error!")),
159-
}
154+
.map_err(|err| anyhow!("image decode: {}", err))
160155
})
161-
.collect::<anyhow::Result<Vec<Array3<f32>>>>()?;
162-
163-
// Extract the batch size
164-
let inputs_view: Vec<ArrayView3<f32>> =
165-
inputs.iter().map(|img| img.view()).collect();
166-
let pixel_values_array = ndarray::stack(ndarray::Axis(0), &inputs_view)?;
156+
.collect::<Result<_, _>>()?;
167157

168-
let input_name = self.session.inputs[0].name.clone();
169-
let session_inputs = ort::inputs![
170-
input_name => Value::from_array(pixel_values_array)?,
171-
]?;
158+
self.embed_images(inputs)
159+
})
160+
.collect::<anyhow::Result<Vec<_>>>()?
161+
.into_iter()
162+
.flatten()
163+
.collect();
172164

173-
let outputs = self.session.run(session_inputs)?;
165+
Ok(output)
166+
}
174167

175-
// Try to get the only output key
176-
// If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
177-
let last_hidden_state_key = match outputs.len() {
178-
1 => vec![outputs.keys().next().unwrap()],
179-
_ => vec!["image_embeds", "last_hidden_state"],
180-
};
168+
/// Method to generate image embeddings for a Vec of image path
169+
// Generic type to accept String, &str, OsString, &OsStr
170+
pub fn embed<S: AsRef<Path> + Send + Sync>(
171+
&self,
172+
images: Vec<S>,
173+
batch_size: Option<usize>,
174+
) -> anyhow::Result<Vec<Embedding>> {
175+
// Determine the batch size, default if not specified
176+
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
181177

182-
// Extract tensor and handle different dimensionalities
183-
let output_data = last_hidden_state_key
178+
let output = images
179+
.par_chunks(batch_size)
180+
.map(|batch| {
181+
// Encode the texts in the batch
182+
let inputs = batch
184183
.iter()
185-
.find_map(|&key| {
186-
outputs
187-
.get(key)
188-
.and_then(|v| v.try_extract_tensor::<f32>().ok())
184+
.map(|img| {
185+
image::ImageReader::open(img)?
186+
.decode()
187+
.map_err(|err| anyhow!("image decode: {}", err))
189188
})
190-
.ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?;
191-
let shape = output_data.shape();
192-
193-
let embeddings: Vec<Vec<f32>> = match shape.len() {
194-
3 => {
195-
// For 3D output [batch_size, sequence_length, hidden_size]
196-
// Take only the first token, sequence_length[0] (CLS token), embedding
197-
// and return [batch_size, hidden_size]
198-
(0..shape[0])
199-
.map(|batch_idx| {
200-
let cls_embedding =
201-
output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec();
202-
normalize(&cls_embedding)
203-
})
204-
.collect()
205-
}
206-
2 => {
207-
// For 2D output [batch_size, hidden_size]
208-
output_data
209-
.rows()
210-
.into_iter()
211-
.map(|row| normalize(row.as_slice().unwrap()))
212-
.collect()
213-
}
214-
_ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)),
215-
};
216-
217-
Ok(embeddings)
189+
.collect::<Result<_, _>>()?;
190+
191+
self.embed_images(inputs)
218192
})
219193
.collect::<anyhow::Result<Vec<_>>>()?
220194
.into_iter()
@@ -223,4 +197,73 @@ impl ImageEmbedding {
223197

224198
Ok(output)
225199
}
200+
201+
/// Embed DynamicImages
202+
pub fn embed_images(&self, imgs: Vec<DynamicImage>) -> anyhow::Result<Vec<Embedding>> {
203+
let inputs = imgs
204+
.into_iter()
205+
.map(|img| {
206+
let pixels = self.preprocessor.transform(TransformData::Image(img))?;
207+
match pixels {
208+
TransformData::NdArray(array) => Ok(array),
209+
_ => Err(anyhow!("Preprocessor configuration error!")),
210+
}
211+
})
212+
.collect::<anyhow::Result<Vec<Array3<f32>>>>()?;
213+
214+
// Extract the batch size
215+
let inputs_view: Vec<ArrayView3<f32>> = inputs.iter().map(|img| img.view()).collect();
216+
let pixel_values_array = ndarray::stack(ndarray::Axis(0), &inputs_view)?;
217+
218+
let input_name = self.session.inputs[0].name.clone();
219+
let session_inputs = ort::inputs![
220+
input_name => Value::from_array(pixel_values_array)?,
221+
]?;
222+
223+
let outputs = self.session.run(session_inputs)?;
224+
225+
// Try to get the only output key
226+
// If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
227+
let last_hidden_state_key = match outputs.len() {
228+
1 => vec![outputs.keys().next().unwrap()],
229+
_ => vec!["image_embeds", "last_hidden_state"],
230+
};
231+
232+
// Extract tensor and handle different dimensionalities
233+
let output_data = last_hidden_state_key
234+
.iter()
235+
.find_map(|&key| {
236+
outputs
237+
.get(key)
238+
.and_then(|v| v.try_extract_tensor::<f32>().ok())
239+
})
240+
.ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?;
241+
let shape = output_data.shape();
242+
243+
let embeddings = match shape.len() {
244+
3 => {
245+
// For 3D output [batch_size, sequence_length, hidden_size]
246+
// Take only the first token, sequence_length[0] (CLS token), embedding
247+
// and return [batch_size, hidden_size]
248+
(0..shape[0])
249+
.map(|batch_idx| {
250+
let cls_embedding =
251+
output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec();
252+
normalize(&cls_embedding)
253+
})
254+
.collect()
255+
}
256+
2 => {
257+
// For 2D output [batch_size, hidden_size]
258+
output_data
259+
.rows()
260+
.into_iter()
261+
.map(|row| normalize(row.as_slice().unwrap()))
262+
.collect()
263+
}
264+
_ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)),
265+
};
266+
267+
Ok(embeddings)
268+
}
226269
}

0 commit comments

Comments
 (0)