diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 8fd31ad5aa..584c42b41b 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -18,6 +18,7 @@ candle-transformers = { workspace = true } candle-flash-attn = { workspace = true, optional = true } candle-onnx = { workspace = true, optional = true } +chrono = "0.4" csv = "1.3.0" cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } diff --git a/candle-examples/examples/smollm3/README.md b/candle-examples/examples/smollm3/README.md new file mode 100644 index 0000000000..1051816b63 --- /dev/null +++ b/candle-examples/examples/smollm3/README.md @@ -0,0 +1,120 @@ +# SmolLM3 Unified Inference + +A unified Rust implementation for running SmolLM3 models using the Candle ML framework. Supports both quantized (GGUF) and full precision (safetensors) models with a single codebase. + +## Features + +- **Dual Model Support**: Run either quantized or full precision models +- **Multiple Quantization Levels**: Q4_K_M (1.9GB), Q8_0 (3.3GB), F16 (6.2GB) +- **Chat Template Support**: Automatic formatting for instruction-tuned models +- **Thinking Mode**: Enable reasoning traces with `/think` mode +- **NoPE Architecture**: Supports SmolLM3's mixed RoPE/NoPE layer configuration +- **Auto-download**: Automatically fetches models from HuggingFace Hub + +## Quick Start + +### Quantized Model (Recommended) +```bash +cargo run --release --example smollm3 -- \ + --model-type quantized \ + --quantization q8_0 \ + --prompt "Explain Rust's ownership system" +``` + +### Full Precision Model +```bash +cargo run --release --example smollm3 -- \ + --model-type full \ + --dtype f16 \ + --prompt "Write a sorting algorithm in Rust" +``` + +## Command Line Options + +### Model Selection +- `--model-type `: Choose `quantized` or `full` (default: quantized) +- `--model `: Choose `3b` (instruct) or `3b-base` (default: 3b) +- `--quantization `: For quantized models - `q4_k_m`, `q8_0`, or `f16` (default: q8_0) +- `--dtype `: For full models - `f32`, `f16`, `bf16`, or `auto` (default: auto) + +### Generation Parameters +- `--prompt `: The prompt to generate from +- `-n, --sample-len `: Number of tokens to generate (default: 1000) +- `--temperature `: Sampling temperature, 0 for greedy (default: 0.8) +- `--top-p `: Nucleus sampling probability cutoff +- `--top-k `: Only sample among top K tokens +- `--repeat-penalty `: Penalty for repeating tokens (default: 1.1) +- `--repeat-last-n `: Context size for repeat penalty (default: 64) + +### Advanced Options +- `--no-chat-template`: Disable chat template formatting (use for base models) +- `--thinking`: Enable thinking/reasoning mode with `/think` tags +- `--split-prompt`: Process prompt tokens individually (for debugging) +- `--tracing`: Enable performance tracing (generates trace JSON) +- `--model-path `: Use local model file instead of auto-download +- `--tokenizer `: Use local tokenizer instead of auto-download + +## Quantization Comparison + +| Level | Size | Quality | Use Case | +|--------|-------|---------|----------| +| Q4_K_M | 1.9GB | Good | Fast inference, constrained environments | +| Q8_0 | 3.3GB | Better | Balanced quality and speed | +| F16 | 6.2GB | Best | Maximum quality in GGUF format | + +## Examples + +### Creative Writing with Thinking Mode +```bash +cargo run --release --example smollm3 -- \ + --thinking \ + --temperature 0.9 \ + --prompt "Write a short sci-fi story about AI" +``` + +### Code Generation (Base Model) +```bash +cargo run --release --example smollm3 -- \ + --model 3b-base \ + --no-chat-template \ + --temperature 0.2 \ + --prompt "def fibonacci(n):" +``` + +### High Quality Output +```bash +cargo run --release --example smollm3 -- \ + --model-type full \ + --dtype f16 \ + --temperature 0.7 \ + --prompt "Explain quantum entanglement" +``` + +## Model Architecture + +SmolLM3 uses a hybrid RoPE/NoPE architecture: +- **RoPE layers**: Standard rotary position embeddings (75% of layers) +- **NoPE layers**: No position embeddings (25% of layers - every 4th layer) + +This configuration is automatically detected and handled by the implementation. + +## Hardware Requirements + +- **Quantized Q4_K_M**: ~2.5GB RAM +- **Quantized Q8_0**: ~4GB RAM +- **Full F16**: ~7GB RAM +- **Full F32**: ~13GB RAM + +GPU acceleration supported via CUDA (with `cuda` feature) or Metal (macOS). + +## Troubleshooting + +**Model download fails**: Check internet connection and HuggingFace Hub access + +**Out of memory**: Try a smaller quantization level or use `--sample-len` to reduce generation length + +**Compilation errors**: Ensure you're using the latest version of the Candle crate + +## License + +This implementation follows the Candle framework license. SmolLM3 models are available under Apache 2.0. \ No newline at end of file diff --git a/candle-examples/examples/smollm3/main.rs b/candle-examples/examples/smollm3/main.rs new file mode 100644 index 0000000000..397417121e --- /dev/null +++ b/candle-examples/examples/smollm3/main.rs @@ -0,0 +1,618 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; +use std::io::Write; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::{LogitsProcessor, Sampling}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +// Import both model implementations +use candle_transformers::models::smol::quantized_smollm3::QuantizedModelForCausalLM; +use candle_transformers::models::smol::smollm3::{Config, ModelForCausalLM}; + +const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number."; + +// ==================== Model Type Enum ==================== + +enum SmolLM3Model { + Quantized(QuantizedModelForCausalLM), + Full(ModelForCausalLM, Config), // Store config alongside model +} + +impl SmolLM3Model { + fn forward(&mut self, input: &Tensor, pos: usize) -> Result { + match self { + Self::Quantized(model) => Ok(model.forward(input, pos)?), + Self::Full(model, _) => Ok(model.forward(input, pos)?), + } + } + + fn config(&self) -> ModelConfig { + match self { + Self::Quantized(model) => { + let cfg = model.config(); + ModelConfig { + vocab_size: cfg.vocab_size, + hidden_size: cfg.hidden_size, + num_hidden_layers: cfg.num_hidden_layers, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + rope_theta: cfg.rope_theta as f32, // Convert f64 to f32 + eos_token_id: Some(128012), // Default SmolLM3 EOS + no_rope_layers: None, + no_rope_layer_interval: None, + } + } + Self::Full(_, cfg) => { + ModelConfig { + vocab_size: cfg.vocab_size, + hidden_size: cfg.hidden_size, + num_hidden_layers: cfg.num_hidden_layers, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + rope_theta: cfg.rope_theta as f32, // Convert f64 to f32 + eos_token_id: cfg.eos_token_id, + no_rope_layers: cfg + .no_rope_layers + .as_ref() + .map(|v| v.iter().map(|&x| x as u32).collect()), // Convert Vec to Vec + no_rope_layer_interval: cfg.no_rope_layer_interval, + } + } + } + } +} + +// Unified config representation +struct ModelConfig { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + num_key_value_heads: usize, + rope_theta: f32, + eos_token_id: Option, + no_rope_layers: Option>, + no_rope_layer_interval: Option, +} + +impl ModelConfig { + fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } +} + +// ==================== CLI Arguments ==================== + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum ModelType { + /// Use quantized GGUF model (smaller, faster) + #[value(name = "quantized")] + Quantized, + /// Use full precision safetensors model (larger, more accurate) + #[value(name = "full")] + Full, +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Quantization { + #[value(name = "q4_k_m")] + Q4KM, + #[value(name = "q8_0")] + Q8_0, + #[value(name = "f16")] + F16, +} + +impl Quantization { + fn filename_unsloth(&self) -> &'static str { + match self { + Self::Q4KM => "SmolLM3-3B-Q4_K_M.gguf", + Self::Q8_0 => "SmolLM3-3B-Q8_0.gguf", + Self::F16 => "SmolLM3-3B-F16.gguf", + } + } + + fn size_gb(&self) -> f32 { + match self { + Self::Q4KM => 1.92, + Self::Q8_0 => 3.28, + Self::F16 => 6.16, + } + } +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum WhichModel { + #[value(name = "3b")] + W3b, + #[value(name = "3b-base")] + W3bBase, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Model type: 'quantized' for GGUF or 'full' for safetensors + #[arg(long, default_value = "quantized")] + model_type: ModelType, + + /// Which model variant to use + #[arg(long, default_value = "3b")] + model: WhichModel, + + /// Quantization level (only for quantized models) + /// Q8_0: 3.3GB, best quality | Q4_K_M: 1.9GB, good balance | F16: 6.2GB, full precision + #[arg(long, default_value = "q8_0")] + quantization: Quantization, + + /// Data type (only for full models: f32, f16, bf16, or auto) + #[arg(long, default_value = "auto")] + dtype: String, + + /// Path to model file (optional, will auto-download if not provided) + #[arg(long)] + model_path: Option, + + /// Path to tokenizer file (optional, will auto-download if not provided) + #[arg(long)] + tokenizer: Option, + + /// The initial prompt + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens) + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The temperature used to generate samples, use 0 for greedy sampling + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Penalty to be applied for repeating tokens, 1. means no penalty + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// Skip chat template formatting (use raw prompt, like base model) + #[arg(long)] + no_chat_template: bool, + + /// Enable thinking/reasoning mode (allows model to show its reasoning process) + #[arg(long)] + thinking: bool, + + /// Process prompt elements separately (slower, for debugging) + #[arg(long)] + split_prompt: bool, + + /// Enable tracing (generates a trace-timestamp.json file) + #[arg(long)] + tracing: bool, +} + +impl Args { + fn get_tokenizer(&self) -> Result { + let tokenizer_path = match &self.tokenizer { + Some(path) => std::path::PathBuf::from(path), + None => { + let api = Api::new()?; + let api = api.model("HuggingFaceTB/SmolLM3-3B".to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(E::msg) + } + + fn should_use_chat_template(&self) -> bool { + matches!(self.model, WhichModel::W3b) && !self.no_chat_template + } +} + +// ==================== Model Loading ==================== + +fn load_quantized_model(args: &Args, device: &Device) -> Result { + let model_path = match &args.model_path { + Some(path) => std::path::PathBuf::from(path), + None => { + let filename = args.quantization.filename_unsloth(); + let repo_id = "unsloth/SmolLM3-3B-GGUF"; + let api = Api::new()?; + println!( + "Downloading {} from {} (~{:.2}GB)...", + filename, + repo_id, + args.quantization.size_gb() + ); + api.repo(Repo::with_revision( + repo_id.to_string(), + RepoType::Model, + "main".to_string(), + )) + .get(filename)? + } + }; + + println!("Loading quantized model from {:?}...", model_path); + let model = QuantizedModelForCausalLM::from_gguf(&model_path, device)?; + Ok(SmolLM3Model::Quantized(model)) +} + +fn load_full_model(args: &Args, device: &Device) -> Result { + let api = Api::new()?; + let model_id = match args.model { + WhichModel::W3b => "HuggingFaceTB/SmolLM3-3B", + WhichModel::W3bBase => "HuggingFaceTB/SmolLM3-3B-Base", + }; + + println!("Loading full model from: {}", model_id); + let repo = api.repo(Repo::with_revision( + model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + let filenames = match &args.model_path { + Some(path) => vec![std::path::PathBuf::from(path)], + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + + let config_file = repo.get("config.json")?; + let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; + + let dtype = match args.dtype.as_str() { + "f16" => DType::F16, + "bf16" => DType::BF16, + "f32" => DType::F32, + "auto" => { + if device.is_cuda() || device.is_metal() { + DType::BF16 + } else { + DType::F32 + } + } + other => anyhow::bail!("Unsupported dtype: {}, use f16, bf16, f32, or auto", other), + }; + + println!("Using dtype: {:?}", dtype); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, device)? }; + let model = ModelForCausalLM::new(&config, vb)?; + + Ok(SmolLM3Model::Full(model, config)) +} + +// ==================== Text Generation ==================== + +fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) -> String { + if use_chat_template { + // Generate current date dynamically + let now = chrono::Local::now(); + let today_date = now.format("%d %B %Y").to_string(); + + // Set reasoning mode based on thinking flag + let reasoning_mode = if enable_thinking { + "/think" + } else { + "/no_think" + }; + + // Build the assistant start with or without thinking tags + let assistant_start = if enable_thinking { + "<|im_start|>assistant\n\n" // Open for reasoning + } else { + "<|im_start|>assistant\n\n\n\n" // Empty = skip reasoning + }; + + format!( + "<|im_start|>system\n\ +## Metadata\n\ +\n\ +Knowledge Cutoff Date: June 2025\n\ +Today Date: {}\n\ +Reasoning Mode: {}\n\ +\n\ +## Custom Instructions\n\ +\n\ +You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\ +\n\ +<|im_start|>user\n\ +{}<|im_end|>\n\ +{}", + today_date, reasoning_mode, prompt, assistant_start + ) + } else { + prompt.to_string() + } +} + +fn get_eos_token(tokenizer: &Tokenizer, config: &ModelConfig) -> u32 { + if let Some(eos_id) = config.eos_token_id { + return eos_id; + } + + let vocab = tokenizer.get_vocab(true); + if let Some(&eos_id) = vocab.get("<|im_end|>") { + return eos_id; + } + if let Some(&eos_id) = vocab.get("<|endoftext|>") { + return eos_id; + } + + 128012 // Default SmolLM3 EOS token +} + +fn run_generation( + model: &mut SmolLM3Model, + tokenizer: Tokenizer, + args: &Args, + device: &Device, +) -> Result<()> { + let mut tos = TokenOutputStream::new(tokenizer); + + // Prepare prompt + let prompt_str = args + .prompt + .clone() + .unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + let use_chat_template = args.should_use_chat_template(); + let formatted_prompt = format_prompt(&prompt_str, use_chat_template, args.thinking); + + println!("\n=== Generation Settings ==="); + println!("Model type: {:?}", args.model_type); + println!( + "Chat template: {}", + if use_chat_template { + "enabled" + } else { + "disabled" + } + ); + println!( + "Thinking mode: {}", + if args.thinking { + "enabled (/think)" + } else { + "disabled (/no_think)" + } + ); + println!("Raw prompt: {}", prompt_str); + + // Encode prompt + let tokens = tos + .tokenizer() + .encode(formatted_prompt.as_str(), false) + .map_err(E::msg)?; + let tokens = tokens.get_ids(); + println!("Encoded {} tokens", tokens.len()); + + // Setup logits processor + let sampling = if args.temperature <= 0.0 { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { + temperature: args.temperature, + }, + (Some(k), None) => Sampling::TopK { + k, + temperature: args.temperature, + }, + (None, Some(p)) => Sampling::TopP { + p, + temperature: args.temperature, + }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { + k, + p, + temperature: args.temperature, + }, + } + }; + let mut logits_processor = LogitsProcessor::from_sampling(args.seed, sampling); + + // Process prompt + let start_prompt = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, &token) in tokens.iter().enumerate() { + let input = Tensor::new(&[token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + next_token = logits_processor.sample(&logits)?; + } + next_token + }; + let prompt_dt = start_prompt.elapsed(); + + // Get EOS token + let config = model.config(); + let eos_token = get_eos_token(tos.tokenizer(), &config); + + // Generate tokens + let mut all_tokens = vec![next_token]; + print!("\n=== Output ===\n"); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let start_generation = std::time::Instant::now(); + let to_sample = args.sample_len.saturating_sub(1); + let mut sampled = 0; + + for index in 0..to_sample { + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + + let logits = if args.repeat_penalty == 1.0 { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + sampled += 1; + if next_token == eos_token { + break; + } + } + + if let Some(rest) = tos.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + + let generation_dt = start_generation.elapsed(); + + // Print statistics + println!( + "\n\n=== Statistics ===\n\ + {:4} prompt tokens processed: {:.2} token/s\n\ + {:4} tokens generated: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + sampled, + sampled as f64 / generation_dt.as_secs_f64(), + ); + + Ok(()) +} + +// ==================== Main ==================== + +fn print_model_info(config: &ModelConfig) { + println!("\n=== Model Configuration ==="); + println!("Vocab size: {}", config.vocab_size); + println!("Hidden size: {}", config.hidden_size); + println!("Num layers: {}", config.num_hidden_layers); + println!("Num attention heads: {}", config.num_attention_heads); + println!("Num KV heads: {}", config.num_key_value_heads); + println!("Head dim: {}", config.head_dim()); + println!("RoPE theta: {:.0}", config.rope_theta); + + // Print RoPE/NoPE layer info for full models + if let Some(ref no_rope_layers) = config.no_rope_layers { + let num_rope_layers = no_rope_layers.iter().filter(|&&x| x == 1).count(); + let num_nope_layers = no_rope_layers.iter().filter(|&&x| x == 0).count(); + println!("\nLayer Configuration:"); + println!( + " RoPE layers: {} ({}%)", + num_rope_layers, + num_rope_layers * 100 / config.num_hidden_layers + ); + println!( + " NoPE layers: {} ({}%)", + num_nope_layers, + num_nope_layers * 100 / config.num_hidden_layers + ); + } else if let Some(interval) = config.no_rope_layer_interval { + let num_nope_layers = config.num_hidden_layers / interval; + let num_rope_layers = config.num_hidden_layers - num_nope_layers; + println!("\nLayer Configuration:"); + println!( + " RoPE layers: {} ({}%)", + num_rope_layers, + num_rope_layers * 100 / config.num_hidden_layers + ); + println!( + " NoPE layers: {} ({}%) - every {}th layer", + num_nope_layers, + num_nope_layers * 100 / config.num_hidden_layers, + interval + ); + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!("=== SmolLM3 Unified Inference ==="); + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2}, repeat-penalty: {:.2}, repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let device = candle_examples::device(false)?; + + // Load model + let mut model = match args.model_type { + ModelType::Quantized => load_quantized_model(&args, &device)?, + ModelType::Full => load_full_model(&args, &device)?, + }; + + println!("Model loaded in {:.2}s", start.elapsed().as_secs_f32()); + + // Print model info + let config = model.config(); + print_model_info(&config); + + // Load tokenizer + let tokenizer = args.get_tokenizer()?; + + // Run generation + run_generation(&mut model, tokenizer, &args, &device)?; + + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e77ba4a36f..b087553fb8 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -112,6 +112,7 @@ pub mod rwkv_v6; pub mod segformer; pub mod segment_anything; pub mod siglip; +pub mod smol; pub mod snac; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/smol/README.md b/candle-transformers/src/models/smol/README.md new file mode 100644 index 0000000000..5a9e260c9b --- /dev/null +++ b/candle-transformers/src/models/smol/README.md @@ -0,0 +1,259 @@ +# SmolLM Model Family + +This directory contains implementations for the SmolLM family of models +developed by HuggingFace. + +## Models + +### SmolLM2 (see `models/llama`) +SmolLM2 models (135M, 360M, 1.7B) use the standard Llama3 architecture +and are implemented in `models/llama.rs`. No separate implementation +is needed. + +**Variants:** +- HuggingFaceTB/SmolLM2-135M +- HuggingFaceTB/SmolLM2-360M +- HuggingFaceTB/SmolLM2-1.7B + +### SmolLM3 +SmolLM3-3B introduces NoPE (No Positional Encoding) which requires +a custom implementation in `smollm3.rs`. + +**Key innovations:** +- Hybrid RoPE/NoPE (3:1 ratio - every 4th layer uses NoPE) +- GQA with 4 groups (32 attention heads, 8 KV heads) +- Very high rope_theta (5M vs typical 10k-500k) +- Long context support (64k-128k tokens) +- Thinking mode support with `` tags + +**Implementations:** +- `smollm3.rs` - Full precision model (safetensors) +- `quantized_smollm3.rs` - Quantized GGUF model with weight reconstruction + +**Available Models:** +- HuggingFaceTB/SmolLM3-3B (Instruct-tuned) +- HuggingFaceTB/SmolLM3-3B-Base (Base model) +- unsloth/SmolLM3-3B-GGUF (Quantized: Q4_K_M, Q8_0, F16) + +### SmolVLM (planned) +Vision-language model variant, to be implemented. + +## Implementation Details + +### NoPE Architecture +SmolLM3 uses a mixed approach to positional encoding: +```rust +pub fn should_skip_rope(&self, layer_idx: usize) -> bool { + // Method 1: Explicit array from config + if let Some(ref no_rope_layers) = self.no_rope_layers { + if layer_idx < no_rope_layers.len() { + return no_rope_layers[layer_idx] == 0; + } + } + + // Method 2: Interval pattern (SmolLM3-3B default) + // Every 4th layer (indices 3, 7, 11, ...) skips RoPE + if let Some(interval) = self.no_rope_layer_interval { + return (layer_idx + 1) % interval == 0; + } + + false // Default: use RoPE +} +``` + +### Quantized Weight Reconstruction +The quantized implementation includes special handling for Q/K weight +reconstruction to maintain compatibility with the GGUF format's +interleaved weight storage. + +### Thinking Mode +SmolLM3 supports explicit reasoning with thinking tags: +- **Enabled**: `<|im_start|>assistant\n\n` (model generates reasoning) +- **Disabled**: `<|im_start|>assistant\n\n\n\n` (skip to answer) + +## Usage Example + +See `examples/smollm3/main.rs` for a unified implementation that supports +both quantized and full precision models with a single codebase. + +```bash +# Quantized model (recommended) +cargo run --release --example smollm3 -- \ + --model-type quantized \ + --quantization q8_0 \ + --prompt "Explain Rust's ownership system" + +# Full precision model +cargo run --release --example smollm3 -- \ + --model-type full \ + --dtype f16 \ + --prompt "Write a sorting algorithm" + +# Enable thinking mode +cargo run --release --example smollm3 -- \ + --thinking \ + --prompt "Solve this logic puzzle step by step" +``` + +## Performance Characteristics + +| Model Type | Size | Speed | Quality | Use Case | +|------------|-------|-------|---------|----------| +| Q4_K_M | 1.9GB | Fast | Good | Resource-constrained | +| Q8_0 | 3.3GB | Fast | Better | Balanced | +| F16 (GGUF) | 6.2GB | Med | Best | High quality GGUF | +| F16 (Safe) | 6.2GB | Med | Best | Maximum quality | +| F32 (Safe) | 12GB | Slow | Best | Research/debugging | + +# Credits & Attribution + +## SmolLM3 Model + +### Developers +**HuggingFace Team (HuggingFaceTB)** + +The SmolLM family of models represents cutting-edge work in efficient language models, demonstrating that small models can achieve impressive capabilities when trained on high-quality data. + +### Resources +- **Model Card**: https://huggingface.co/HuggingFaceTB/SmolLM3-3B +- **Model Card (Base)**: https://huggingface.co/HuggingFaceTB/SmolLM3-3B-Base +- **Collection**: https://huggingface.co/collections/HuggingFaceTB/smollm3-6723884a9c35673e4f9b74a2 +- **Blog Post**: https://huggingface.co/blog/smollm3 +- **GitHub Repository**: https://github.com/huggingface/smollm +- **License**: Apache 2.0 + +### Key Contributors +The SmolLM project is developed by the HuggingFace team with contributions from researchers focused on efficient LLM architectures and training methods. + +## NoPE Architecture + +### Research Paper +**Title**: "Length Generalization of Causal Transformers without Position Encoding" + +**Authors**: +- Jie Wang (Fudan University) +- Tao Ji (Fudan University) +- Yuanbin Wu (Fudan University) +- Hang Yan (Fudan University) +- Tao Gui (Fudan University) +- Qi Zhang (Fudan University) +- Xuanjing Huang (Fudan University) +- Xiaoling Wang (Fudan University) + +**Published**: NeurIPS 2024 (Thirty-Eighth Annual Conference on Neural Information Processing Systems) + +**Abstract Summary**: The paper demonstrates that removing positional encoding from selected layers (NoPE - No Positional Encoding) can improve length generalization in causal transformers while maintaining or improving performance. SmolLM3 implements this with a 3:1 RoPE/NoPE ratio. + +**Resources**: +- **arXiv**: https://arxiv.org/abs/2410.01926 +- **Conference**: NeurIPS 2024 + +### Key Innovation +The hybrid approach uses: +- **RoPE layers** (75%): Standard rotary positional embeddings for local context +- **NoPE layers** (25%): No positional encoding for improved length generalization +- **Pattern**: Every 4th layer uses NoPE (layers 3, 7, 11, 15, etc.) + +This architecture enables SmolLM3 to handle much longer contexts (64k-128k tokens) while maintaining efficiency. + +## Quantized Models + +### Unsloth +Quantized GGUF models are provided by **Unsloth**, a team focused on making LLM inference and fine-tuning more accessible. + +**Resources**: +- **GGUF Repository**: https://huggingface.co/unsloth/SmolLM3-3B-GGUF +- **Available Quantizations**: Q4_K_M, Q8_0, F16 +- **Website**: https://unsloth.ai/ + +The quantization work enables running SmolLM3 efficiently on consumer hardware with minimal quality loss. + +## Implementation Credits + +### This Candle Implementation +**Implemented for**: Candle ML Framework +**Implementation Date**: Nov 2025 +**Features**: +- Full precision model (F32/F16/BF16) +- Quantized model (Q4_K_M/Q8_0/F16 GGUF) +- Unified example supporting both +- Verified against reference implementations + +**Verification**: +- Full precision: Validated against HuggingFace Transformers Python implementation +- Quantized: Validated against llama.cpp implementation + +### Related Tools & Frameworks + +**Candle**: Minimalist ML framework in Rust by HuggingFace +- GitHub: https://github.com/huggingface/candle + +**llama.cpp**: Efficient LLM inference in C/C++ +- GitHub: https://github.com/ggerganov/llama.cpp +- Used for quantized model verification + +**HuggingFace Transformers**: Reference Python implementation +- GitHub: https://github.com/huggingface/transformers +- Used for full model verification + +## Acknowledgments + +Special thanks to: + +1. **HuggingFace Team** - For developing SmolLM3 and making it openly available under Apache 2.0 license +2. **NoPE Researchers** - For advancing the field with novel positional encoding approaches +3. **Unsloth** - For providing optimized quantized versions +4. **Candle Contributors** - For building an excellent ML framework in Rust +5. **Open Source Community** - For tools like llama.cpp that enable verification and benchmarking + +## Citation + +If you use SmolLM3 in your research or applications, please cite: + +### SmolLM3 Model +```bibtex +@misc{smollm3, + title={SmolLM3}, + author={HuggingFace Team}, + year={2024}, + publisher={HuggingFace}, + howpublished={\url{https://huggingface.co/HuggingFaceTB/SmolLM3-3B}} +} +``` + +### NoPE Paper +```bibtex +@inproceedings{wang2024length, + title={Length Generalization of Causal Transformers without Position Encoding}, + author={Wang, Jie and Ji, Tao and Wu, Yuanbin and Yan, Hang and Gui, Tao and Zhang, Qi and Huang, Xuanjing and Wang, Xiaoling}, + booktitle={Thirty-Eighth Annual Conference on Neural Information Processing Systems}, + year={2024} +} +``` + +### Candle Framework +```bibtex +@software{candle, + title={Candle: Minimalist ML Framework}, + author={HuggingFace}, + year={2024}, + url={https://github.com/huggingface/candle} +} +``` + +## License + +- **SmolLM3 Model**: Apache 2.0 +- **This Implementation**: Follows Candle framework license +- **Candle Framework**: Apache 2.0 and MIT dual-licensed + +## Further Reading + +- **SmolLM Blog Series**: https://huggingface.co/blog/smollm and https://huggingface.co/blog/smollm3 +- **Model Card Details**: https://huggingface.co/HuggingFaceTB/SmolLM3-3B +- **NoPE Paper**: https://arxiv.org/abs/2410.01926 +- **Candle Documentation**: https://huggingface.github.io/candle/ + +--- + +This implementation stands on the shoulders of giants. Thank you to all the researchers, engineers, and open source contributors who make this work possible. diff --git a/candle-transformers/src/models/smol/mod.rs b/candle-transformers/src/models/smol/mod.rs new file mode 100644 index 0000000000..b3744385e0 --- /dev/null +++ b/candle-transformers/src/models/smol/mod.rs @@ -0,0 +1,67 @@ +//! SmolLM model family implementations. +//! +//! The SmolLM family consists of efficient language models developed by HuggingFace: +//! - **SmolLM2** (135M, 360M, 1.7B): Uses standard Llama architecture (see `models::llama`) +//! - **SmolLM3** (3B): Introduces hybrid RoPE/NoPE architecture (implemented here) +//! +//! # SmolLM3 Architecture +//! +//! SmolLM3-3B introduces NoPE (No Positional Encoding) as a key innovation: +//! - 3:1 RoPE/NoPE ratio: every 4th layer skips positional encoding +//! - Grouped Query Attention: 32 attention heads, 8 KV heads (4 groups) +//! - High RoPE theta: 5,000,000 (vs typical 10,000-500,000) +//! - Extended context: 64k-128k tokens +//! +//! # Module Structure +//! +//! - [`smollm3`]: Full precision model implementation (safetensors) +//! - [`quantized_smollm3`]: Quantized model implementation (GGUF) +//! +//! # Example Usage +//! +//! ```rust,no_run +//! use candle_transformers::models::smol::smollm3::{Config, ModelForCausalLM}; +//! use candle_transformers::models::smol::quantized_smollm3::QuantizedModelForCausalLM; +//! use candle::{Device, Tensor}; +//! use candle_nn::VarBuilder; +//! +//! # fn main() -> anyhow::Result<()> { +//! let device = Device::Cpu; +//! +//! // Load full precision model +//! let vb = VarBuilder::zeros(candle::DType::F32, &device); +//! let config = Config::default(); +//! let model = ModelForCausalLM::new(&config, vb)?; +//! +//! // Or load quantized model +//! // let model = QuantizedModelForCausalLM::from_gguf(path, &device)?; +//! +//! // Run inference +//! let input = Tensor::new(&[1u32, 2, 3], &device)?.unsqueeze(0)?; +//! let logits = model.forward(&input, 0)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Thinking Mode +//! +//! SmolLM3 supports explicit reasoning via thinking tags in chat templates: +//! - Thinking enabled: `<|im_start|>assistant\n\n` (model generates reasoning) +//! - Thinking disabled: `<|im_start|>assistant\n\n\n\n` (skip to answer) +//! +//! # Performance Considerations +//! +//! | Format | Size | Inference Speed | Quality | +//! |--------|-------|-----------------|---------| +//! | Q4_K_M | 1.9GB | Fastest | Good | +//! | Q8_0 | 3.3GB | Fast | Better | +//! | F16 | 6.2GB | Medium | Best | +//! | F32 | 12GB | Slow | Best | +//! +//! # References +//! +//! - [SmolLM3 Model Card](https://huggingface.co/HuggingFaceTB/SmolLM3-3B) +//! - [NoPE Paper](https://arxiv.org/abs/2410.01926) + +pub mod quantized_smollm3; +pub mod smollm3; diff --git a/candle-transformers/src/models/smol/quantized_smollm3.rs b/candle-transformers/src/models/smol/quantized_smollm3.rs new file mode 100644 index 0000000000..7bbc88f7c3 --- /dev/null +++ b/candle-transformers/src/models/smol/quantized_smollm3.rs @@ -0,0 +1,567 @@ +use crate::models::with_tracing::QMatMul; +use crate::quantized_var_builder::VarBuilder; +use candle::quantized::gguf_file; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::kv_cache::KvCache; +use candle_nn::Activation; +use std::io::Write; +use std::sync::Arc; + +const MAX_SEQ_LEN: usize = 4096; +use candle::IndexOp; + +// ===== RECONSTRUCTION FUNCTION ===== +fn reconstruct_qk_weights(gguf_weight: &Tensor, _num_heads: usize) -> Result { + let total_rows = gguf_weight.dim(0)?; + let half_rows = total_rows / 2; + let chunk_size = 128; + let chunks_per_half = half_rows / chunk_size; + + let mut heads = Vec::new(); + + // First half + for chunk_idx in 0..chunks_per_half { + let chunk_start = chunk_idx * chunk_size; + + // Even rows + let mut head_even = Vec::new(); + for i in (chunk_start..chunk_start + chunk_size).step_by(2) { + head_even.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_even, 0)?); + + // Odd rows + let mut head_odd = Vec::new(); + for i in (chunk_start + 1..chunk_start + chunk_size).step_by(2) { + head_odd.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_odd, 0)?); + } + + // Second half + for chunk_idx in 0..chunks_per_half { + let chunk_start = half_rows + chunk_idx * chunk_size; + + // Even rows + let mut head_even = Vec::new(); + for i in (chunk_start..chunk_start + chunk_size).step_by(2) { + head_even.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_even, 0)?); + + // Odd rows + let mut head_odd = Vec::new(); + for i in (chunk_start + 1..chunk_start + chunk_size).step_by(2) { + head_odd.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_odd, 0)?); + } + + Ok(Tensor::cat(&heads, 0)?) +} + +#[derive(Debug, Clone)] +pub struct QuantizedConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub rope_dimension_count: usize, + pub no_rope_layer_interval: Option, +} + +impl QuantizedConfig { + /// Load config from GGUF metadata + pub fn from_gguf(ct: &gguf_file::Content) -> Result { + let metadata = &ct.metadata; + + // Helper to get required metadata + let get_u32 = |key: &str| -> Result { + metadata + .get(key) + .and_then(|v| v.to_u32().ok()) + .map(|v| v as usize) + .ok_or_else(|| { + candle::Error::Msg(format!("Missing or invalid metadata key: {}", key)) + }) + }; + + let get_f32 = |key: &str| -> Result { + metadata + .get(key) + .and_then(|v| v.to_f32().ok()) + .map(|v| v as f64) + .ok_or_else(|| { + candle::Error::Msg(format!("Missing or invalid metadata key: {}", key)) + }) + }; + + Ok(Self { + vocab_size: get_u32("smollm3.vocab_size")?, + hidden_size: get_u32("smollm3.embedding_length")?, + intermediate_size: get_u32("smollm3.feed_forward_length")?, + num_hidden_layers: get_u32("smollm3.block_count")?, + num_attention_heads: get_u32("smollm3.attention.head_count")?, + num_key_value_heads: get_u32("smollm3.attention.head_count_kv")?, + max_position_embeddings: get_u32("smollm3.context_length").unwrap_or(MAX_SEQ_LEN), + rope_theta: get_f32("smollm3.rope.freq_base")?, + rms_norm_eps: get_f32("smollm3.attention.layer_norm_rms_epsilon")?, + rope_dimension_count: get_u32("smollm3.rope.dimension_count")?, + no_rope_layer_interval: Some(4), + }) + } + + pub fn should_skip_rope(&self, layer_idx: usize) -> bool { + if let Some(interval) = self.no_rope_layer_interval { + return (layer_idx + 1) % interval == 0; + } + false + } + + pub fn head_dim(&self) -> usize { + self.rope_dimension_count + } +} + +#[derive(Debug, Clone)] +struct RmsNorm { + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(weight: Tensor, eps: f64) -> Self { + Self { weight, eps } + } + + fn forward(&self, x: &Tensor) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(candle::D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let result = x_normed.broadcast_mul(&self.weight)?; + result.to_dtype(x_dtype) + } +} + +#[derive(Debug, Clone)] +pub struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + pub fn new(dtype: DType, cfg: &QuantizedConfig, dev: &Device) -> Result { + let dim = cfg.head_dim(); + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, + }) + } + + pub fn apply_rotary_emb( + &self, + q: &Tensor, + k: &Tensor, + offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +fn repeat_kv(x: Tensor, n_rep: usize) -> Result { + if n_rep == 1 { + Ok(x) + } else { + let (b, n_kv_heads, seq_len, head_dim) = x.dims4()?; + x.unsqueeze(2)? + .expand(&[b, n_kv_heads, n_rep, seq_len, head_dim])? + .reshape(&[b, n_kv_heads * n_rep, seq_len, head_dim]) + } +} + +#[derive(Debug, Clone)] +struct QuantizedMLP { + gate_proj: QMatMul, + up_proj: QMatMul, + down_proj: QMatMul, +} + +impl QuantizedMLP { + fn new(vb: VarBuilder, _layer_idx: usize) -> Result { + // VarBuilder.get_no_shape() returns Arc which QMatMul::from_weights expects + let gate_proj = QMatMul::from_weights(vb.get_no_shape("ffn_gate.weight")?)?; + let up_proj = QMatMul::from_weights(vb.get_no_shape("ffn_up.weight")?)?; + let down_proj = QMatMul::from_weights(vb.get_no_shape("ffn_down.weight")?)?; + + Ok(Self { + gate_proj, + up_proj, + down_proj, + }) + } + + fn forward(&self, x: &Tensor) -> Result { + let gate = self.gate_proj.forward(x)?.apply(&Activation::Silu)?; + let up = self.up_proj.forward(x)?; + self.down_proj.forward(&(gate * up)?) + } +} + +#[derive(Debug, Clone)] +struct QuantizedAttention { + q_proj: QMatMul, + k_proj: QMatMul, + v_proj: QMatMul, + o_proj: QMatMul, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + rotary_emb: Option>, + skip_rope: bool, + kv_cache: KvCache, +} + +impl QuantizedAttention { + fn new( + vb: VarBuilder, + cfg: &QuantizedConfig, + layer_idx: usize, + rotary_emb: Option>, + ) -> Result { + let head_dim = cfg.head_dim(); + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + + // For v and o weights, use directly from VarBuilder (already quantized) + // VarBuilder.get_no_shape() returns Arc + let v_proj = QMatMul::from_weights(vb.get_no_shape("attn_v.weight")?)?; + let o_proj = QMatMul::from_weights(vb.get_no_shape("attn_output.weight")?)?; + + // For q and k weights, we need to dequantize, reconstruct, then re-quantize + // IMPORTANT: Do reconstruction on CPU to avoid VRAM exhaustion during model loading + let device = vb.device(); + let cpu = Device::Cpu; + + let q_weight_qtensor = vb.get_no_shape("attn_q.weight")?; + let q_weight_raw = q_weight_qtensor.dequantize(&cpu)?; // Dequantize to CPU + let q_weight = reconstruct_qk_weights(&q_weight_raw, num_heads)?; // Reconstruct on CPU + let q_weight = q_weight.to_device(device)?; // Move to GPU + + // Re-quantize (now on GPU) + use candle::quantized::{GgmlDType, QTensor}; + let q_weight_qtensor = QTensor::quantize(&q_weight, GgmlDType::Q8_0)?; + drop(q_weight_raw); // Explicitly free CPU memory + drop(q_weight); + + let k_weight_qtensor = vb.get_no_shape("attn_k.weight")?; + let k_weight_raw = k_weight_qtensor.dequantize(&cpu)?; // Dequantize to CPU + let k_weight = reconstruct_qk_weights(&k_weight_raw, num_kv_heads)?; // Reconstruct on CPU + let k_weight = k_weight.to_device(device)?; // Move to GPU + + // Re-quantize (now on GPU) + let k_weight_qtensor = QTensor::quantize(&k_weight, GgmlDType::Q8_0)?; + drop(k_weight_raw); // Explicitly free CPU memory + drop(k_weight); + + let q_proj = QMatMul::from_weights(Arc::new(q_weight_qtensor))?; + let k_proj = QMatMul::from_weights(Arc::new(k_weight_qtensor))?; + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups: num_heads / num_kv_heads, + head_dim, + rotary_emb, + skip_rope: cfg.should_skip_rope(layer_idx), + kv_cache: KvCache::new(2, 512), + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let (b, seq_len, _) = x.dims3()?; + + let q = self + .q_proj + .forward(x)? + .reshape((b, seq_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = self + .k_proj + .forward(x)? + .reshape((b, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = self + .v_proj + .forward(x)? + .reshape((b, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (q, k) = if self.skip_rope { + (q, k) + } else if let Some(rope) = &self.rotary_emb { + rope.apply_rotary_emb(&q, &k, offset)? + } else { + (q, k) + }; + + // can remove this continguous call if using ConcatKV-Cache https://github.com/huggingface/candle/pull/3143 + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; + + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; + + let scale = 1.0 / (self.head_dim as f64).sqrt(); + // Make q contiguous before matmul to avoid stride mismatch + let q = q.contiguous()?; + let attn_weights = (q.matmul(&k.t()?)? * scale)?; + + let mut attn_weights = match mask { + Some(mask) => attn_weights.broadcast_add(mask)?, + None => attn_weights, + }; + + attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + + attn_output + .transpose(1, 2)? + .reshape((b, seq_len, self.num_heads * self.head_dim))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +#[derive(Debug, Clone)] +struct QuantizedDecoderLayer { + self_attn: QuantizedAttention, + mlp: QuantizedMLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl QuantizedDecoderLayer { + fn new( + vb: VarBuilder, + cfg: &QuantizedConfig, + layer_idx: usize, + rotary_emb: Option>, + ) -> Result { + let attn_vb = vb.pp(&format!("blk.{layer_idx}")); + + Ok(Self { + self_attn: QuantizedAttention::new(attn_vb.clone(), cfg, layer_idx, rotary_emb)?, + mlp: QuantizedMLP::new(attn_vb.clone(), layer_idx)?, + input_layernorm: RmsNorm::new( + attn_vb + .get_no_shape("attn_norm.weight")? + .dequantize(vb.device())?, + cfg.rms_norm_eps, + ), + post_attention_layernorm: RmsNorm::new( + attn_vb + .get_no_shape("ffn_norm.weight")? + .dequantize(vb.device())?, + cfg.rms_norm_eps, + ), + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let residual = x; + let x = self.input_layernorm.forward(x)?; + let x = self.self_attn.forward(&x, mask, offset)?; + let x = (residual + x)?; + + let residual = &x; + let x = self.post_attention_layernorm.forward(&x)?; + let x = self.mlp.forward(&x)?; + residual + x + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct QuantizedModelForCausalLM { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: QMatMul, + device: Device, + config: QuantizedConfig, +} + +impl QuantizedModelForCausalLM { + pub fn from_gguf>(path: P, device: &Device) -> Result { + use candle::quantized::{GgmlDType, QTensor}; + + // Open file once to read metadata + let mut file = std::fs::File::open(path.as_ref())?; + let content = gguf_file::Content::read(&mut file)?; + let config = QuantizedConfig::from_gguf(&content)?; + + // Create VarBuilder for tensor loading + let vb = VarBuilder::from_gguf(path, device)?; + + // Load embedding tensor - dequantize on CPU first to save VRAM + // (will be used for both embed_tokens and lm_head - tied embeddings) + let cpu = Device::Cpu; + let embed_tensor = vb.get_no_shape("token_embd.weight")?.dequantize(&cpu)?; + let embed_tensor_gpu = embed_tensor.to_device(device)?; // Move to GPU for embedding layer + let embed_tokens = candle_nn::Embedding::new(embed_tensor_gpu, config.hidden_size); + + // Create rotary embedding if needed + let needs_rope = (0..config.num_hidden_layers).any(|i| !config.should_skip_rope(i)); + let rotary_emb = if needs_rope { + Some(Arc::new(RotaryEmbedding::new(DType::F32, &config, device)?)) + } else { + None + }; + + // Load decoder layers + let mut layers = Vec::with_capacity(config.num_hidden_layers); + println!("Loading {} decoder layers...", config.num_hidden_layers); + for layer_idx in 0..config.num_hidden_layers { + if layer_idx % 4 == 0 || layer_idx == config.num_hidden_layers - 1 { + print!( + " Layer {}/{}...\r", + layer_idx + 1, + config.num_hidden_layers + ); + std::io::stdout().flush().ok(); + } + layers.push(QuantizedDecoderLayer::new( + vb.clone(), + &config, + layer_idx, + rotary_emb.clone(), + )?); + } + println!( + " Layer {}/{} - Done! ", + config.num_hidden_layers, config.num_hidden_layers + ); + + // Load output norm + let norm = RmsNorm::new( + vb.get_no_shape("output_norm.weight")?.dequantize(device)?, + config.rms_norm_eps, + ); + + // Load LM head - move CPU embedding tensor to GPU, then quantize + let embed_tensor_for_lm = embed_tensor.to_device(device)?; + let embed_qtensor = QTensor::quantize(&embed_tensor_for_lm, GgmlDType::Q8_0)?; + let lm_head = QMatMul::from_weights(Arc::new(embed_qtensor))?; + drop(embed_tensor); // Free CPU memory + drop(embed_tensor_for_lm); + + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: device.clone(), + config, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, offset: usize) -> Result { + let (batch_size, seq_len) = input_ids.dims2()?; + + // Embed tokens + let mut hidden_states = self.embed_tokens.forward(input_ids)?; + + // Create causal mask if needed + let mask = if seq_len > 1 { + Some(self.create_causal_mask(batch_size, seq_len, offset)?) + } else { + None + }; + + // Forward through decoder layers + for layer in &mut self.layers { + hidden_states = layer.forward(&hidden_states, mask.as_ref(), offset)?; + } + + // Final norm + hidden_states = self.norm.forward(&hidden_states)?; + + // LM head (only last token for generation) + let last_hidden = hidden_states.narrow(1, seq_len - 1, 1)?; + let logits = last_hidden.apply(&self.lm_head)?; + + Ok(logits) + } + + fn create_causal_mask( + &self, + batch_size: usize, + tgt_len: usize, + offset: usize, + ) -> Result { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len + offset).map(move |j| { + if j <= i + offset { + 0f32 + } else { + f32::NEG_INFINITY + } + }) + }) + .collect(); + + Tensor::from_slice( + &mask, + (batch_size, 1, tgt_len, tgt_len + offset), + &self.device, + ) + } + + pub fn clear_kv_cache(&mut self) { + for layer in &mut self.layers { + layer.clear_kv_cache(); + } + } + + pub fn config(&self) -> &QuantizedConfig { + &self.config + } +} diff --git a/candle-transformers/src/models/smol/smollm3.rs b/candle-transformers/src/models/smol/smollm3.rs new file mode 100644 index 0000000000..f006cdd797 --- /dev/null +++ b/candle-transformers/src/models/smol/smollm3.rs @@ -0,0 +1,470 @@ +use crate::{ + models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm}, + utils::repeat_kv, +}; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub hidden_act: Activation, + // Optional fields + pub attention_bias: Option, + pub attention_dropout: Option, + pub mlp_bias: Option, + pub sliding_window: Option, + pub use_sliding_window: Option, + pub rope_scaling: Option, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub pad_token_id: Option, + pub max_window_layers: Option, + // SmolLM3-specific: NoPE configuration + pub no_rope_layers: Option>, + pub no_rope_layer_interval: Option, +} + +impl Config { + pub fn should_skip_rope(&self, layer_idx: usize) -> bool { + // Method 1: Explicit array (some model variants may provide this) + if let Some(ref no_rope_layers) = self.no_rope_layers { + if layer_idx < no_rope_layers.len() { + // 0 = skip RoPE (NoPE), 1 = use RoPE + return no_rope_layers[layer_idx] == 0; + } + } + + // Method 2: Interval pattern (SmolLM3-3B uses this) + // With interval=4: layers 0,1,2 use RoPE; layer 3 skips RoPE (NoPE) + // Pattern: every 4th layer (3,7,11...) skips RoPE + if let Some(interval) = self.no_rope_layer_interval { + return (layer_idx + 1) % interval == 0; + } + + // Default: use RoPE on all layers (standard Llama behavior) + false + } + + /// Calculates head_dim from hidden_size and num_attention_heads + pub fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } +} + +#[derive(Debug, Clone)] +pub(crate) struct SmolLM3RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl SmolLM3RotaryEmbedding { + pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim(); + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, + }) + } + + /// Apply RoPE (q, k shape: B x H x L x D) + pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct SmolLM3MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl SmolLM3MLP { + pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result { + let mlp_bias = cfg.mlp_bias.unwrap_or(false); + Ok(Self { + gate_proj: linear_b( + cfg.hidden_size, + cfg.intermediate_size, + mlp_bias, + vb.pp("gate_proj"), + )?, + up_proj: linear_b( + cfg.hidden_size, + cfg.intermediate_size, + mlp_bias, + vb.pp("up_proj"), + )?, + down_proj: linear_b( + cfg.intermediate_size, + cfg.hidden_size, + mlp_bias, + vb.pp("down_proj"), + )?, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for SmolLM3MLP { + fn forward(&self, x: &Tensor) -> Result { + let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = x.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct SmolLM3Attention { + // projections + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + // hyper params + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + // utils + rotary_emb: Option>, + kv_cache: KvCache, + // NoPE flag + skip_rope: bool, +} + +impl SmolLM3Attention { + pub(crate) fn new( + cfg: &Config, + layer_idx: usize, + rotary_emb: Option>, + vb: VarBuilder, + ) -> Result { + let use_sliding_window = cfg.use_sliding_window.unwrap_or(false); + if use_sliding_window { + candle::bail!("sliding window is not supported") + } + + let head_dim = cfg.head_dim(); + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + + let attention_bias = cfg.attention_bias.unwrap_or(false); + + let q_proj = linear_b( + cfg.hidden_size, + num_heads * head_dim, + attention_bias, + vb.pp("q_proj"), + )?; + + let k_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + attention_bias, + vb.pp("k_proj"), + )?; + + let v_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + attention_bias, + vb.pp("v_proj"), + )?; + let o_proj = linear_b( + num_heads * head_dim, + cfg.hidden_size, + attention_bias, + vb.pp("o_proj"), + )?; + + // Necessary because the hidden_size in the config isn't always accurate + let hidden_size = head_dim * cfg.num_attention_heads; + + // Initialize KV cache with 512 tokens capacity to reduce initial memory allocation. + // The cache will grow in chunks of 512 tokens when needed. + let kv_cache = KvCache::new(2, 512); + + // Check if this layer should skip RoPE (NoPE) + let skip_rope = cfg.should_skip_rope(layer_idx); + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size, + rotary_emb, + kv_cache, + skip_rope, + }) + } + + pub(crate) fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> Result { + let (b, l, _) = x.dims3()?; + + // 1. Proj + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + // 2. Reshape: (B, L, H, D) -> (B, H, L, D) + let q = q + .reshape((b, l, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // 3. RoPE - only apply if this layer should use RoPE (not NoPE) + let (q, k) = if self.skip_rope { + // NoPE: Skip rotary embeddings, but ensure tensors are contiguous + (q.contiguous()?, k.contiguous()?) + } else { + // Apply RoPE + if let Some(ref rope) = self.rotary_emb { + rope.apply(&q, &k, offset)? + } else { + (q, k) + } + }; + + // 4. Accumulate KV cache + // Reset KV cache if we're at the first position + if offset == 0 { + self.kv_cache.reset(); + } + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; + + // 5. GQA repeat_kv + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; + + // 6. Attention score + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + + // 7. Output proj + ctx.transpose(1, 2)? + .reshape((b, l, self.hidden_size))? + .apply(&self.o_proj) + } + + pub fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +#[derive(Debug, Clone)] +pub(crate) struct DecoderLayer { + self_attn: SmolLM3Attention, + mlp: SmolLM3MLP, + ln1: RmsNorm, + ln2: RmsNorm, +} + +impl DecoderLayer { + fn new( + cfg: &Config, + layer_idx: usize, + rotary: Option>, + vb: VarBuilder, + ) -> Result { + let self_attn = SmolLM3Attention::new(cfg, layer_idx, rotary, vb.pp("self_attn"))?; + let mlp = SmolLM3MLP::new(cfg, vb.pp("mlp"))?; + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = (x + h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.mlp)?; + x + h2 + } + + pub fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub(crate) embed_tokens: candle_nn::Embedding, + pub(crate) layers: Vec, + pub(crate) norm: RmsNorm, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + + // Only create rotary embedding if at least one layer uses RoPE + let needs_rope = (0..cfg.num_hidden_layers).any(|i| !cfg.should_skip_rope(i)); + let rotary = if needs_rope { + Some(Arc::new(SmolLM3RotaryEmbedding::new( + vb.dtype(), + cfg, + vb.device(), + )?)) + } else { + None + }; + + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("model.layers"); + for i in 0..cfg.num_hidden_layers { + layers.push(DecoderLayer::new(cfg, i, rotary.clone(), vb_l.pp(i))?); + } + Ok(Self { + embed_tokens, + layers, + norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (b, l) = input.dims2()?; + + let mut h = self.embed_tokens.forward(input)?; + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +#[derive(Debug, Clone)] +pub struct ModelForCausalLM { + base: Model, + lm_head: Linear, +} + +impl ModelForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let base = Model::new(cfg, vb.clone())?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + } else { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (_, l) = input.dims2()?; + + self.base + .forward(input, offset)? + .narrow(1, l - 1, 1)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } +}