@@ -16,6 +16,7 @@ limitations under the License.
1616#![ allow( clippy:: implicit_hasher) ]
1717
1818use crate :: token_providers:: databricks:: { DatabricksM2MTokenProvider , DatabricksU2MTokenProvider } ;
19+ use crate :: { embeddings:: params:: get_params_spec, parameters:: Parameters } ;
1920use bytes:: Bytes ;
2021use cache:: CacheProvider ;
2122use cache:: result:: embeddings:: CachedEmbeddingResult ;
@@ -31,10 +32,6 @@ use llms::bedrock::{
3132 } ,
3233} ;
3334use runtime_secrets:: { Secrets , get_params_with_secrets} ;
34- use crate :: {
35- embeddings:: params:: get_params_spec,
36- parameters:: Parameters ,
37- } ;
3835
3936#[ cfg( feature = "models" ) ]
4037use llms:: embeddings:: candle:: { download_hf_file, tei:: TeiEmbed } ;
@@ -62,8 +59,6 @@ use url::Url;
6259
6360pub type EmbeddingModelStore = HashMap < String , Arc < dyn Embed > > ;
6461
65-
66-
6762pub async fn try_to_embedding (
6863 component : & spicepod:: component:: embeddings:: Embeddings ,
6964 secrets : Arc < RwLock < Secrets > > ,
@@ -102,9 +97,7 @@ pub async fn try_to_embedding(
10297 param_spec,
10398 )
10499 . await
105- . map_err ( |e| EmbedError :: FailedToInstantiateEmbeddingModel {
106- source : e,
107- } ) ?;
100+ . map_err ( |e| EmbedError :: FailedToInstantiateEmbeddingModel { source : e } ) ?;
108101
109102 match prefix {
110103 EmbeddingPrefix :: Azure => azure (
@@ -240,7 +233,10 @@ fn google(
240233 } ) ;
241234 } ;
242235
243- let dimensions: Option < u32 > = params. get ( "dimensions" ) . expose ( ) . ok ( )
236+ let dimensions: Option < u32 > = params
237+ . get ( "dimensions" )
238+ . expose ( )
239+ . ok ( )
244240 . map ( |d| d. parse :: < u32 > ( ) )
245241 . transpose ( )
246242 // Only error if user provided dimensions.
@@ -278,15 +274,21 @@ async fn bedrock(
278274 . map_err ( |e| EmbedError :: FailedToInstantiateEmbeddingModel { source : e } ) ?;
279275
280276 if model_id. starts_with ( "amazon.titan-embed" ) {
281- let normalize = params. get ( "normalize" ) . expose ( ) . ok ( )
277+ let normalize = params
278+ . get ( "normalize" )
279+ . expose ( )
280+ . ok ( )
282281 . map ( |s| s. parse :: < bool > ( ) )
283282 . transpose ( )
284283 . map_err ( |e| EmbedError :: FailedToInstantiateEmbeddingModel {
285284 source : format ! ( "Failed to parse 'normalize' parameter: {e}" ) . into ( ) ,
286285 } ) ?
287286 . unwrap_or ( true ) ;
288287
289- let Some ( dimensions) = params. get ( "dimensions" ) . expose ( ) . ok ( )
288+ let Some ( dimensions) = params
289+ . get ( "dimensions" )
290+ . expose ( )
291+ . ok ( )
290292 . map ( |s| s. parse :: < u32 > ( ) )
291293 . transpose ( )
292294 . map_err ( |e| EmbedError :: FailedToInstantiateEmbeddingModel {
@@ -311,8 +313,11 @@ async fn bedrock(
311313 bedrock:: embed:: new_titan_v2 ( client, normalize, dimensions) . set_cache ( embeddings_cache) ,
312314 ) as Arc < dyn Embed > )
313315 } else if model_id. starts_with ( "cohere.embed" ) {
314- let truncate = if let Some ( truncate_str) =
315- params. get ( "truncate_mode" ) . expose ( ) . ok ( ) . or_else ( || params. get ( "truncate" ) . expose ( ) . ok ( ) )
316+ let truncate = if let Some ( truncate_str) = params
317+ . get ( "truncate_mode" )
318+ . expose ( )
319+ . ok ( )
320+ . or_else ( || params. get ( "truncate" ) . expose ( ) . ok ( ) )
316321 {
317322 CohereEmbeddingTruncate :: from_str ( truncate_str)
318323 . boxed ( )
@@ -345,7 +350,10 @@ async fn bedrock(
345350 . set_cache ( embeddings_cache) ,
346351 ) as Arc < dyn Embed > )
347352 } else if model_id. starts_with ( "amazon.nova-2-multimodal-embeddings" ) {
348- let Some ( dimensions) = params. get ( "dimensions" ) . expose ( ) . ok ( )
353+ let Some ( dimensions) = params
354+ . get ( "dimensions" )
355+ . expose ( )
356+ . ok ( )
349357 . map ( |s| s. parse :: < u32 > ( ) )
350358 . transpose ( )
351359 . map_err ( |e| EmbedError :: FailedToInstantiateEmbeddingModel {
@@ -379,8 +387,11 @@ async fn bedrock(
379387 } ) ?
380388 . unwrap_or_default ( ) ;
381389
382- let truncate = if let Some ( truncate_str) =
383- params. get ( "truncate_mode" ) . expose ( ) . ok ( ) . or_else ( || params. get ( "truncate" ) . expose ( ) . ok ( ) )
390+ let truncate = if let Some ( truncate_str) = params
391+ . get ( "truncate_mode" )
392+ . expose ( )
393+ . ok ( )
394+ . or_else ( || params. get ( "truncate" ) . expose ( ) . ok ( ) )
384395 {
385396 NovaTruncationMode :: from_str ( truncate_str)
386397 . boxed ( )
@@ -638,20 +649,19 @@ async fn openai(
638649) -> Result < Arc < dyn Embed > , EmbedError > {
639650 // If parameter is from secret store, it will have `openai_` prefix
640651 let openai_usage_tier = params
641- . get ( "usage_tier" ) . expose ( ) . ok ( )
652+ . get ( "usage_tier" )
653+ . expose ( )
654+ . ok ( )
642655 . map ( UsageTier :: from_str)
643656 . transpose ( ) ?;
644657
645658 let mut embed = OpenaiEmbed :: new (
646659 llms:: openai:: new_openai_client (
647660 model_id. unwrap_or ( DEFAULT_EMBEDDING_MODEL . to_string ( ) ) ,
648661 params. get ( "endpoint" ) . expose ( ) . ok ( ) ,
649- params
650- . get ( "api_key" ) . expose ( ) . ok ( ) ,
651- params
652- . get ( "org_id" ) . expose ( ) . ok ( ) ,
653- params
654- . get ( "project_id" ) . expose ( ) . ok ( ) ,
662+ params. get ( "api_key" ) . expose ( ) . ok ( ) ,
663+ params. get ( "org_id" ) . expose ( ) . ok ( ) ,
664+ params. get ( "project_id" ) . expose ( ) . ok ( ) ,
655665 openai_usage_tier,
656666 ) ,
657667 openai_usage_tier. map ( Into :: into) ,
0 commit comments