Skip to content

Commit efa636c

Browse files
style: apply rustfmt formatting
1 parent 6cd46a7 commit efa636c

6 files changed

Lines changed: 51 additions & 39 deletions

File tree

crates/runtime/src/embeddings/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616
pub mod common;
17-
pub mod params;
1817
pub mod connector;
1918
pub mod execution_plan;
19+
pub mod params;
2020

2121
pub mod index;
2222
pub mod metrics;

crates/runtime/src/embeddings/params/azure.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ const AZURE_PARAM_LEN: usize = 5;
2121
pub const PARAMETERS: &[ParameterSpec] = &AZURE_PARAMETERS;
2222

2323
pub(crate) const AZURE_PARAMETERS: [ParameterSpec; AZURE_PARAM_LEN] = [
24-
ParameterSpec::runtime("endpoint")
25-
.description("The Azure OpenAI resource endpoint, e.g., https://resource-name.openai.azure.com."),
24+
ParameterSpec::runtime("endpoint").description(
25+
"The Azure OpenAI resource endpoint, e.g., https://resource-name.openai.azure.com.",
26+
),
2627
ParameterSpec::component("api_version")
2728
.description("The API version used for the Azure OpenAI service."),
28-
ParameterSpec::component("deployment_name")
29-
.description("The name of the model deployment."),
29+
ParameterSpec::component("deployment_name").description("The name of the model deployment."),
3030
ParameterSpec::component("api_key")
3131
.secret()
3232
.description("The Azure OpenAI API key."),

crates/runtime/src/embeddings/params/databricks.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,18 @@ const DATABRICKS_PARAM_LEN: usize = 4;
2121
pub const PARAMETERS: &[ParameterSpec] = &DATABRICKS_PARAMETERS;
2222

2323
pub(crate) const DATABRICKS_PARAMETERS: [ParameterSpec; DATABRICKS_PARAM_LEN] = [
24-
ParameterSpec::component("endpoint")
25-
.description("The Databricks workspace endpoint, e.g., dbc-a12cd3e4-56f7.cloud.databricks.com."),
24+
ParameterSpec::component("endpoint").description(
25+
"The Databricks workspace endpoint, e.g., dbc-a12cd3e4-56f7.cloud.databricks.com.",
26+
),
2627
ParameterSpec::component("token")
2728
.secret()
2829
.description("The Databricks API token."),
29-
ParameterSpec::component("client_id")
30-
.description("The Databricks Service Principal Client ID. Cannot be used with databricks_token."),
30+
ParameterSpec::component("client_id").description(
31+
"The Databricks Service Principal Client ID. Cannot be used with databricks_token.",
32+
),
3133
ParameterSpec::component("client_secret")
3234
.secret()
33-
.description("The Databricks Service Principal Client Secret. Cannot be used with databricks_token."),
35+
.description(
36+
"The Databricks Service Principal Client Secret. Cannot be used with databricks_token.",
37+
),
3438
];

crates/runtime/src/embeddings/params/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ pub mod huggingface;
2323
pub mod model2vec;
2424
pub mod openai;
2525

26-
use spicepod::component::embeddings::EmbeddingPrefix;
2726
pub use crate::parameters::ParameterSpec;
27+
use spicepod::component::embeddings::EmbeddingPrefix;
2828

2929
/// Returns the parameter specifications for a given embedding source.
3030
#[must_use]

crates/runtime/src/embeddings/params/openai.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ pub(crate) const OPENAI_PARAMETERS: [ParameterSpec; OPENAI_PARAM_LEN] = [
2727
ParameterSpec::component("api_key")
2828
.secret()
2929
.description("The OpenAI API key."),
30-
ParameterSpec::component("org_id")
31-
.description("The OpenAI organization ID."),
32-
ParameterSpec::component("project_id")
33-
.description("The OpenAI project ID."),
30+
ParameterSpec::component("org_id").description("The OpenAI organization ID."),
31+
ParameterSpec::component("project_id").description("The OpenAI project ID."),
3432
ParameterSpec::component("usage_tier")
3533
.description("The current usage tier for the OpenAI account: 'free', 'tier1'-'tier5'.")
3634
.one_of(&["free", "tier1", "tier2", "tier3", "tier4", "tier5"])

crates/runtime/src/model/embed.rs

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#![allow(clippy::implicit_hasher)]
1717

1818
use crate::token_providers::databricks::{DatabricksM2MTokenProvider, DatabricksU2MTokenProvider};
19+
use crate::{embeddings::params::get_params_spec, parameters::Parameters};
1920
use bytes::Bytes;
2021
use cache::CacheProvider;
2122
use cache::result::embeddings::CachedEmbeddingResult;
@@ -31,10 +32,6 @@ use llms::bedrock::{
3132
},
3233
};
3334
use runtime_secrets::{Secrets, get_params_with_secrets};
34-
use crate::{
35-
embeddings::params::get_params_spec,
36-
parameters::Parameters,
37-
};
3835

3936
#[cfg(feature = "models")]
4037
use llms::embeddings::candle::{download_hf_file, tei::TeiEmbed};
@@ -62,8 +59,6 @@ use url::Url;
6259

6360
pub type EmbeddingModelStore = HashMap<String, Arc<dyn Embed>>;
6461

65-
66-
6762
pub async fn try_to_embedding(
6863
component: &spicepod::component::embeddings::Embeddings,
6964
secrets: Arc<RwLock<Secrets>>,
@@ -102,9 +97,7 @@ pub async fn try_to_embedding(
10297
param_spec,
10398
)
10499
.await
105-
.map_err(|e| EmbedError::FailedToInstantiateEmbeddingModel {
106-
source: e,
107-
})?;
100+
.map_err(|e| EmbedError::FailedToInstantiateEmbeddingModel { source: e })?;
108101

109102
match prefix {
110103
EmbeddingPrefix::Azure => azure(
@@ -240,7 +233,10 @@ fn google(
240233
});
241234
};
242235

243-
let dimensions: Option<u32> = params.get("dimensions").expose().ok()
236+
let dimensions: Option<u32> = params
237+
.get("dimensions")
238+
.expose()
239+
.ok()
244240
.map(|d| d.parse::<u32>())
245241
.transpose()
246242
// Only error if user provided dimensions.
@@ -278,15 +274,21 @@ async fn bedrock(
278274
.map_err(|e| EmbedError::FailedToInstantiateEmbeddingModel { source: e })?;
279275

280276
if model_id.starts_with("amazon.titan-embed") {
281-
let normalize = params.get("normalize").expose().ok()
277+
let normalize = params
278+
.get("normalize")
279+
.expose()
280+
.ok()
282281
.map(|s| s.parse::<bool>())
283282
.transpose()
284283
.map_err(|e| EmbedError::FailedToInstantiateEmbeddingModel {
285284
source: format!("Failed to parse 'normalize' parameter: {e}").into(),
286285
})?
287286
.unwrap_or(true);
288287

289-
let Some(dimensions) = params.get("dimensions").expose().ok()
288+
let Some(dimensions) = params
289+
.get("dimensions")
290+
.expose()
291+
.ok()
290292
.map(|s| s.parse::<u32>())
291293
.transpose()
292294
.map_err(|e| EmbedError::FailedToInstantiateEmbeddingModel {
@@ -311,8 +313,11 @@ async fn bedrock(
311313
bedrock::embed::new_titan_v2(client, normalize, dimensions).set_cache(embeddings_cache),
312314
) as Arc<dyn Embed>)
313315
} else if model_id.starts_with("cohere.embed") {
314-
let truncate = if let Some(truncate_str) =
315-
params.get("truncate_mode").expose().ok().or_else(|| params.get("truncate").expose().ok())
316+
let truncate = if let Some(truncate_str) = params
317+
.get("truncate_mode")
318+
.expose()
319+
.ok()
320+
.or_else(|| params.get("truncate").expose().ok())
316321
{
317322
CohereEmbeddingTruncate::from_str(truncate_str)
318323
.boxed()
@@ -345,7 +350,10 @@ async fn bedrock(
345350
.set_cache(embeddings_cache),
346351
) as Arc<dyn Embed>)
347352
} else if model_id.starts_with("amazon.nova-2-multimodal-embeddings") {
348-
let Some(dimensions) = params.get("dimensions").expose().ok()
353+
let Some(dimensions) = params
354+
.get("dimensions")
355+
.expose()
356+
.ok()
349357
.map(|s| s.parse::<u32>())
350358
.transpose()
351359
.map_err(|e| EmbedError::FailedToInstantiateEmbeddingModel {
@@ -379,8 +387,11 @@ async fn bedrock(
379387
})?
380388
.unwrap_or_default();
381389

382-
let truncate = if let Some(truncate_str) =
383-
params.get("truncate_mode").expose().ok().or_else(|| params.get("truncate").expose().ok())
390+
let truncate = if let Some(truncate_str) = params
391+
.get("truncate_mode")
392+
.expose()
393+
.ok()
394+
.or_else(|| params.get("truncate").expose().ok())
384395
{
385396
NovaTruncationMode::from_str(truncate_str)
386397
.boxed()
@@ -638,20 +649,19 @@ async fn openai(
638649
) -> Result<Arc<dyn Embed>, EmbedError> {
639650
// If parameter is from secret store, it will have `openai_` prefix
640651
let openai_usage_tier = params
641-
.get("usage_tier").expose().ok()
652+
.get("usage_tier")
653+
.expose()
654+
.ok()
642655
.map(UsageTier::from_str)
643656
.transpose()?;
644657

645658
let mut embed = OpenaiEmbed::new(
646659
llms::openai::new_openai_client(
647660
model_id.unwrap_or(DEFAULT_EMBEDDING_MODEL.to_string()),
648661
params.get("endpoint").expose().ok(),
649-
params
650-
.get("api_key").expose().ok(),
651-
params
652-
.get("org_id").expose().ok(),
653-
params
654-
.get("project_id").expose().ok(),
662+
params.get("api_key").expose().ok(),
663+
params.get("org_id").expose().ok(),
664+
params.get("project_id").expose().ok(),
655665
openai_usage_tier,
656666
),
657667
openai_usage_tier.map(Into::into),

0 commit comments

Comments
 (0)