Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 32 additions & 39 deletions src/loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
//! and Hugging Face Hub lock acquisition failures.

use crate::core::GenerationConfig;
use serde::Deserialize;
use hf_hub::api::tokio::Api as HfApi;
use std::path::PathBuf;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -110,6 +111,19 @@ pub struct GenerationConfigLoader {
pub generation_config_file_loader: HfLoader,
}

#[derive(Deserialize)]
struct RawGenerationConfig {
temperature: Option<f64>,
top_p: Option<f64>,
top_k: Option<u64>,
min_p: Option<f64>,
#[serde(alias = "repetition_penalty", alias = "repeat_penalty")]
repeat_penalty: Option<f32>,
repeat_last_n: Option<usize>,
#[serde(alias = "eos_token_id", alias = "eos_token_ids")]
eos_token_ids: Option<serde_json::Value>,
}

impl GenerationConfigLoader {
pub fn new(repo: &str, filename: &str) -> Self {
let generation_config_file_loader = HfLoader::new(repo, filename);
Expand All @@ -128,49 +142,28 @@ impl GenerationConfigLoader {
let generation_config_content =
std::fs::read_to_string(generation_config_file_path)?;

let config_json: serde_json::Value = serde_json::from_str(&generation_config_content)?;

// All fields are optional to handle inconsistent JSON files
let temperature = config_json.get("temperature").and_then(|v| v.as_f64());
let top_p = config_json.get("top_p").and_then(|v| v.as_f64());
let top_k = config_json.get("top_k").and_then(|v| v.as_u64());
let min_p = config_json.get("min_p").and_then(|v| v.as_f64());
let repeat_penalty = config_json
.get("repetition_penalty")
.or_else(|| config_json.get("repeat_penalty"))
.and_then(|v| v.as_f64())
.map(|v| v as f32);
let repeat_last_n = config_json
.get("repeat_last_n")
.and_then(|v| v.as_u64())
.map(|v| v as usize);

// Handle both single EOS token ID and array of EOS token IDs
let eos_token_ids = match config_json.get("eos_token_id") {
Some(serde_json::Value::Number(n)) => vec![n.as_u64().expect("Invalid EOS token ID")],
let raw: RawGenerationConfig = serde_json::from_str(&generation_config_content)?;

let eos_token_ids = match raw.eos_token_ids {
Some(serde_json::Value::Number(n)) => vec![n
.as_u64()
.ok_or_else(|| anyhow::anyhow!("Invalid EOS token ID"))?],
Some(serde_json::Value::Array(arr)) => arr
.iter()
.map(|v| v.as_u64().expect("Invalid EOS token ID in array"))
.collect(),
_ => {
// Try alternative field names
match config_json.get("eos_token_ids") {
Some(serde_json::Value::Array(arr)) => arr
.iter()
.map(|v| v.as_u64().expect("Invalid EOS token ID in array"))
.collect(),
_ => vec![], // Empty vec instead of panic
}
}
.into_iter()
.map(|v| v
.as_u64()
.ok_or_else(|| anyhow::anyhow!("Invalid EOS token ID in array")))
.collect::<Result<Vec<_>, _>>()?,
_ => Vec::new(),
};

Ok(GenerationConfig {
temperature,
top_p,
top_k,
min_p,
repeat_penalty,
repeat_last_n,
temperature: raw.temperature,
top_p: raw.top_p,
top_k: raw.top_k,
min_p: raw.min_p,
repeat_penalty: raw.repeat_penalty,
repeat_last_n: raw.repeat_last_n,
eos_token_ids,
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/implementations/gemma3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ Pipeline stuff

*/

use crate::pipelines::text_generation_pipeline::text_generation_model::{
use crate::pipelines::text_generation_pipeline::model::{
LanguageModelContext, TextGenerationModel,
};
use async_trait::async_trait;
Expand Down
6 changes: 3 additions & 3 deletions src/models/implementations/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ impl FillMaskModernBertModel {
}
}

impl crate::pipelines::fill_mask_pipeline::fill_mask_model::FillMaskModel
impl crate::pipelines::fill_mask_pipeline::model::FillMaskModel
for FillMaskModernBertModel
{
type Options = ModernBertSize;
Expand Down Expand Up @@ -1016,7 +1016,7 @@ impl ZeroShotModernBertModel {
}
}

impl crate::pipelines::zero_shot_classification_pipeline::zero_shot_classification_model::ZeroShotClassificationModel
impl crate::pipelines::zero_shot_classification_pipeline::model::ZeroShotClassificationModel
for ZeroShotModernBertModel
{
type Options = ModernBertSize;
Expand Down Expand Up @@ -1205,7 +1205,7 @@ impl SentimentModernBertModel {
}
}

impl crate::pipelines::sentiment_analysis_pipeline::sentiment_analysis_model::SentimentAnalysisModel
impl crate::pipelines::sentiment_analysis_pipeline::model::SentimentAnalysisModel
for SentimentModernBertModel
{
type Options = ModernBertSize;
Expand Down
4 changes: 2 additions & 2 deletions src/models/implementations/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ Pipeline Stuff

*/

use crate::pipelines::text_generation_pipeline::text_generation_model::{
use crate::pipelines::text_generation_pipeline::model::{
LanguageModelContext, TextGenerationModel, ToggleableReasoning, ToolCalling,
};

Expand Down Expand Up @@ -999,7 +999,7 @@ impl ToggleableReasoning for Qwen3Model {
}

use crate::core::ToolError;
use crate::pipelines::text_generation_pipeline::text_generation_model::Tool;
use crate::pipelines::text_generation_pipeline::model::Tool;
use async_trait::async_trait;

impl ToolCalling for Qwen3Model {
Expand Down
2 changes: 1 addition & 1 deletion src/models/implementations/qwen3_embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn l2_normalise(t: Tensor) -> Result<Tensor> {
t.broadcast_div(&norm)
}

use crate::pipelines::embedding_pipeline::embedding_model::EmbeddingModel;
use crate::pipelines::embedding_pipeline::model::EmbeddingModel;

impl EmbeddingModel for Qwen3EmbeddingModel {
type Options = Qwen3EmbeddingSize;
Expand Down
4 changes: 2 additions & 2 deletions src/models/implementations/qwen3_reranker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ impl Qwen3RerankModel {
}


use crate::pipelines::reranker_pipeline::reranker_model::RerankModel;
use crate::pipelines::reranker_pipeline::reranker_pipeline::RerankResult;
use crate::pipelines::reranker_pipeline::model::RerankModel;
use crate::pipelines::reranker_pipeline::pipeline::RerankResult;

impl RerankModel for Qwen3RerankModel {
type Options = Qwen3RerankSize;
Expand Down
9 changes: 5 additions & 4 deletions src/pipelines/embedding_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::embedding_model::EmbeddingModel;
use super::embedding_pipeline::EmbeddingPipeline;
use super::model::EmbeddingModel;
use super::pipeline::EmbeddingPipeline;
use std::sync::Arc;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::DeviceRequest;
use crate::pipelines::utils::{build_cache_key, DeviceRequest};

pub struct EmbeddingPipelineBuilder<M: EmbeddingModel> {
options: M::Options,
Expand Down Expand Up @@ -38,7 +38,7 @@ impl<M: EmbeddingModel> EmbeddingPipelineBuilder<M> {
M::Options: ModelOptions + Clone,
{
let device = self.device_request.resolve()?;
let key = format!("{}-{:?}", self.options.cache_key(), device.location());
let key = build_cache_key(&self.options, &device);
let model = global_cache()
.get_or_create(&key, || M::new(self.options.clone(), device.clone()))
.await?;
Expand All @@ -52,3 +52,4 @@ impl EmbeddingPipelineBuilder<crate::models::implementations::qwen3_embeddings::
Self::new(size)
}
}

8 changes: 4 additions & 4 deletions src/pipelines/embedding_pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
//! ```

pub mod builder;
pub mod embedding_model;
pub mod embedding_pipeline;
pub mod model;
pub mod pipeline;

pub use builder::EmbeddingPipelineBuilder;
pub use embedding_model::EmbeddingModel;
pub use embedding_pipeline::EmbeddingPipeline;
pub use model::EmbeddingModel;
pub use pipeline::EmbeddingPipeline;

pub use crate::models::implementations::qwen3_embeddings::Qwen3EmbeddingModel;
pub use crate::models::implementations::qwen3_embeddings::Qwen3EmbeddingSize;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::embedding_model::EmbeddingModel;
use super::model::EmbeddingModel;
use tokenizers::Tokenizer;
use std::sync::Arc;

Expand Down
25 changes: 11 additions & 14 deletions src/pipelines/fill_mask_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
use super::fill_mask_model::FillMaskModel;
use super::fill_mask_pipeline::FillMaskPipeline;
use super::model::FillMaskModel;
use super::pipeline::FillMaskPipeline;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::{build_cache_key, DeviceRequest};

pub struct FillMaskPipelineBuilder<M: FillMaskModel> {
options: M::Options,
device: Option<candle_core::Device>,
device_request: DeviceRequest,
}

impl<M: FillMaskModel> FillMaskPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device: None,
device_request: DeviceRequest::Default,
}
}

pub fn cpu(mut self) -> Self {
self.device = Some(candle_core::Device::Cpu);
self.device_request = DeviceRequest::Cpu;
self
}

pub fn cuda_device(mut self, index: usize) -> Self {
let dev =
candle_core::Device::new_cuda_with_stream(index).unwrap_or(candle_core::Device::Cpu);
self.device = Some(dev);
self.device_request = DeviceRequest::Cuda(index);
self
}

pub fn device(mut self, device: candle_core::Device) -> Self {
self.device = Some(device);
self.device_request = DeviceRequest::Explicit(device);
self
}

Expand All @@ -37,11 +36,8 @@ impl<M: FillMaskModel> FillMaskPipelineBuilder<M> {
M: Clone + Send + Sync + 'static,
M::Options: ModelOptions + Clone,
{
let device = match self.device {
Some(d) => d,
None => crate::pipelines::utils::load_device()?,
};
let key = format!("{}-{:?}", self.options.cache_key(), device.location());
let device = self.device_request.resolve()?;
let key = build_cache_key(&self.options, &device);
let model = global_cache()
.get_or_create(&key, || M::new(self.options.clone(), device.clone()))
.await?;
Expand All @@ -55,3 +51,4 @@ impl FillMaskPipelineBuilder<crate::models::implementations::modernbert::FillMas
Self::new(size)
}
}

8 changes: 4 additions & 4 deletions src/pipelines/fill_mask_pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
//! ```

pub mod builder;
pub mod fill_mask_model;
pub mod fill_mask_pipeline;
pub mod model;
pub mod pipeline;

pub use builder::FillMaskPipelineBuilder;
pub use fill_mask_model::FillMaskModel;
pub use fill_mask_pipeline::FillMaskPipeline;
pub use model::FillMaskModel;
pub use pipeline::FillMaskPipeline;

pub use crate::models::ModernBertSize;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::fill_mask_model::FillMaskModel;
use super::model::FillMaskModel;
use tokenizers::Tokenizer;

pub struct FillMaskPipeline<M: FillMaskModel> {
Expand Down
10 changes: 5 additions & 5 deletions src/pipelines/reranker_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::reranker_model::RerankModel;
use super::reranker_pipeline::RerankPipeline;
use super::model::RerankModel;
use super::pipeline::RerankPipeline;
use std::sync::Arc;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::DeviceRequest;
use crate::pipelines::utils::{build_cache_key, DeviceRequest};

pub struct RerankPipelineBuilder<M: RerankModel> {
options: M::Options,
Expand Down Expand Up @@ -38,7 +38,7 @@ impl<M: RerankModel> RerankPipelineBuilder<M> {
M::Options: ModelOptions + Clone,
{
let device = self.device_request.resolve()?;
let key = format!("{}-{:?}", self.options.cache_key(), device.location());
let key = build_cache_key(&self.options, &device);
let model = global_cache()
.get_or_create(&key, || M::new(self.options.clone(), device.clone()))
.await?;
Expand All @@ -51,4 +51,4 @@ impl RerankPipelineBuilder<crate::models::implementations::qwen3_reranker::Qwen3
pub fn qwen3(size: crate::models::implementations::qwen3_reranker::Qwen3RerankSize) -> Self {
Self::new(size)
}
}
}
10 changes: 5 additions & 5 deletions src/pipelines/reranker_pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
//! ```

pub mod builder;
pub mod reranker_model;
pub mod reranker_pipeline;
pub mod model;
pub mod pipeline;

pub use builder::RerankPipelineBuilder;
pub use reranker_model::RerankModel;
pub use reranker_pipeline::{RerankPipeline, RerankResult};
pub use model::RerankModel;
pub use pipeline::{RerankPipeline, RerankResult};

pub use crate::models::implementations::qwen3_reranker::{Qwen3RerankModel, Qwen3RerankSize};

pub use anyhow::Result;
pub use anyhow::Result;
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use tokenizers::Tokenizer;
use candle_core::Device;
use super::reranker_pipeline::RerankResult;
use super::pipeline::RerankResult;

/// Trait for reranking models.
pub trait RerankModel {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::reranker_model::RerankModel;
use super::model::RerankModel;
use tokenizers::Tokenizer;
use std::sync::Arc;

Expand Down
8 changes: 5 additions & 3 deletions src/pipelines/sentiment_analysis_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::sentiment_analysis_model::SentimentAnalysisModel;
use super::sentiment_analysis_pipeline::SentimentAnalysisPipeline;
use super::model::SentimentAnalysisModel;
use super::pipeline::SentimentAnalysisPipeline;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::build_cache_key;

pub struct SentimentAnalysisPipelineBuilder<M: SentimentAnalysisModel> {
options: M::Options,
Expand Down Expand Up @@ -41,7 +42,7 @@ impl<M: SentimentAnalysisModel> SentimentAnalysisPipelineBuilder<M> {
Some(d) => d,
None => crate::pipelines::utils::load_device()?,
};
let key = format!("{}-{:?}", self.options.cache_key(), device.location());
let key = build_cache_key(&self.options, &device);
let model = global_cache()
.get_or_create(&key, || M::new(self.options.clone(), device.clone()))
.await?;
Expand All @@ -59,3 +60,4 @@ impl
Self::new(size)
}
}

Loading
Loading