@@ -3,14 +3,15 @@ use hf_hub::{
3
3
api:: sync:: { ApiBuilder , ApiRepo } ,
4
4
Cache ,
5
5
} ;
6
+ use image:: DynamicImage ;
6
7
use ndarray:: { Array3 , ArrayView3 } ;
7
8
use ort:: {
8
9
session:: { builder:: GraphOptimizationLevel , Session } ,
9
10
value:: Value ,
10
11
} ;
11
12
#[ cfg( feature = "hf-hub" ) ]
12
13
use std:: path:: PathBuf ;
13
- use std:: { path:: Path , thread:: available_parallelism} ;
14
+ use std:: { io :: Cursor , path:: Path , thread:: available_parallelism} ;
14
15
15
16
use crate :: {
16
17
common:: normalize, models:: image_embedding:: models_list, Embedding , ImageEmbeddingModel ,
@@ -132,14 +133,12 @@ impl ImageEmbedding {
132
133
. expect ( "Model not found." )
133
134
}
134
135
135
- /// Method to generate image embeddings for a Vec of image path
136
- // Generic type to accept String, &str, OsString, &OsStr
137
- pub fn embed < S : AsRef < Path > + Send + Sync > (
136
+ /// Method to generate image embeddings for a Vec of image bytes
137
+ pub fn embed_bytes (
138
138
& self ,
139
- images : Vec < S > ,
139
+ images : & [ & [ u8 ] ] ,
140
140
batch_size : Option < usize > ,
141
141
) -> anyhow:: Result < Vec < Embedding > > {
142
- // Determine the batch size, default if not specified
143
142
let batch_size = batch_size. unwrap_or ( DEFAULT_BATCH_SIZE ) ;
144
143
145
144
let output = images
@@ -149,72 +148,47 @@ impl ImageEmbedding {
149
148
let inputs = batch
150
149
. iter ( )
151
150
. map ( |img| {
152
- let img = image:: ImageReader :: open ( img) ?
151
+ image:: ImageReader :: new ( Cursor :: new ( img) )
152
+ . with_guessed_format ( ) ?
153
153
. decode ( )
154
- . map_err ( |err| anyhow ! ( "image decode: {}" , err) ) ?;
155
- let pixels = self . preprocessor . transform ( TransformData :: Image ( img) ) ?;
156
- match pixels {
157
- TransformData :: NdArray ( array) => Ok ( array) ,
158
- _ => Err ( anyhow ! ( "Preprocessor configuration error!" ) ) ,
159
- }
154
+ . map_err ( |err| anyhow ! ( "image decode: {}" , err) )
160
155
} )
161
- . collect :: < anyhow:: Result < Vec < Array3 < f32 > > > > ( ) ?;
162
-
163
- // Extract the batch size
164
- let inputs_view: Vec < ArrayView3 < f32 > > =
165
- inputs. iter ( ) . map ( |img| img. view ( ) ) . collect ( ) ;
166
- let pixel_values_array = ndarray:: stack ( ndarray:: Axis ( 0 ) , & inputs_view) ?;
156
+ . collect :: < Result < _ , _ > > ( ) ?;
167
157
168
- let input_name = self . session . inputs [ 0 ] . name . clone ( ) ;
169
- let session_inputs = ort:: inputs![
170
- input_name => Value :: from_array( pixel_values_array) ?,
171
- ] ?;
158
+ self . embed_images ( inputs)
159
+ } )
160
+ . collect :: < anyhow:: Result < Vec < _ > > > ( ) ?
161
+ . into_iter ( )
162
+ . flatten ( )
163
+ . collect ( ) ;
172
164
173
- let outputs = self . session . run ( session_inputs) ?;
165
+ Ok ( output)
166
+ }
174
167
175
- // Try to get the only output key
176
- // If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
177
- let last_hidden_state_key = match outputs. len ( ) {
178
- 1 => vec ! [ outputs. keys( ) . next( ) . unwrap( ) ] ,
179
- _ => vec ! [ "image_embeds" , "last_hidden_state" ] ,
180
- } ;
168
+ /// Method to generate image embeddings for a Vec of image path
169
+ // Generic type to accept String, &str, OsString, &OsStr
170
+ pub fn embed < S : AsRef < Path > + Send + Sync > (
171
+ & self ,
172
+ images : Vec < S > ,
173
+ batch_size : Option < usize > ,
174
+ ) -> anyhow:: Result < Vec < Embedding > > {
175
+ // Determine the batch size, default if not specified
176
+ let batch_size = batch_size. unwrap_or ( DEFAULT_BATCH_SIZE ) ;
181
177
182
- // Extract tensor and handle different dimensionalities
183
- let output_data = last_hidden_state_key
178
+ let output = images
179
+ . par_chunks ( batch_size)
180
+ . map ( |batch| {
181
+ // Encode the texts in the batch
182
+ let inputs = batch
184
183
. iter ( )
185
- . find_map ( | & key | {
186
- outputs
187
- . get ( key )
188
- . and_then ( |v| v . try_extract_tensor :: < f32 > ( ) . ok ( ) )
184
+ . map ( |img | {
185
+ image :: ImageReader :: open ( img ) ?
186
+ . decode ( )
187
+ . map_err ( |err| anyhow ! ( "image decode: {}" , err ) )
189
188
} )
190
- . ok_or_else ( || anyhow ! ( "Could not extract tensor from any known output key" ) ) ?;
191
- let shape = output_data. shape ( ) ;
192
-
193
- let embeddings: Vec < Vec < f32 > > = match shape. len ( ) {
194
- 3 => {
195
- // For 3D output [batch_size, sequence_length, hidden_size]
196
- // Take only the first token, sequence_length[0] (CLS token), embedding
197
- // and return [batch_size, hidden_size]
198
- ( 0 ..shape[ 0 ] )
199
- . map ( |batch_idx| {
200
- let cls_embedding =
201
- output_data. slice ( ndarray:: s![ batch_idx, 0 , ..] ) . to_vec ( ) ;
202
- normalize ( & cls_embedding)
203
- } )
204
- . collect ( )
205
- }
206
- 2 => {
207
- // For 2D output [batch_size, hidden_size]
208
- output_data
209
- . rows ( )
210
- . into_iter ( )
211
- . map ( |row| normalize ( row. as_slice ( ) . unwrap ( ) ) )
212
- . collect ( )
213
- }
214
- _ => return Err ( anyhow ! ( "Unexpected output tensor shape: {:?}" , shape) ) ,
215
- } ;
216
-
217
- Ok ( embeddings)
189
+ . collect :: < Result < _ , _ > > ( ) ?;
190
+
191
+ self . embed_images ( inputs)
218
192
} )
219
193
. collect :: < anyhow:: Result < Vec < _ > > > ( ) ?
220
194
. into_iter ( )
@@ -223,4 +197,73 @@ impl ImageEmbedding {
223
197
224
198
Ok ( output)
225
199
}
200
+
201
+ /// Embed DynamicImages
202
+ pub fn embed_images ( & self , imgs : Vec < DynamicImage > ) -> anyhow:: Result < Vec < Embedding > > {
203
+ let inputs = imgs
204
+ . into_iter ( )
205
+ . map ( |img| {
206
+ let pixels = self . preprocessor . transform ( TransformData :: Image ( img) ) ?;
207
+ match pixels {
208
+ TransformData :: NdArray ( array) => Ok ( array) ,
209
+ _ => Err ( anyhow ! ( "Preprocessor configuration error!" ) ) ,
210
+ }
211
+ } )
212
+ . collect :: < anyhow:: Result < Vec < Array3 < f32 > > > > ( ) ?;
213
+
214
+ // Extract the batch size
215
+ let inputs_view: Vec < ArrayView3 < f32 > > = inputs. iter ( ) . map ( |img| img. view ( ) ) . collect ( ) ;
216
+ let pixel_values_array = ndarray:: stack ( ndarray:: Axis ( 0 ) , & inputs_view) ?;
217
+
218
+ let input_name = self . session . inputs [ 0 ] . name . clone ( ) ;
219
+ let session_inputs = ort:: inputs![
220
+ input_name => Value :: from_array( pixel_values_array) ?,
221
+ ] ?;
222
+
223
+ let outputs = self . session . run ( session_inputs) ?;
224
+
225
+ // Try to get the only output key
226
+ // If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
227
+ let last_hidden_state_key = match outputs. len ( ) {
228
+ 1 => vec ! [ outputs. keys( ) . next( ) . unwrap( ) ] ,
229
+ _ => vec ! [ "image_embeds" , "last_hidden_state" ] ,
230
+ } ;
231
+
232
+ // Extract tensor and handle different dimensionalities
233
+ let output_data = last_hidden_state_key
234
+ . iter ( )
235
+ . find_map ( |& key| {
236
+ outputs
237
+ . get ( key)
238
+ . and_then ( |v| v. try_extract_tensor :: < f32 > ( ) . ok ( ) )
239
+ } )
240
+ . ok_or_else ( || anyhow ! ( "Could not extract tensor from any known output key" ) ) ?;
241
+ let shape = output_data. shape ( ) ;
242
+
243
+ let embeddings = match shape. len ( ) {
244
+ 3 => {
245
+ // For 3D output [batch_size, sequence_length, hidden_size]
246
+ // Take only the first token, sequence_length[0] (CLS token), embedding
247
+ // and return [batch_size, hidden_size]
248
+ ( 0 ..shape[ 0 ] )
249
+ . map ( |batch_idx| {
250
+ let cls_embedding =
251
+ output_data. slice ( ndarray:: s![ batch_idx, 0 , ..] ) . to_vec ( ) ;
252
+ normalize ( & cls_embedding)
253
+ } )
254
+ . collect ( )
255
+ }
256
+ 2 => {
257
+ // For 2D output [batch_size, hidden_size]
258
+ output_data
259
+ . rows ( )
260
+ . into_iter ( )
261
+ . map ( |row| normalize ( row. as_slice ( ) . unwrap ( ) ) )
262
+ . collect ( )
263
+ }
264
+ _ => return Err ( anyhow ! ( "Unexpected output tensor shape: {:?}" , shape) ) ,
265
+ } ;
266
+
267
+ Ok ( embeddings)
268
+ }
226
269
}
0 commit comments