Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/target
.env
review.md
.vscode
.vscode
2 changes: 1 addition & 1 deletion src/core/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ pub struct GenerationConfig {
pub repeat_penalty: Option<f32>,
pub repeat_last_n: Option<usize>,
pub eos_token_ids: Vec<u64>,
}
}
2 changes: 1 addition & 1 deletion src/core/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ mod tests {
assert_eq!(messages.last_assistant(), None);
assert_eq!(messages.system(), None);
}
}
}
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
//! # Transformers
//!
//! A Rust library providing pipelines built on top of the
//! [`candle`](https://github.com/huggingface/candle) crate for running
//! large language models locally. See the [README](../README.md) for
//! full details and usage examples.

pub mod core;
mod loaders;
pub mod models;
Expand Down
2 changes: 1 addition & 1 deletion src/models/components/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
// https://github.com/huggingface/candle/pull/2043
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}
}
2 changes: 1 addition & 1 deletion src/models/components/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ pub mod quantization;

pub use attention::repeat_kv;
pub use layers::RmsNorm;
pub use quantization::{QMatMul, VarBuilder};
pub use quantization::{QMatMul, VarBuilder};
2 changes: 1 addition & 1 deletion src/models/components/quantization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ impl VarBuilder {
pub fn contains_key(&self, key: &str) -> bool {
self.data.contains_key(key)
}
}
}
2 changes: 1 addition & 1 deletion src/models/generation/logits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ pub fn apply_repeat_penalty(
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}
}
2 changes: 1 addition & 1 deletion src/models/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ pub mod sampling;

pub use logits::apply_repeat_penalty;
pub use params::GenerationParams;
pub use sampling::{initialize_logits_processor, LogitsProcessor, Sampling};
pub use sampling::{initialize_logits_processor, LogitsProcessor, Sampling};
2 changes: 1 addition & 1 deletion src/models/generation/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ impl GenerationParams {
min_p,
}
}
}
}
2 changes: 1 addition & 1 deletion src/models/implementations/qwen3_reranker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,4 @@ impl RerankModel for Qwen3RerankModel {
fn device(&self) -> &Device {
&self.device
}
}
}
2 changes: 1 addition & 1 deletion src/pipelines/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ pub trait Predict {
pub trait PredictWithScores {
/// Make predictions on the given text, returning labels with their scores
fn predict(&self, tokenizer: &Tokenizer, text: &str, labels: &[&str]) -> anyhow::Result<Vec<(String, f32)>>;
}
}
2 changes: 1 addition & 1 deletion src/pipelines/reranker_pipeline/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ pub trait RerankModel {
fn get_tokenizer(options: Self::Options) -> anyhow::Result<Tokenizer>;

fn device(&self) -> &Device;
}
}
2 changes: 1 addition & 1 deletion src/pipelines/reranker_pipeline/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ where
pub fn device(&self) -> &candle_core::Device {
self.model.device()
}
}
}
19 changes: 7 additions & 12 deletions src/pipelines/sentiment_analysis_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
use super::model::SentimentAnalysisModel;
use super::pipeline::SentimentAnalysisPipeline;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::build_cache_key;
use crate::pipelines::utils::{build_cache_key, DeviceRequest};

pub struct SentimentAnalysisPipelineBuilder<M: SentimentAnalysisModel> {
options: M::Options,
device: Option<candle_core::Device>,
device_request: DeviceRequest,
}

impl<M: SentimentAnalysisModel> SentimentAnalysisPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device: None,
device_request: DeviceRequest::Default,
}
}

pub fn cpu(mut self) -> Self {
self.device = Some(candle_core::Device::Cpu);
self.device_request = DeviceRequest::Cpu;
self
}

pub fn cuda_device(mut self, index: usize) -> Self {
let dev =
candle_core::Device::new_cuda_with_stream(index).unwrap_or(candle_core::Device::Cpu);
self.device = Some(dev);
self.device_request = DeviceRequest::Cuda(index);
self
}

pub fn device(mut self, device: candle_core::Device) -> Self {
self.device = Some(device);
self.device_request = DeviceRequest::Explicit(device);
self
}

Expand All @@ -38,10 +36,7 @@ impl<M: SentimentAnalysisModel> SentimentAnalysisPipelineBuilder<M> {
M: Clone + Send + Sync + 'static,
M::Options: ModelOptions + Clone,
{
let device = match self.device {
Some(d) => d,
None => crate::pipelines::utils::load_device()?,
};
let device = self.device_request.resolve()?;
let key = build_cache_key(&self.options, &device);
let model = global_cache()
.get_or_create(&key, || M::new(self.options.clone(), device.clone()))
Expand Down
18 changes: 4 additions & 14 deletions src/pipelines/text_generation_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::core::{global_cache, ModelOptions};
use crate::models::{Gemma3Model, Gemma3Size, Qwen3Model, Qwen3Size};
use crate::pipelines::utils::{load_device_with, DeviceRequest};
use crate::pipelines::utils::{DeviceRequest};

use super::parser::XmlParserBuilder;
use super::model::TextGenerationModel;
use super::pipeline::TextGenerationPipeline;
use super::xml_pipeline::XmlGenerationPipeline;
use candle_core::{CudaDevice, Device};
use candle_core::Device;

pub struct TextGenerationPipelineBuilder<M: TextGenerationModel> {
model_options: M::Options,
Expand Down Expand Up @@ -122,12 +122,7 @@ impl<M: TextGenerationModel> TextGenerationPipelineBuilder<M> {
self.top_k.unwrap_or(default_params.top_k),
self.min_p.unwrap_or(default_params.min_p),
);
let device = match self.device_request {
DeviceRequest::Default => load_device_with(None)?,
DeviceRequest::Cpu => Device::Cpu,
DeviceRequest::Cuda(i) => Device::Cuda(CudaDevice::new_with_stream(i)?),
DeviceRequest::Explicit(d) => d,
};
let device = self.device_request.resolve()?;

TextGenerationPipeline::new(model, gen_params, device).await
}
Expand Down Expand Up @@ -165,12 +160,7 @@ impl<M: TextGenerationModel> TextGenerationPipelineBuilder<M> {
builder.register_tag(*tag);
}
let xml_parser = builder.build();
let device = match self.device_request {
DeviceRequest::Default => load_device_with(None)?,
DeviceRequest::Cpu => Device::Cpu,
DeviceRequest::Cuda(i) => Device::Cuda(CudaDevice::new_with_stream(i)?),
DeviceRequest::Explicit(d) => d,
};
let device = self.device_request.resolve()?;

XmlGenerationPipeline::new(model, gen_params, xml_parser, device).await
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ where
let mut this = self.project();
this.inner.as_mut().as_mut().poll_next(cx)
}
}
}
2 changes: 1 addition & 1 deletion src/pipelines/text_generation_pipeline/streaming/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ pub mod completion_stream;
pub mod event_stream;

pub use completion_stream::CompletionStream;
pub use event_stream::EventStream;
pub use event_stream::EventStream;
2 changes: 1 addition & 1 deletion src/pipelines/text_generation_pipeline/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ impl IntoTool for Tool {
fn into_tool(self) -> Tool {
self
}
}
}
Loading