diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs
index 3f35b286e1..1ce0448ea2 100644
--- a/candle-transformers/src/models/quantized_qwen3.rs
+++ b/candle-transformers/src/models/quantized_qwen3.rs
@@ -426,4 +426,10 @@ impl ModelWeights {
let last_hidden = h.narrow(1, l - 1, 1)?;
self.lm_head.forward(&last_hidden)?.squeeze(1)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in &mut self.layers {
+ layer.self_attn.kv_cache.reset();
+ }
+ }
}
diff --git a/candle-wasm-examples/quant-qwen3/.cargo/config.toml b/candle-wasm-examples/quant-qwen3/.cargo/config.toml
new file mode 100644
index 0000000000..24f4968402
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/.cargo/config.toml
@@ -0,0 +1,5 @@
+[target.wasm32-unknown-unknown]
+rustflags = [
+ '--cfg', 'getrandom_backend="wasm_js"',
+ '-C', 'target-feature=+simd128',
+]
\ No newline at end of file
diff --git a/candle-wasm-examples/quant-qwen3/Cargo.toml b/candle-wasm-examples/quant-qwen3/Cargo.toml
new file mode 100644
index 0000000000..4ed0f5dcc1
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/Cargo.toml
@@ -0,0 +1,49 @@
+[package]
+name = "candle-wasm-example-quant-qwen3"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[lib]
+crate-type = ["cdylib", "rlib"]
+
+[features]
+default = []
+#simd-flash-attn = ["candle-nn/simd-flash-attn", "candle-transformers/simd-flash-attn"]
+
+[dependencies]
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true} #, features = ["simd-flash-attn"] }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
+num-traits = { workspace = true }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+getrandom = { version = "0.3", features = ["wasm_js"], default-features = false }
+image = { workspace = true }
+log = { workspace = true }
+safetensors = { workspace = true }
+serde = { workspace = true, features = ["derive"] }
+serde_json = { workspace = true }
+rayon = { workspace = true }
+tracing = { workspace = true }
+libc = "0.2"
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+wasm-bindgen = "0.2.87"
+js-sys = "0.3.64"
+web-sys = { version = "0.3.70", features = ["console", "Window", "Performance"] }
+
+[profile.release]
+opt-level = "z"
+lto = true
+codegen-units = 1
+panic = "abort"
+strip = true
\ No newline at end of file
diff --git a/candle-wasm-examples/quant-qwen3/README.md b/candle-wasm-examples/quant-qwen3/README.md
new file mode 100644
index 0000000000..0ad640edb8
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/README.md
@@ -0,0 +1,224 @@
+# Qwen3 WASM Text Generation
+
+A high-performance WebAssembly implementation of the Qwen3-0.6B language model running entirely in the browser. This project demonstrates efficient on-device inference using Rust, WASM, and the Candle ML framework with SIMD optimizations.
+
+## Features
+
+- **Pure Browser Inference**: No server required - runs 100% client-side
+- **SIMD Optimized**: Leverages WebAssembly SIMD for faster inference
+- **Quantized Models**: Supports Q8_0 and Q4_K_M GGUF quantization
+- **Performance Profiling**: Built-in profiler for optimization analysis
+- **Flexible CLI**: Automatic model downloads with progress tracking
+- **Smart Caching**: Uses HuggingFace cache to avoid re-downloads
+
+## Performance
+
+Running on a modern CPU with WASM SIMD support:
+
+| Quantization | Speed | Model Size | Quality |
+|--------------|---------------|------------|---------|
+| **Q8_0** (default) | **8.7 tok/s** | ~645MB | Best |
+| Q4_K_M | 5.8 tok/s | ~380MB | Good |
+
+*Q8_0 provides superior quality with better throughput despite larger size, making it the recommended choice.*
+
+**Performance Note**: Having browser DevTools/console open can significantly reduce inference speed (up to 50% slower). For best performance, close the console during generation and only open it when you need to view profiling stats.
+
+## Requirements
+### Python Dependencies
+```bash
+pip install huggingface-hub tqdm
+```
+
+### Build Tools
+- Rust (latest stable)
+- wasm-pack: `cargo install wasm-pack`
+
+### Browser
+- Modern browser with WebAssembly SIMD support (Chrome 91+, Firefox 89+, Safari 16.4+)
+
+## Quick Start
+
+### 1. Build the WASM Module
+```bash
+wasm-pack build --target web --release
+```
+
+### 2. Run the Server (Auto-downloads model)
+```bash
+./serve.py
+```
+
+The server will:
+- Check for the model in HuggingFace cache
+- Download Q8_0 model (~645MB) if not present
+- Download tokenizer and config files
+- Start serving at http://localhost:8080
+
+### 3. Open Browser
+Navigate to http://localhost:8080 and start generating text!
+
+## CLI Usage
+
+### Basic Usage
+```bash
+# Use default Q8_0 model
+./serve.py
+
+# Use smaller Q4_K_M model (faster download, lower quality)
+./serve.py --model 0.6b-q4
+
+# Change port
+./serve.py --port 3000
+
+# Use custom GGUF model file
+./serve.py --path /path/to/custom-model.gguf
+```
+
+### Available Options
+```bash
+./serve.py --help
+```
+
+**Options:**
+- `--model, -m`: Choose model variant (`0.6b-q8` or `0.6b-q4`)
+- `--path, -p`: Path to custom GGUF model file
+- `--port`: Server port (default: 8080)
+- `--list-models`: Show available models and exit
+
+### List Models
+```bash
+./serve.py --list-models
+```
+
+Output:
+```
+Available models:
+
+ 0.6b-q8:
+ Size: ~645MB
+ Description: 8-bit quantization (best quality)
+ File: Qwen3-0.6B-Q8_0.gguf
+
+ 0.6b-q4:
+ Size: ~380MB
+ Description: 4-bit quantization (smaller, faster)
+ File: Qwen3-0.6B-Q4_K_M.gguf
+```
+
+## Project Structure
+```
+.
+├── src/
+│ ├── lib.rs # WASM bindings
+│ ├── m.rs # Model implementation
+│ └── profiler.rs # Performance profiler
+├── index.html # Web interface
+├── serve.py # Development server with auto-download
+├── Cargo.toml # Rust dependencies
+├── .cargo/
+│ └── config.toml # WASM build config (SIMD flags)
+└── pkg/ # Generated WASM (after build)
+```
+
+
+## Using the Interface
+
+### Text Generation
+1. Enter your prompt in the text field
+2. Click **Generate** to start inference
+3. The model will generate up to set number of maximum tokens (default 100) or until it reaches an end-of-sequence token
+4. Click **Reset** to clear the output and KV cache for a fresh start
+
+### Performance Tools
+
+The interface includes several tools for monitoring and debugging performance:
+
+#### Show Stats
+Prints detailed performance profiling data to the browser console, including:
+- Time spent in each operation (model forward pass, tokenization, etc.)
+- Call counts, average/min/max times
+- Percentage of total time per operation
+
+**When to use**: After generation to analyze which operations are bottlenecks
+
+#### Clear Stats
+Resets all accumulated profiling data to start fresh measurements.
+
+**When to use**: Before running a benchmark or when you want to measure a specific generation without previous data
+
+#### Update Memory
+Refreshes the memory display showing:
+- **JS Heap**: JavaScript heap memory usage (used/total/limit)
+- **WASM Memory**: WebAssembly linear memory usage in MB and pages
+
+**When to use**: To check current memory consumption, especially useful for:
+- Monitoring memory growth during long generations
+- Debugging potential memory leaks
+- Understanding memory requirements for deployment
+
+**Example workflow**:
+1. Click **Clear Stats** to reset measurements
+2. Generate text
+3. Click **Show Stats** and open console to see timing breakdown
+4. Click **Update Memory** to see memory usage
+5. Repeat to compare different prompts or parameters
+
+## Technical Details
+
+### WASM SIMD
+The project uses WebAssembly SIMD128 instructions for accelerated matrix operations. The SIMD feature is enabled in `config.toml`:
+```toml
+[target.wasm32-unknown-unknown]
+rustflags = [
+ '-C', 'target-feature=+simd128',
+]
+```
+
+### Quantization
+Models use GGUF format with different quantization schemes:
+- **Q8_0**: 8-bit quantization, minimal quality loss
+- **Q4_K_M**: 4-bit K-quants, good balance of size and quality
+
+### Model Architecture
+- **Base Model**: Qwen3-0.6B by Alibaba Cloud's Qwen Team
+- **Framework**: Candle (Rust ML framework)
+- **Format**: GGUF (quantized weights)
+- **Context**: Supports variable context length with KV cache
+
+## Development
+
+### Debug Build
+```bash
+wasm-pack build --target web --dev
+```
+
+### Profile Performance
+Open browser console after generation to see detailed timing breakdown:
+```javascript
+// In browser console
+showProfile() // Print performance stats
+clearProfile() // Reset profiler
+updateMemory() // Check memory usage
+```
+
+## Credits
+
+- **Qwen3 Model**: Developed by the [Qwen Team at Alibaba Cloud](https://github.com/QwenLM/Qwen)
+- **Candle Framework**: Rust ML framework by Hugging Face
+- **GGUF Quantization**: Models from [unsloth/Qwen3-0.6B-GGUF](https://huggingface.co/unsloth/Qwen3-0.6B-GGUF)
+
+## License
+
+This implementation is provided as-is. Please refer to the original Qwen3 license for model usage terms.
+
+## Links
+
+- **Qwen Project**: https://github.com/QwenLM/Qwen
+- **Original Model**: https://huggingface.co/Qwen/Qwen3-0.6B
+- **Quantized Models**: https://huggingface.co/unsloth/Qwen3-0.6B-GGUF
+- **Example GitHub**: https://github.com/DrJesseGlass
+
+---
+
+Built using Rust, WebAssembly, and the Candle framework
\ No newline at end of file
diff --git a/candle-wasm-examples/quant-qwen3/index.html b/candle-wasm-examples/quant-qwen3/index.html
new file mode 100644
index 0000000000..41917c4b71
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/index.html
@@ -0,0 +1,609 @@
+
+
+
+
+ Qwen3 WASM Text Generation
+
+
+
+
+
+
+
+
+
+
Initializing...
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Performance Tools
+
+
+
+
+
Loading memory info...
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/candle-wasm-examples/quant-qwen3/serve.py b/candle-wasm-examples/quant-qwen3/serve.py
new file mode 100755
index 0000000000..c401d19730
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/serve.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python3
+import os
+import sys
+import argparse
+from pathlib import Path
+from http.server import HTTPServer, SimpleHTTPRequestHandler
+
+try:
+ from huggingface_hub import hf_hub_download
+ from tqdm import tqdm
+except ImportError:
+ print("Error: Required packages not installed", file=sys.stderr)
+ print("Install with: pip install huggingface-hub tqdm", file=sys.stderr)
+ sys.exit(1)
+
+HOME = Path.home()
+HF_CACHE = HOME / '.cache/huggingface/hub'
+
+# Model configurations
+MODELS = {
+ '0.6b-q8': {
+ 'repo': 'unsloth/Qwen3-0.6B-GGUF',
+ 'filename': 'Qwen3-0.6B-Q8_0.gguf',
+ 'size': '~645MB',
+ 'description': '8-bit quantization (good quality and fastest)'
+ },
+ '0.6b-q4': {
+ 'repo': 'unsloth/Qwen3-0.6B-GGUF',
+ 'filename': 'Qwen3-0.6B-Q4_K_M.gguf',
+ 'size': '~380MB',
+ 'description': '4-bit quantization (smaller, less accurate, slower in WASM SIMD)'
+ }
+}
+
+TOKENIZER_REPO = 'Qwen/Qwen3-0.6B'
+
+
+def download_with_progress(repo_id, filename, cache_dir):
+ """Download a file from HuggingFace with progress bar"""
+ print(f"\nDownloading {filename} from {repo_id}...")
+ try:
+ path = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ cache_dir=cache_dir,
+ resume_download=True
+ )
+ print(f"Downloaded to: {path}")
+ return Path(path)
+ except Exception as e:
+ print(f"Error downloading {filename}: {e}", file=sys.stderr)
+ sys.exit(1)
+
+
+def find_or_download_model(model_key, custom_path=None):
+ """Find model in cache or download it"""
+ if custom_path:
+ custom_path = Path(custom_path)
+ if not custom_path.exists():
+ print(f"Error: Custom path does not exist: {custom_path}", file=sys.stderr)
+ sys.exit(1)
+ print(f"Using custom model: {custom_path}")
+ return custom_path
+
+ model_config = MODELS[model_key]
+ repo_id = model_config['repo']
+ filename = model_config['filename']
+
+ # Check cache first
+ repo_cache = HF_CACHE / f"models--{repo_id.replace('/', '--')}"
+ if repo_cache.exists():
+ snapshots = list((repo_cache / 'snapshots').glob('*'))
+ if snapshots:
+ model_path = snapshots[0] / filename
+ if model_path.exists():
+ print(f"Found model in cache: {model_path}")
+ return model_path
+
+ # Download if not found
+ print(f"Model not found in cache")
+ print(f"Size: {model_config['size']} - {model_config['description']}")
+ return download_with_progress(repo_id, filename, HF_CACHE)
+
+
+def find_or_download_tokenizer():
+ """Find tokenizer files or download them"""
+ repo_cache = HF_CACHE / f"models--{TOKENIZER_REPO.replace('/', '--')}"
+
+ if repo_cache.exists():
+ snapshots = list((repo_cache / 'snapshots').glob('*'))
+ if snapshots:
+ tokenizer_path = snapshots[0] / 'tokenizer.json'
+ config_path = snapshots[0] / 'config.json'
+ if tokenizer_path.exists() and config_path.exists():
+ print(f"Found tokenizer in cache: {snapshots[0]}")
+ return snapshots[0]
+
+ print("Tokenizer not found in cache")
+ print("Downloading tokenizer and config...")
+
+ tokenizer_path = download_with_progress(TOKENIZER_REPO, 'tokenizer.json', HF_CACHE)
+ config_path = download_with_progress(TOKENIZER_REPO, 'config.json', HF_CACHE)
+
+ return tokenizer_path.parent
+
+
+class CustomHandler(SimpleHTTPRequestHandler):
+ model_path = None
+ tokenizer_dir = None
+
+ extensions_map = {
+ **SimpleHTTPRequestHandler.extensions_map,
+ '.wasm': 'application/wasm',
+ }
+
+ def end_headers(self):
+ self.send_header('Access-Control-Allow-Origin', '*')
+ self.send_header('Cross-Origin-Opener-Policy', 'same-origin')
+ self.send_header('Cross-Origin-Embedder-Policy', 'require-corp')
+ SimpleHTTPRequestHandler.end_headers(self)
+
+ def do_GET(self):
+ # Serve model file
+ if self.path.endswith('.gguf'):
+ self.send_file(self.model_path, 'application/octet-stream')
+ elif self.path == '/tokenizer.json':
+ self.send_file(self.tokenizer_dir / 'tokenizer.json', 'application/json')
+ elif self.path == '/config.json':
+ self.send_file(self.tokenizer_dir / 'config.json', 'application/json')
+ else:
+ SimpleHTTPRequestHandler.do_GET(self)
+
+ def send_file(self, filepath, content_type):
+ try:
+ with open(filepath, 'rb') as f:
+ content = f.read()
+ self.send_response(200)
+ self.send_header('Content-Type', content_type)
+ self.send_header('Content-Length', len(content))
+ self.end_headers()
+ self.wfile.write(content)
+ except Exception as e:
+ self.send_error(404, f"File not found: {e}")
+
+ def log_message(self, format, *args):
+ # Suppress default logging for cleaner output
+ pass
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Serve Qwen3 WASM model with automatic downloads',
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Use default Q8_0 model
+ %(prog)s
+
+ # Use Q4 model (smaller, less accurate, slower in WASM SIMD)
+ %(prog)s --model 0.6b-q4
+
+ # Use custom model file
+ %(prog)s --path /path/to/model.gguf
+
+ # Change port
+ %(prog)s --port 3000
+ """
+ )
+
+ parser.add_argument(
+ '--model', '-m',
+ choices=list(MODELS.keys()),
+ default='0.6b-q8',
+ help='Model to use (default: 0.6b-q8)'
+ )
+
+ parser.add_argument(
+ '--path', '-p',
+ type=str,
+ help='Path to custom GGUF model file'
+ )
+
+ parser.add_argument(
+ '--port',
+ type=int,
+ default=8080,
+ help='Server port (default: 8080)'
+ )
+
+ parser.add_argument(
+ '--list-models',
+ action='store_true',
+ help='List available models and exit'
+ )
+
+ args = parser.parse_args()
+
+ if args.list_models:
+ print("\nAvailable models:")
+ for key, config in MODELS.items():
+ print(f"\n {key}:")
+ print(f" Size: {config['size']}")
+ print(f" Description: {config['description']}")
+ print(f" File: {config['filename']}")
+ return
+
+ print("=" * 60)
+ print("Qwen3 WASM Server")
+ print("=" * 60)
+
+ # Find or download model
+ model_path = find_or_download_model(args.model, args.path)
+ tokenizer_dir = find_or_download_tokenizer()
+
+ # Set paths for handler
+ CustomHandler.model_path = model_path
+ CustomHandler.tokenizer_dir = tokenizer_dir
+
+ print("\n" + "=" * 60)
+ print(f"Model: {model_path.name}")
+ print(f"Tokenizer: {tokenizer_dir}")
+ print(f"Serving from: {os.getcwd()}")
+ print(f"Port: {args.port}")
+ print("=" * 60)
+ print(f"\n Server running at http://localhost:{args.port}")
+ print("Press Ctrl+C to stop\n")
+
+ try:
+ server = HTTPServer(('', args.port), CustomHandler)
+ server.serve_forever()
+ except KeyboardInterrupt:
+ print("\n\nShutting down server...")
+ server.shutdown()
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/candle-wasm-examples/quant-qwen3/src/lib.rs b/candle-wasm-examples/quant-qwen3/src/lib.rs
new file mode 100644
index 0000000000..60d29f77a8
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/src/lib.rs
@@ -0,0 +1,15 @@
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+extern "C" {
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}
+
+pub mod m;
+pub mod profiler;
diff --git a/candle-wasm-examples/quant-qwen3/src/m.rs b/candle-wasm-examples/quant-qwen3/src/m.rs
new file mode 100644
index 0000000000..f9ad841419
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/src/m.rs
@@ -0,0 +1,271 @@
+use candle::quantized::gguf_file;
+use candle::{DType, Device, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+use js_sys::Date;
+use std::io::Cursor;
+use tokenizers::Tokenizer;
+use wasm_bindgen::prelude::*;
+
+use crate::console_log;
+use crate::profiler::ProfileGuard;
+use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3;
+
+#[wasm_bindgen]
+pub struct Model {
+ model: QuantizedQwen3,
+ tokenizer: Tokenizer,
+ logits_processor: LogitsProcessor,
+ tokens: Vec,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ eos_token: u32,
+ enable_thinking: bool,
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn load(
+ weights: Vec,
+ tokenizer: Vec,
+ _config: Vec, // Not used for GGUF, but keep for compatibility
+ ) -> Result {
+ let _prof = ProfileGuard::new("total_load");
+ console_error_panic_hook::set_once();
+
+ let device = Device::Cpu;
+
+ // Tokenizer loading
+ {
+ let _prof = ProfileGuard::new("load_tokenizer");
+ console_log!("Loading tokenizer...");
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+
+ // Get EOS token
+ let eos_token = match tokenizer.get_vocab(true).get("<|endoftext|>") {
+ Some(&token) => token,
+ None => match tokenizer.get_vocab(true).get("<|im_end|>") {
+ Some(&token) => token,
+ None => {
+ console_log!("Warning: no EOS token found, using 0");
+ 0
+ }
+ },
+ };
+
+ let start = Date::now();
+ console_log!(
+ "Weights size: {} bytes ({:.2} MB)",
+ weights.len(),
+ weights.len() as f64 / 1_048_576.0
+ );
+
+ // Load GGUF quantized model with SIMD optimizations
+ let model = {
+ let _prof = ProfileGuard::new("parse_gguf");
+
+ let mut cursor = Cursor::new(weights);
+ let content = gguf_file::Content::read(&mut cursor)
+ .map_err(|e| JsError::new(&format!("Failed to read GGUF: {}", e)))?;
+
+ console_log!("GGUF file parsed, loading model weights...");
+
+ // Use the new integrated API with optimizations
+ QuantizedQwen3::from_gguf(content, &mut cursor, &device)?
+ };
+
+ let load_time = (Date::now() - start) / 1000.0;
+ console_log!("Quantized model loaded in {:.2}s", load_time);
+
+ let logits_processor = LogitsProcessor::new(299792458, None, None);
+
+ Ok(Self {
+ model,
+ tokenizer,
+ tokens: vec![],
+ logits_processor,
+ repeat_penalty: 1.,
+ repeat_last_n: 64,
+ eos_token,
+ enable_thinking: true,
+ })
+ }
+ }
+
+ #[wasm_bindgen]
+ pub fn init_with_prompt(
+ &mut self,
+ prompt: String,
+ temp: f64,
+ top_p: f64,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ seed: f64,
+ enable_thinking: bool,
+ ) -> Result {
+ let _prof = ProfileGuard::new("init_with_prompt");
+
+ self.enable_thinking = enable_thinking;
+
+ // Clear KV cache
+ {
+ let _prof = ProfileGuard::new("clear_kv_cache");
+ self.model.clear_kv_cache();
+ }
+
+ let temp = if temp <= 0. { None } else { Some(temp) };
+ let top_p = if top_p <= 0. || top_p >= 1. {
+ None
+ } else {
+ Some(top_p)
+ };
+
+ let seed = seed as u64;
+ self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ self.repeat_penalty = repeat_penalty;
+ self.repeat_last_n = repeat_last_n;
+ self.tokens.clear();
+
+ let formatted_prompt = format_prompt(&prompt, enable_thinking);
+
+ let tokens = {
+ let _prof = ProfileGuard::new("tokenize_prompt");
+ self.tokenizer
+ .encode(formatted_prompt, true)
+ .map_err(|m| JsError::new(&m.to_string()))?
+ .get_ids()
+ .to_vec()
+ };
+
+ console_log!("Prompt encoded to {} tokens", tokens.len());
+
+ let text = self
+ .process(&tokens)
+ .map_err(|m| JsError::new(&m.to_string()))?;
+
+ Ok(text)
+ }
+
+ #[wasm_bindgen]
+ pub fn next_token(&mut self) -> Result {
+ let _prof = ProfileGuard::new("next_token");
+
+ let last_token = *self.tokens.last().unwrap();
+ let text = self
+ .process(&[last_token])
+ .map_err(|m| JsError::new(&m.to_string()))?;
+ Ok(text)
+ }
+
+ #[wasm_bindgen]
+ pub fn is_eos(&self) -> bool {
+ self.tokens.last().map_or(false, |&t| t == self.eos_token)
+ }
+
+ #[wasm_bindgen]
+ pub fn get_token_count(&self) -> usize {
+ self.tokens.len()
+ }
+
+ #[wasm_bindgen]
+ pub fn reset(&mut self) {
+ let _prof = ProfileGuard::new("reset_model");
+ self.tokens.clear();
+ self.model.clear_kv_cache();
+ }
+
+ #[wasm_bindgen]
+ pub fn generate_tokens(&mut self, count: usize) -> Result {
+ let _prof = ProfileGuard::new("generate_tokens_batch");
+
+ let mut result = String::new();
+
+ for _ in 0..count {
+ if self.is_eos() {
+ break;
+ }
+
+ let last_token = *self.tokens.last().unwrap();
+ let text = self
+ .process(&[last_token])
+ .map_err(|m| JsError::new(&m.to_string()))?;
+ result.push_str(&text);
+ }
+
+ Ok(result)
+ }
+}
+
+fn format_prompt(prompt: &str, enable_thinking: bool) -> String {
+ // Set reasoning mode based on thinking flag
+ let reasoning_mode = if enable_thinking {
+ "/think"
+ } else {
+ "/no_think"
+ };
+
+ format!(
+ "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n{}",
+ reasoning_mode,
+ prompt,
+ if !enable_thinking { "\n\n" } else { "" }
+ )
+}
+
+impl Model {
+ fn process(&mut self, tokens: &[u32]) -> candle::Result {
+ let _prof = ProfileGuard::new("process_token");
+
+ let dev = Device::Cpu;
+
+ let input = {
+ let _prof = ProfileGuard::new("create_input_tensor");
+ Tensor::new(tokens, &dev)?.unsqueeze(0)?
+ };
+
+ // Calculate offset (position in sequence)
+ let offset = self.tokens.len();
+
+ // Forward pass - this is where most time is spent
+ let logits = {
+ let _prof = ProfileGuard::new("model_forward");
+ self.model.forward(&input, offset)?
+ };
+
+ let logits = {
+ let _prof = ProfileGuard::new("logits_post_process");
+ logits.squeeze(0)?.to_dtype(DType::F32)?
+ };
+
+ // Apply repeat penalty if enabled
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let _prof = ProfileGuard::new("apply_repeat_penalty");
+ let start_at = self.tokens.len().saturating_sub(self.repeat_last_n);
+ let context = &self.tokens[start_at..];
+ candle_transformers::utils::apply_repeat_penalty(&logits, self.repeat_penalty, context)?
+ };
+
+ let next_token = {
+ let _prof = ProfileGuard::new("sample_token");
+ self.logits_processor.sample(&logits)?
+ };
+
+ self.tokens.push(next_token);
+
+ let token = {
+ let _prof = ProfileGuard::new("decode_token");
+ match self.tokenizer.decode(&[next_token], false) {
+ Ok(token) => token,
+ Err(e) => {
+ console_log!("Error decoding token: {:?}", e);
+ "".to_string()
+ }
+ }
+ };
+
+ Ok(token)
+ }
+}
diff --git a/candle-wasm-examples/quant-qwen3/src/profiler.rs b/candle-wasm-examples/quant-qwen3/src/profiler.rs
new file mode 100644
index 0000000000..5508ce6409
--- /dev/null
+++ b/candle-wasm-examples/quant-qwen3/src/profiler.rs
@@ -0,0 +1,312 @@
+//! Performance profiler for WASM
+//!
+//! Tracks timing and memory usage across different parts of the model.
+
+use std::cell::RefCell;
+use std::collections::HashMap;
+use wasm_bindgen::prelude::*;
+
+thread_local! {
+ static PROFILER: RefCell = RefCell::new(Profiler::new());
+}
+
+#[derive(Debug, Clone, serde::Serialize)]
+pub struct ProfileEntry {
+ pub name: String,
+ pub count: usize,
+ pub total_ms: f64,
+ pub min_ms: f64,
+ pub max_ms: f64,
+ pub avg_ms: f64,
+ pub last_ms: f64,
+}
+
+pub struct Profiler {
+ entries: HashMap,
+ enabled: bool,
+ stack: Vec<(String, f64)>,
+}
+
+#[derive(Debug, Clone)]
+struct ProfileData {
+ count: usize,
+ total_ms: f64,
+ min_ms: f64,
+ max_ms: f64,
+ last_ms: f64,
+}
+
+impl Profiler {
+ fn new() -> Self {
+ Self {
+ entries: HashMap::new(),
+ enabled: true,
+ stack: Vec::new(),
+ }
+ }
+
+ fn start(&mut self, name: &str) {
+ if !self.enabled {
+ return;
+ }
+ let time = js_sys::Date::now();
+ self.stack.push((name.to_string(), time));
+ }
+
+ fn end(&mut self, name: &str) {
+ if !self.enabled {
+ return;
+ }
+
+ let end_time = js_sys::Date::now();
+
+ if let Some((start_name, start_time)) = self.stack.pop() {
+ if start_name != name {
+ web_sys::console::warn_1(
+ &format!(
+ "Profiler mismatch: expected '{}', got '{}'",
+ start_name, name
+ )
+ .into(),
+ );
+ return;
+ }
+
+ let elapsed = end_time - start_time;
+
+ let entry = self.entries.entry(name.to_string()).or_insert(ProfileData {
+ count: 0,
+ total_ms: 0.0,
+ min_ms: f64::INFINITY,
+ max_ms: 0.0,
+ last_ms: 0.0,
+ });
+
+ entry.count += 1;
+ entry.total_ms += elapsed;
+ entry.min_ms = entry.min_ms.min(elapsed);
+ entry.max_ms = entry.max_ms.max(elapsed);
+ entry.last_ms = elapsed;
+ }
+ }
+
+ fn get_entries(&self) -> Vec {
+ let mut entries: Vec<_> = self
+ .entries
+ .iter()
+ .map(|(name, data)| ProfileEntry {
+ name: name.clone(),
+ count: data.count,
+ total_ms: data.total_ms,
+ min_ms: data.min_ms,
+ max_ms: data.max_ms,
+ avg_ms: data.total_ms / data.count as f64,
+ last_ms: data.last_ms,
+ })
+ .collect();
+
+ entries.sort_by(|a, b| b.total_ms.partial_cmp(&a.total_ms).unwrap());
+ entries
+ }
+
+ fn reset(&mut self) {
+ self.entries.clear();
+ self.stack.clear();
+ }
+
+ fn set_enabled(&mut self, enabled: bool) {
+ self.enabled = enabled;
+ }
+}
+
+// Public API
+pub fn profile_start(name: &str) {
+ PROFILER.with(|p| p.borrow_mut().start(name));
+}
+
+pub fn profile_end(name: &str) {
+ PROFILER.with(|p| p.borrow_mut().end(name));
+}
+
+pub fn profile_reset() {
+ PROFILER.with(|p| p.borrow_mut().reset());
+}
+
+pub fn profile_set_enabled(enabled: bool) {
+ PROFILER.with(|p| p.borrow_mut().set_enabled(enabled));
+}
+
+// RAII guard for automatic profiling
+pub struct ProfileGuard {
+ name: String,
+}
+
+impl ProfileGuard {
+ pub fn new(name: &str) -> Self {
+ profile_start(name);
+ Self {
+ name: name.to_string(),
+ }
+ }
+}
+
+impl Drop for ProfileGuard {
+ fn drop(&mut self) {
+ profile_end(&self.name);
+ }
+}
+
+// Macro for easy profiling
+#[macro_export]
+macro_rules! profile_scope {
+ ($name:expr) => {
+ let _guard = $crate::profiler::ProfileGuard::new($name);
+ };
+}
+
+// WASM exports
+#[wasm_bindgen]
+pub struct ProfileStats {
+ entries: Vec,
+}
+
+#[wasm_bindgen]
+impl ProfileStats {
+ #[wasm_bindgen(getter)]
+ pub fn json(&self) -> String {
+ serde_json::to_string(&self.entries).unwrap_or_default()
+ }
+}
+
+#[wasm_bindgen]
+pub fn profile_get_stats() -> ProfileStats {
+ let entries = PROFILER.with(|p| p.borrow().get_entries());
+ ProfileStats { entries }
+}
+
+#[wasm_bindgen]
+pub fn profile_print_stats() {
+ let entries = PROFILER.with(|p| p.borrow().get_entries());
+
+ web_sys::console::log_1(&"".into());
+ web_sys::console::log_1(&"═══════════════════════════════════════════════════════".into());
+ web_sys::console::log_1(&" PERFORMANCE PROFILE ".into());
+ web_sys::console::log_1(&"═══════════════════════════════════════════════════════".into());
+
+ if entries.is_empty() {
+ web_sys::console::log_1(&"No profiling data collected.".into());
+ return;
+ }
+
+ let total_time: f64 = entries.iter().map(|e| e.total_ms).sum();
+
+ web_sys::console::log_1(
+ &format!(
+ "{:<30} {:>8} {:>10} {:>10} {:>10} {:>10}",
+ "Section", "Count", "Total(ms)", "Avg(ms)", "Min(ms)", "Max(ms)"
+ )
+ .into(),
+ );
+ web_sys::console::log_1(&"───────────────────────────────────────────────────────".into());
+
+ for entry in &entries {
+ let percent = (entry.total_ms / total_time) * 100.0;
+ web_sys::console::log_1(
+ &format!(
+ "{:<30} {:>8} {:>10.2} {:>10.3} {:>10.3} {:>10.3} ({:.1}%)",
+ entry.name,
+ entry.count,
+ entry.total_ms,
+ entry.avg_ms,
+ entry.min_ms,
+ entry.max_ms,
+ percent
+ )
+ .into(),
+ );
+ }
+
+ web_sys::console::log_1(&"───────────────────────────────────────────────────────".into());
+ web_sys::console::log_1(&format!("TOTAL TIME: {:.2}ms", total_time).into());
+ web_sys::console::log_1(&"═══════════════════════════════════════════════════════".into());
+}
+
+#[wasm_bindgen]
+pub fn profile_enable(enabled: bool) {
+ profile_set_enabled(enabled);
+ if enabled {
+ web_sys::console::log_1(&"✅ Profiler ENABLED".into());
+ } else {
+ web_sys::console::log_1(&"❌ Profiler DISABLED".into());
+ }
+}
+
+#[wasm_bindgen]
+pub fn profile_clear() {
+ profile_reset();
+ web_sys::console::log_1(&"Profiler CLEARED".into());
+}
+
+// Memory tracking
+#[wasm_bindgen]
+pub fn get_memory_info() -> String {
+ let memory = web_sys::window()
+ .and_then(|w| w.performance())
+ .and_then(|p| js_sys::Reflect::get(&p, &"memory".into()).ok())
+ .and_then(|m| {
+ let used = js_sys::Reflect::get(&m, &"usedJSHeapSize".into())
+ .ok()
+ .and_then(|v| v.as_f64())
+ .unwrap_or(0.0);
+ let total = js_sys::Reflect::get(&m, &"totalJSHeapSize".into())
+ .ok()
+ .and_then(|v| v.as_f64())
+ .unwrap_or(0.0);
+ let limit = js_sys::Reflect::get(&m, &"jsHeapSizeLimit".into())
+ .ok()
+ .and_then(|v| v.as_f64())
+ .unwrap_or(0.0);
+ Some((used, total, limit))
+ });
+
+ if let Some((used, total, limit)) = memory {
+ format!(
+ "Used: {:.2} MB / Total: {:.2} MB / Limit: {:.2} MB ({:.1}%)",
+ used / 1_048_576.0,
+ total / 1_048_576.0,
+ limit / 1_048_576.0,
+ (used / limit) * 100.0
+ )
+ } else {
+ "Memory info not available".to_string()
+ }
+}
+
+#[wasm_bindgen]
+pub fn log_memory() {
+ let info = get_memory_info();
+ web_sys::console::log_1(&format!("Memory: {}", info).into());
+}
+
+#[wasm_bindgen]
+pub fn get_wasm_memory_info() -> String {
+ #[cfg(target_arch = "wasm32")]
+ {
+ let pages = core::arch::wasm32::memory_size(0);
+ let bytes = pages as f64 * 65536.0;
+ let mb = bytes / (1024.0 * 1024.0);
+
+ format!("WASM Memory: {:.2} MB ({} pages of 64KB)", mb, pages)
+ }
+
+ #[cfg(not(target_arch = "wasm32"))]
+ {
+ "Not WASM".to_string()
+ }
+}
+
+#[wasm_bindgen]
+pub fn log_wasm_memory() {
+ let info = get_wasm_memory_info();
+ web_sys::console::log_1(&format!("{}", info).into());
+}