diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index 3f35b286e1..1ce0448ea2 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -426,4 +426,10 @@ impl ModelWeights { let last_hidden = h.narrow(1, l - 1, 1)?; self.lm_head.forward(&last_hidden)?.squeeze(1) } + + pub fn clear_kv_cache(&mut self) { + for layer in &mut self.layers { + layer.self_attn.kv_cache.reset(); + } + } } diff --git a/candle-wasm-examples/quant-qwen3/.cargo/config.toml b/candle-wasm-examples/quant-qwen3/.cargo/config.toml new file mode 100644 index 0000000000..24f4968402 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.wasm32-unknown-unknown] +rustflags = [ + '--cfg', 'getrandom_backend="wasm_js"', + '-C', 'target-feature=+simd128', +] \ No newline at end of file diff --git a/candle-wasm-examples/quant-qwen3/Cargo.toml b/candle-wasm-examples/quant-qwen3/Cargo.toml new file mode 100644 index 0000000000..4ed0f5dcc1 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "candle-wasm-example-quant-qwen3" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = [] +#simd-flash-attn = ["candle-nn/simd-flash-attn", "candle-transformers/simd-flash-attn"] + +[dependencies] +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true} #, features = ["simd-flash-attn"] } +tokenizers = { workspace = true, features = ["unstable_wasm"] } +num-traits = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +getrandom = { version = "0.3", features = ["wasm_js"], default-features = false } +image = { workspace = true } +log = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +rayon = { workspace = true } +tracing = { workspace = true } +libc = "0.2" + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +wasm-bindgen = "0.2.87" +js-sys = "0.3.64" +web-sys = { version = "0.3.70", features = ["console", "Window", "Performance"] } + +[profile.release] +opt-level = "z" +lto = true +codegen-units = 1 +panic = "abort" +strip = true \ No newline at end of file diff --git a/candle-wasm-examples/quant-qwen3/README.md b/candle-wasm-examples/quant-qwen3/README.md new file mode 100644 index 0000000000..0ad640edb8 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/README.md @@ -0,0 +1,224 @@ +# Qwen3 WASM Text Generation + +A high-performance WebAssembly implementation of the Qwen3-0.6B language model running entirely in the browser. This project demonstrates efficient on-device inference using Rust, WASM, and the Candle ML framework with SIMD optimizations. + +## Features + +- **Pure Browser Inference**: No server required - runs 100% client-side +- **SIMD Optimized**: Leverages WebAssembly SIMD for faster inference +- **Quantized Models**: Supports Q8_0 and Q4_K_M GGUF quantization +- **Performance Profiling**: Built-in profiler for optimization analysis +- **Flexible CLI**: Automatic model downloads with progress tracking +- **Smart Caching**: Uses HuggingFace cache to avoid re-downloads + +## Performance + +Running on a modern CPU with WASM SIMD support: + +| Quantization | Speed | Model Size | Quality | +|--------------|---------------|------------|---------| +| **Q8_0** (default) | **8.7 tok/s** | ~645MB | Best | +| Q4_K_M | 5.8 tok/s | ~380MB | Good | + +*Q8_0 provides superior quality with better throughput despite larger size, making it the recommended choice.* + +**Performance Note**: Having browser DevTools/console open can significantly reduce inference speed (up to 50% slower). For best performance, close the console during generation and only open it when you need to view profiling stats. + +## Requirements +### Python Dependencies +```bash +pip install huggingface-hub tqdm +``` + +### Build Tools +- Rust (latest stable) +- wasm-pack: `cargo install wasm-pack` + +### Browser +- Modern browser with WebAssembly SIMD support (Chrome 91+, Firefox 89+, Safari 16.4+) + +## Quick Start + +### 1. Build the WASM Module +```bash +wasm-pack build --target web --release +``` + +### 2. Run the Server (Auto-downloads model) +```bash +./serve.py +``` + +The server will: +- Check for the model in HuggingFace cache +- Download Q8_0 model (~645MB) if not present +- Download tokenizer and config files +- Start serving at http://localhost:8080 + +### 3. Open Browser +Navigate to http://localhost:8080 and start generating text! + +## CLI Usage + +### Basic Usage +```bash +# Use default Q8_0 model +./serve.py + +# Use smaller Q4_K_M model (faster download, lower quality) +./serve.py --model 0.6b-q4 + +# Change port +./serve.py --port 3000 + +# Use custom GGUF model file +./serve.py --path /path/to/custom-model.gguf +``` + +### Available Options +```bash +./serve.py --help +``` + +**Options:** +- `--model, -m`: Choose model variant (`0.6b-q8` or `0.6b-q4`) +- `--path, -p`: Path to custom GGUF model file +- `--port`: Server port (default: 8080) +- `--list-models`: Show available models and exit + +### List Models +```bash +./serve.py --list-models +``` + +Output: +``` +Available models: + + 0.6b-q8: + Size: ~645MB + Description: 8-bit quantization (best quality) + File: Qwen3-0.6B-Q8_0.gguf + + 0.6b-q4: + Size: ~380MB + Description: 4-bit quantization (smaller, faster) + File: Qwen3-0.6B-Q4_K_M.gguf +``` + +## Project Structure +``` +. +├── src/ +│ ├── lib.rs # WASM bindings +│ ├── m.rs # Model implementation +│ └── profiler.rs # Performance profiler +├── index.html # Web interface +├── serve.py # Development server with auto-download +├── Cargo.toml # Rust dependencies +├── .cargo/ +│ └── config.toml # WASM build config (SIMD flags) +└── pkg/ # Generated WASM (after build) +``` + + +## Using the Interface + +### Text Generation +1. Enter your prompt in the text field +2. Click **Generate** to start inference +3. The model will generate up to set number of maximum tokens (default 100) or until it reaches an end-of-sequence token +4. Click **Reset** to clear the output and KV cache for a fresh start + +### Performance Tools + +The interface includes several tools for monitoring and debugging performance: + +#### Show Stats +Prints detailed performance profiling data to the browser console, including: +- Time spent in each operation (model forward pass, tokenization, etc.) +- Call counts, average/min/max times +- Percentage of total time per operation + +**When to use**: After generation to analyze which operations are bottlenecks + +#### Clear Stats +Resets all accumulated profiling data to start fresh measurements. + +**When to use**: Before running a benchmark or when you want to measure a specific generation without previous data + +#### Update Memory +Refreshes the memory display showing: +- **JS Heap**: JavaScript heap memory usage (used/total/limit) +- **WASM Memory**: WebAssembly linear memory usage in MB and pages + +**When to use**: To check current memory consumption, especially useful for: +- Monitoring memory growth during long generations +- Debugging potential memory leaks +- Understanding memory requirements for deployment + +**Example workflow**: +1. Click **Clear Stats** to reset measurements +2. Generate text +3. Click **Show Stats** and open console to see timing breakdown +4. Click **Update Memory** to see memory usage +5. Repeat to compare different prompts or parameters + +## Technical Details + +### WASM SIMD +The project uses WebAssembly SIMD128 instructions for accelerated matrix operations. The SIMD feature is enabled in `config.toml`: +```toml +[target.wasm32-unknown-unknown] +rustflags = [ + '-C', 'target-feature=+simd128', +] +``` + +### Quantization +Models use GGUF format with different quantization schemes: +- **Q8_0**: 8-bit quantization, minimal quality loss +- **Q4_K_M**: 4-bit K-quants, good balance of size and quality + +### Model Architecture +- **Base Model**: Qwen3-0.6B by Alibaba Cloud's Qwen Team +- **Framework**: Candle (Rust ML framework) +- **Format**: GGUF (quantized weights) +- **Context**: Supports variable context length with KV cache + +## Development + +### Debug Build +```bash +wasm-pack build --target web --dev +``` + +### Profile Performance +Open browser console after generation to see detailed timing breakdown: +```javascript +// In browser console +showProfile() // Print performance stats +clearProfile() // Reset profiler +updateMemory() // Check memory usage +``` + +## Credits + +- **Qwen3 Model**: Developed by the [Qwen Team at Alibaba Cloud](https://github.com/QwenLM/Qwen) +- **Candle Framework**: Rust ML framework by Hugging Face +- **GGUF Quantization**: Models from [unsloth/Qwen3-0.6B-GGUF](https://huggingface.co/unsloth/Qwen3-0.6B-GGUF) + +## License + +This implementation is provided as-is. Please refer to the original Qwen3 license for model usage terms. + +## Links + +- **Qwen Project**: https://github.com/QwenLM/Qwen +- **Original Model**: https://huggingface.co/Qwen/Qwen3-0.6B +- **Quantized Models**: https://huggingface.co/unsloth/Qwen3-0.6B-GGUF +- **Example GitHub**: https://github.com/DrJesseGlass + +--- + +Built using Rust, WebAssembly, and the Candle framework \ No newline at end of file diff --git a/candle-wasm-examples/quant-qwen3/index.html b/candle-wasm-examples/quant-qwen3/index.html new file mode 100644 index 0000000000..41917c4b71 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/index.html @@ -0,0 +1,609 @@ + + + + + Qwen3 WASM Text Generation + + + +
+ +
+

+ Qwen3 Text Generation +

+

+ Running the Qwen3-0.6B language model directly in your browser using WebAssembly. + Qwen3 is developed by the Qwen team at Alibaba Cloud. +

+

+ Quantized models from the Unsloth team. +

+

+ Implementation by Jesse Glass using the Candle framework. +

+
+ + +
+
Initializing...
+ +
+ + +
+ +
+ + +
+
+ +
+ +
+ + + +
+ + + + +
+ + +
+

Performance Tools

+
+ + +
+
Loading memory info...
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/candle-wasm-examples/quant-qwen3/serve.py b/candle-wasm-examples/quant-qwen3/serve.py new file mode 100755 index 0000000000..c401d19730 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/serve.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse +from pathlib import Path +from http.server import HTTPServer, SimpleHTTPRequestHandler + +try: + from huggingface_hub import hf_hub_download + from tqdm import tqdm +except ImportError: + print("Error: Required packages not installed", file=sys.stderr) + print("Install with: pip install huggingface-hub tqdm", file=sys.stderr) + sys.exit(1) + +HOME = Path.home() +HF_CACHE = HOME / '.cache/huggingface/hub' + +# Model configurations +MODELS = { + '0.6b-q8': { + 'repo': 'unsloth/Qwen3-0.6B-GGUF', + 'filename': 'Qwen3-0.6B-Q8_0.gguf', + 'size': '~645MB', + 'description': '8-bit quantization (good quality and fastest)' + }, + '0.6b-q4': { + 'repo': 'unsloth/Qwen3-0.6B-GGUF', + 'filename': 'Qwen3-0.6B-Q4_K_M.gguf', + 'size': '~380MB', + 'description': '4-bit quantization (smaller, less accurate, slower in WASM SIMD)' + } +} + +TOKENIZER_REPO = 'Qwen/Qwen3-0.6B' + + +def download_with_progress(repo_id, filename, cache_dir): + """Download a file from HuggingFace with progress bar""" + print(f"\nDownloading {filename} from {repo_id}...") + try: + path = hf_hub_download( + repo_id=repo_id, + filename=filename, + cache_dir=cache_dir, + resume_download=True + ) + print(f"Downloaded to: {path}") + return Path(path) + except Exception as e: + print(f"Error downloading {filename}: {e}", file=sys.stderr) + sys.exit(1) + + +def find_or_download_model(model_key, custom_path=None): + """Find model in cache or download it""" + if custom_path: + custom_path = Path(custom_path) + if not custom_path.exists(): + print(f"Error: Custom path does not exist: {custom_path}", file=sys.stderr) + sys.exit(1) + print(f"Using custom model: {custom_path}") + return custom_path + + model_config = MODELS[model_key] + repo_id = model_config['repo'] + filename = model_config['filename'] + + # Check cache first + repo_cache = HF_CACHE / f"models--{repo_id.replace('/', '--')}" + if repo_cache.exists(): + snapshots = list((repo_cache / 'snapshots').glob('*')) + if snapshots: + model_path = snapshots[0] / filename + if model_path.exists(): + print(f"Found model in cache: {model_path}") + return model_path + + # Download if not found + print(f"Model not found in cache") + print(f"Size: {model_config['size']} - {model_config['description']}") + return download_with_progress(repo_id, filename, HF_CACHE) + + +def find_or_download_tokenizer(): + """Find tokenizer files or download them""" + repo_cache = HF_CACHE / f"models--{TOKENIZER_REPO.replace('/', '--')}" + + if repo_cache.exists(): + snapshots = list((repo_cache / 'snapshots').glob('*')) + if snapshots: + tokenizer_path = snapshots[0] / 'tokenizer.json' + config_path = snapshots[0] / 'config.json' + if tokenizer_path.exists() and config_path.exists(): + print(f"Found tokenizer in cache: {snapshots[0]}") + return snapshots[0] + + print("Tokenizer not found in cache") + print("Downloading tokenizer and config...") + + tokenizer_path = download_with_progress(TOKENIZER_REPO, 'tokenizer.json', HF_CACHE) + config_path = download_with_progress(TOKENIZER_REPO, 'config.json', HF_CACHE) + + return tokenizer_path.parent + + +class CustomHandler(SimpleHTTPRequestHandler): + model_path = None + tokenizer_dir = None + + extensions_map = { + **SimpleHTTPRequestHandler.extensions_map, + '.wasm': 'application/wasm', + } + + def end_headers(self): + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Cross-Origin-Opener-Policy', 'same-origin') + self.send_header('Cross-Origin-Embedder-Policy', 'require-corp') + SimpleHTTPRequestHandler.end_headers(self) + + def do_GET(self): + # Serve model file + if self.path.endswith('.gguf'): + self.send_file(self.model_path, 'application/octet-stream') + elif self.path == '/tokenizer.json': + self.send_file(self.tokenizer_dir / 'tokenizer.json', 'application/json') + elif self.path == '/config.json': + self.send_file(self.tokenizer_dir / 'config.json', 'application/json') + else: + SimpleHTTPRequestHandler.do_GET(self) + + def send_file(self, filepath, content_type): + try: + with open(filepath, 'rb') as f: + content = f.read() + self.send_response(200) + self.send_header('Content-Type', content_type) + self.send_header('Content-Length', len(content)) + self.end_headers() + self.wfile.write(content) + except Exception as e: + self.send_error(404, f"File not found: {e}") + + def log_message(self, format, *args): + # Suppress default logging for cleaner output + pass + + +def main(): + parser = argparse.ArgumentParser( + description='Serve Qwen3 WASM model with automatic downloads', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use default Q8_0 model + %(prog)s + + # Use Q4 model (smaller, less accurate, slower in WASM SIMD) + %(prog)s --model 0.6b-q4 + + # Use custom model file + %(prog)s --path /path/to/model.gguf + + # Change port + %(prog)s --port 3000 + """ + ) + + parser.add_argument( + '--model', '-m', + choices=list(MODELS.keys()), + default='0.6b-q8', + help='Model to use (default: 0.6b-q8)' + ) + + parser.add_argument( + '--path', '-p', + type=str, + help='Path to custom GGUF model file' + ) + + parser.add_argument( + '--port', + type=int, + default=8080, + help='Server port (default: 8080)' + ) + + parser.add_argument( + '--list-models', + action='store_true', + help='List available models and exit' + ) + + args = parser.parse_args() + + if args.list_models: + print("\nAvailable models:") + for key, config in MODELS.items(): + print(f"\n {key}:") + print(f" Size: {config['size']}") + print(f" Description: {config['description']}") + print(f" File: {config['filename']}") + return + + print("=" * 60) + print("Qwen3 WASM Server") + print("=" * 60) + + # Find or download model + model_path = find_or_download_model(args.model, args.path) + tokenizer_dir = find_or_download_tokenizer() + + # Set paths for handler + CustomHandler.model_path = model_path + CustomHandler.tokenizer_dir = tokenizer_dir + + print("\n" + "=" * 60) + print(f"Model: {model_path.name}") + print(f"Tokenizer: {tokenizer_dir}") + print(f"Serving from: {os.getcwd()}") + print(f"Port: {args.port}") + print("=" * 60) + print(f"\n Server running at http://localhost:{args.port}") + print("Press Ctrl+C to stop\n") + + try: + server = HTTPServer(('', args.port), CustomHandler) + server.serve_forever() + except KeyboardInterrupt: + print("\n\nShutting down server...") + server.shutdown() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/candle-wasm-examples/quant-qwen3/src/lib.rs b/candle-wasm-examples/quant-qwen3/src/lib.rs new file mode 100644 index 0000000000..60d29f77a8 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/src/lib.rs @@ -0,0 +1,15 @@ +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} + +pub mod m; +pub mod profiler; diff --git a/candle-wasm-examples/quant-qwen3/src/m.rs b/candle-wasm-examples/quant-qwen3/src/m.rs new file mode 100644 index 0000000000..f9ad841419 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/src/m.rs @@ -0,0 +1,271 @@ +use candle::quantized::gguf_file; +use candle::{DType, Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use js_sys::Date; +use std::io::Cursor; +use tokenizers::Tokenizer; +use wasm_bindgen::prelude::*; + +use crate::console_log; +use crate::profiler::ProfileGuard; +use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3; + +#[wasm_bindgen] +pub struct Model { + model: QuantizedQwen3, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + tokens: Vec, + repeat_penalty: f32, + repeat_last_n: usize, + eos_token: u32, + enable_thinking: bool, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn load( + weights: Vec, + tokenizer: Vec, + _config: Vec, // Not used for GGUF, but keep for compatibility + ) -> Result { + let _prof = ProfileGuard::new("total_load"); + console_error_panic_hook::set_once(); + + let device = Device::Cpu; + + // Tokenizer loading + { + let _prof = ProfileGuard::new("load_tokenizer"); + console_log!("Loading tokenizer..."); + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + + // Get EOS token + let eos_token = match tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(&token) => token, + None => match tokenizer.get_vocab(true).get("<|im_end|>") { + Some(&token) => token, + None => { + console_log!("Warning: no EOS token found, using 0"); + 0 + } + }, + }; + + let start = Date::now(); + console_log!( + "Weights size: {} bytes ({:.2} MB)", + weights.len(), + weights.len() as f64 / 1_048_576.0 + ); + + // Load GGUF quantized model with SIMD optimizations + let model = { + let _prof = ProfileGuard::new("parse_gguf"); + + let mut cursor = Cursor::new(weights); + let content = gguf_file::Content::read(&mut cursor) + .map_err(|e| JsError::new(&format!("Failed to read GGUF: {}", e)))?; + + console_log!("GGUF file parsed, loading model weights..."); + + // Use the new integrated API with optimizations + QuantizedQwen3::from_gguf(content, &mut cursor, &device)? + }; + + let load_time = (Date::now() - start) / 1000.0; + console_log!("Quantized model loaded in {:.2}s", load_time); + + let logits_processor = LogitsProcessor::new(299792458, None, None); + + Ok(Self { + model, + tokenizer, + tokens: vec![], + logits_processor, + repeat_penalty: 1., + repeat_last_n: 64, + eos_token, + enable_thinking: true, + }) + } + } + + #[wasm_bindgen] + pub fn init_with_prompt( + &mut self, + prompt: String, + temp: f64, + top_p: f64, + repeat_penalty: f32, + repeat_last_n: usize, + seed: f64, + enable_thinking: bool, + ) -> Result { + let _prof = ProfileGuard::new("init_with_prompt"); + + self.enable_thinking = enable_thinking; + + // Clear KV cache + { + let _prof = ProfileGuard::new("clear_kv_cache"); + self.model.clear_kv_cache(); + } + + let temp = if temp <= 0. { None } else { Some(temp) }; + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + + let seed = seed as u64; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); + self.repeat_penalty = repeat_penalty; + self.repeat_last_n = repeat_last_n; + self.tokens.clear(); + + let formatted_prompt = format_prompt(&prompt, enable_thinking); + + let tokens = { + let _prof = ProfileGuard::new("tokenize_prompt"); + self.tokenizer + .encode(formatted_prompt, true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec() + }; + + console_log!("Prompt encoded to {} tokens", tokens.len()); + + let text = self + .process(&tokens) + .map_err(|m| JsError::new(&m.to_string()))?; + + Ok(text) + } + + #[wasm_bindgen] + pub fn next_token(&mut self) -> Result { + let _prof = ProfileGuard::new("next_token"); + + let last_token = *self.tokens.last().unwrap(); + let text = self + .process(&[last_token]) + .map_err(|m| JsError::new(&m.to_string()))?; + Ok(text) + } + + #[wasm_bindgen] + pub fn is_eos(&self) -> bool { + self.tokens.last().map_or(false, |&t| t == self.eos_token) + } + + #[wasm_bindgen] + pub fn get_token_count(&self) -> usize { + self.tokens.len() + } + + #[wasm_bindgen] + pub fn reset(&mut self) { + let _prof = ProfileGuard::new("reset_model"); + self.tokens.clear(); + self.model.clear_kv_cache(); + } + + #[wasm_bindgen] + pub fn generate_tokens(&mut self, count: usize) -> Result { + let _prof = ProfileGuard::new("generate_tokens_batch"); + + let mut result = String::new(); + + for _ in 0..count { + if self.is_eos() { + break; + } + + let last_token = *self.tokens.last().unwrap(); + let text = self + .process(&[last_token]) + .map_err(|m| JsError::new(&m.to_string()))?; + result.push_str(&text); + } + + Ok(result) + } +} + +fn format_prompt(prompt: &str, enable_thinking: bool) -> String { + // Set reasoning mode based on thinking flag + let reasoning_mode = if enable_thinking { + "/think" + } else { + "/no_think" + }; + + format!( + "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n{}", + reasoning_mode, + prompt, + if !enable_thinking { "\n\n" } else { "" } + ) +} + +impl Model { + fn process(&mut self, tokens: &[u32]) -> candle::Result { + let _prof = ProfileGuard::new("process_token"); + + let dev = Device::Cpu; + + let input = { + let _prof = ProfileGuard::new("create_input_tensor"); + Tensor::new(tokens, &dev)?.unsqueeze(0)? + }; + + // Calculate offset (position in sequence) + let offset = self.tokens.len(); + + // Forward pass - this is where most time is spent + let logits = { + let _prof = ProfileGuard::new("model_forward"); + self.model.forward(&input, offset)? + }; + + let logits = { + let _prof = ProfileGuard::new("logits_post_process"); + logits.squeeze(0)?.to_dtype(DType::F32)? + }; + + // Apply repeat penalty if enabled + let logits = if self.repeat_penalty == 1. { + logits + } else { + let _prof = ProfileGuard::new("apply_repeat_penalty"); + let start_at = self.tokens.len().saturating_sub(self.repeat_last_n); + let context = &self.tokens[start_at..]; + candle_transformers::utils::apply_repeat_penalty(&logits, self.repeat_penalty, context)? + }; + + let next_token = { + let _prof = ProfileGuard::new("sample_token"); + self.logits_processor.sample(&logits)? + }; + + self.tokens.push(next_token); + + let token = { + let _prof = ProfileGuard::new("decode_token"); + match self.tokenizer.decode(&[next_token], false) { + Ok(token) => token, + Err(e) => { + console_log!("Error decoding token: {:?}", e); + "".to_string() + } + } + }; + + Ok(token) + } +} diff --git a/candle-wasm-examples/quant-qwen3/src/profiler.rs b/candle-wasm-examples/quant-qwen3/src/profiler.rs new file mode 100644 index 0000000000..5508ce6409 --- /dev/null +++ b/candle-wasm-examples/quant-qwen3/src/profiler.rs @@ -0,0 +1,312 @@ +//! Performance profiler for WASM +//! +//! Tracks timing and memory usage across different parts of the model. + +use std::cell::RefCell; +use std::collections::HashMap; +use wasm_bindgen::prelude::*; + +thread_local! { + static PROFILER: RefCell = RefCell::new(Profiler::new()); +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct ProfileEntry { + pub name: String, + pub count: usize, + pub total_ms: f64, + pub min_ms: f64, + pub max_ms: f64, + pub avg_ms: f64, + pub last_ms: f64, +} + +pub struct Profiler { + entries: HashMap, + enabled: bool, + stack: Vec<(String, f64)>, +} + +#[derive(Debug, Clone)] +struct ProfileData { + count: usize, + total_ms: f64, + min_ms: f64, + max_ms: f64, + last_ms: f64, +} + +impl Profiler { + fn new() -> Self { + Self { + entries: HashMap::new(), + enabled: true, + stack: Vec::new(), + } + } + + fn start(&mut self, name: &str) { + if !self.enabled { + return; + } + let time = js_sys::Date::now(); + self.stack.push((name.to_string(), time)); + } + + fn end(&mut self, name: &str) { + if !self.enabled { + return; + } + + let end_time = js_sys::Date::now(); + + if let Some((start_name, start_time)) = self.stack.pop() { + if start_name != name { + web_sys::console::warn_1( + &format!( + "Profiler mismatch: expected '{}', got '{}'", + start_name, name + ) + .into(), + ); + return; + } + + let elapsed = end_time - start_time; + + let entry = self.entries.entry(name.to_string()).or_insert(ProfileData { + count: 0, + total_ms: 0.0, + min_ms: f64::INFINITY, + max_ms: 0.0, + last_ms: 0.0, + }); + + entry.count += 1; + entry.total_ms += elapsed; + entry.min_ms = entry.min_ms.min(elapsed); + entry.max_ms = entry.max_ms.max(elapsed); + entry.last_ms = elapsed; + } + } + + fn get_entries(&self) -> Vec { + let mut entries: Vec<_> = self + .entries + .iter() + .map(|(name, data)| ProfileEntry { + name: name.clone(), + count: data.count, + total_ms: data.total_ms, + min_ms: data.min_ms, + max_ms: data.max_ms, + avg_ms: data.total_ms / data.count as f64, + last_ms: data.last_ms, + }) + .collect(); + + entries.sort_by(|a, b| b.total_ms.partial_cmp(&a.total_ms).unwrap()); + entries + } + + fn reset(&mut self) { + self.entries.clear(); + self.stack.clear(); + } + + fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } +} + +// Public API +pub fn profile_start(name: &str) { + PROFILER.with(|p| p.borrow_mut().start(name)); +} + +pub fn profile_end(name: &str) { + PROFILER.with(|p| p.borrow_mut().end(name)); +} + +pub fn profile_reset() { + PROFILER.with(|p| p.borrow_mut().reset()); +} + +pub fn profile_set_enabled(enabled: bool) { + PROFILER.with(|p| p.borrow_mut().set_enabled(enabled)); +} + +// RAII guard for automatic profiling +pub struct ProfileGuard { + name: String, +} + +impl ProfileGuard { + pub fn new(name: &str) -> Self { + profile_start(name); + Self { + name: name.to_string(), + } + } +} + +impl Drop for ProfileGuard { + fn drop(&mut self) { + profile_end(&self.name); + } +} + +// Macro for easy profiling +#[macro_export] +macro_rules! profile_scope { + ($name:expr) => { + let _guard = $crate::profiler::ProfileGuard::new($name); + }; +} + +// WASM exports +#[wasm_bindgen] +pub struct ProfileStats { + entries: Vec, +} + +#[wasm_bindgen] +impl ProfileStats { + #[wasm_bindgen(getter)] + pub fn json(&self) -> String { + serde_json::to_string(&self.entries).unwrap_or_default() + } +} + +#[wasm_bindgen] +pub fn profile_get_stats() -> ProfileStats { + let entries = PROFILER.with(|p| p.borrow().get_entries()); + ProfileStats { entries } +} + +#[wasm_bindgen] +pub fn profile_print_stats() { + let entries = PROFILER.with(|p| p.borrow().get_entries()); + + web_sys::console::log_1(&"".into()); + web_sys::console::log_1(&"═══════════════════════════════════════════════════════".into()); + web_sys::console::log_1(&" PERFORMANCE PROFILE ".into()); + web_sys::console::log_1(&"═══════════════════════════════════════════════════════".into()); + + if entries.is_empty() { + web_sys::console::log_1(&"No profiling data collected.".into()); + return; + } + + let total_time: f64 = entries.iter().map(|e| e.total_ms).sum(); + + web_sys::console::log_1( + &format!( + "{:<30} {:>8} {:>10} {:>10} {:>10} {:>10}", + "Section", "Count", "Total(ms)", "Avg(ms)", "Min(ms)", "Max(ms)" + ) + .into(), + ); + web_sys::console::log_1(&"───────────────────────────────────────────────────────".into()); + + for entry in &entries { + let percent = (entry.total_ms / total_time) * 100.0; + web_sys::console::log_1( + &format!( + "{:<30} {:>8} {:>10.2} {:>10.3} {:>10.3} {:>10.3} ({:.1}%)", + entry.name, + entry.count, + entry.total_ms, + entry.avg_ms, + entry.min_ms, + entry.max_ms, + percent + ) + .into(), + ); + } + + web_sys::console::log_1(&"───────────────────────────────────────────────────────".into()); + web_sys::console::log_1(&format!("TOTAL TIME: {:.2}ms", total_time).into()); + web_sys::console::log_1(&"═══════════════════════════════════════════════════════".into()); +} + +#[wasm_bindgen] +pub fn profile_enable(enabled: bool) { + profile_set_enabled(enabled); + if enabled { + web_sys::console::log_1(&"✅ Profiler ENABLED".into()); + } else { + web_sys::console::log_1(&"❌ Profiler DISABLED".into()); + } +} + +#[wasm_bindgen] +pub fn profile_clear() { + profile_reset(); + web_sys::console::log_1(&"Profiler CLEARED".into()); +} + +// Memory tracking +#[wasm_bindgen] +pub fn get_memory_info() -> String { + let memory = web_sys::window() + .and_then(|w| w.performance()) + .and_then(|p| js_sys::Reflect::get(&p, &"memory".into()).ok()) + .and_then(|m| { + let used = js_sys::Reflect::get(&m, &"usedJSHeapSize".into()) + .ok() + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + let total = js_sys::Reflect::get(&m, &"totalJSHeapSize".into()) + .ok() + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + let limit = js_sys::Reflect::get(&m, &"jsHeapSizeLimit".into()) + .ok() + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + Some((used, total, limit)) + }); + + if let Some((used, total, limit)) = memory { + format!( + "Used: {:.2} MB / Total: {:.2} MB / Limit: {:.2} MB ({:.1}%)", + used / 1_048_576.0, + total / 1_048_576.0, + limit / 1_048_576.0, + (used / limit) * 100.0 + ) + } else { + "Memory info not available".to_string() + } +} + +#[wasm_bindgen] +pub fn log_memory() { + let info = get_memory_info(); + web_sys::console::log_1(&format!("Memory: {}", info).into()); +} + +#[wasm_bindgen] +pub fn get_wasm_memory_info() -> String { + #[cfg(target_arch = "wasm32")] + { + let pages = core::arch::wasm32::memory_size(0); + let bytes = pages as f64 * 65536.0; + let mb = bytes / (1024.0 * 1024.0); + + format!("WASM Memory: {:.2} MB ({} pages of 64KB)", mb, pages) + } + + #[cfg(not(target_arch = "wasm32"))] + { + "Not WASM".to_string() + } +} + +#[wasm_bindgen] +pub fn log_wasm_memory() { + let info = get_wasm_memory_info(); + web_sys::console::log_1(&format!("{}", info).into()); +}