Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .claude/settings.local.json

This file was deleted.

92 changes: 92 additions & 0 deletions PERFORMANCE_FIX_SUMMARY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Reranker Performance Fix Summary

## Problem Identified

Your reranker implementation was running ~2.5x slower than the reference implementation (46.7s vs 18.9s) due to unnecessary overhead from KV caching and causal masking.

## Root Causes

1. **KV Cache Reset on Every Forward Pass**
- The model was resetting the KV cache every time (`offset == 0`)
- Since `forward()` was always called with `offset = 0`, the cache was cleared for every document
- This added overhead without any caching benefit

2. **Unnecessary Causal Masking**
- The model created causal attention masks for each forward pass
- Reranking processes the full sequence at once and doesn't need causal masking
- Creating these masks added computational overhead

3. **KV Cache Memory Operations**
- Even without reuse, the KV cache operations (`append`, `contiguous()`) added overhead
- For single-pass inference (like reranking), KV caching is unnecessary

## Changes Made

### 1. Removed KV Cache Usage in Attention
In `AttentionWeights::forward()`:
```rust
// Before:
if offset == 0 {
self.kv_cache.reset();
}
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;

// After:
// For reranking, we don't need KV cache as each document is processed independently
// Directly use k and v without caching
let k = k.contiguous()?;
let v = v.contiguous()?;
```

### 2. Removed Causal Masking
In `ModelWeights::forward()`:
```rust
// Before:
let causal_mask = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset, None)?)
};

// After:
// For reranking, we don't need causal masking as we process the full sequence
let causal_mask = None;
```

### 3. Added Configuration Flag (Optional)
Added `use_kv_cache` flag to allow reverting to old behavior for comparison:
```rust
pub struct Qwen3RerankModel {
// ...
use_kv_cache: bool, // Default: false
}
```

## Expected Performance Improvement

Based on the changes:
- **Before**: ~46.7s for 4 documents
- **After**: Expected ~20-25s (similar to reference implementation)
- **Speedup**: ~2x faster

The optimized implementation should now perform similarly to the reference implementation since both:
- Process documents in a single forward pass
- Don't use KV caching
- Don't apply causal masking

## How to Test

Run your reranker example again:
```bash
cargo run --release --example reranker
```

You should see significantly improved performance, with the total time reduced by approximately 50%.

## Why This Works

1. **Reranking is Not Autoregressive**: Unlike text generation, reranking processes the full query-document pair in one pass
2. **No Sequential Dependencies**: Each document is scored independently, so caching previous computations doesn't help
3. **Simplified Attention**: Without causal masking, attention computation is more efficient

These optimizations align your implementation with the reference code's approach, which is optimal for the reranking use case.
89 changes: 89 additions & 0 deletions benchmark_reranker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use std::time::Instant;
use transformers::models::implementations::Qwen3RerankSize;
use transformers::pipelines::reranker_pipeline::*;
use transformers::pipelines::utils::BasePipelineBuilder;
use transformers::pipelines::utils::DeviceSelectable;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Test documents
let documents = vec![
"Machine learning is a subset of artificial intelligence.",
"Mathematics is the study of numbers and patterns.",
"Physics is the fundamental science of the universe.",
"Cooking is both an art and a science.",
];
let query = "How do neural networks work?";

println!("=== Reranker Performance Benchmark ===\n");

// Benchmark 1: Pipeline approach
println!("1. Testing pipeline approach...");
let start = Instant::now();

let rerank_pipe = RerankPipelineBuilder::qwen3(Qwen3RerankSize::Size0_6B)
.cpu()
.build()
.await?;

let pipeline_build_time = start.elapsed();
println!(" Pipeline build time: {:?}", pipeline_build_time);

let start = Instant::now();
let results = rerank_pipe
.rerank(query, &documents.iter().map(|d| *d).collect::<Vec<_>>())
.await?;
let pipeline_rerank_time = start.elapsed();
println!(" Pipeline rerank time: {:?}", pipeline_rerank_time);
println!(" Total pipeline time: {:?}", pipeline_build_time + pipeline_rerank_time);

// Print results
println!("\n Results:");
for (i, res) in results.iter().enumerate() {
println!(" {}. Doc {} - Score: {:.4}", i+1, res.index, res.score);
}

// Benchmark 2: Direct model approach (sync in spawn_blocking)
println!("\n2. Testing direct model approach (in spawn_blocking)...");
let start = Instant::now();

let results2 = tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<(usize, f32)>> {
use transformers::models::implementations::qwen3_reranker::{Qwen3RerankModel, Qwen3RerankSize};
use transformers::pipelines::reranker_pipeline::model::RerankModel;
use candle_core::Device;

// Load model directly
let start_load = Instant::now();
let mut model = futures::executor::block_on(
Qwen3RerankModel::from_hf(&Device::Cpu, Qwen3RerankSize::Size0_6B)
)?;
let tokenizer = futures::executor::block_on(model.get_tokenizer())?;
let load_time = start_load.elapsed();
println!(" Model load time: {:?}", load_time);

// Rerank
let start_rerank = Instant::now();
let results = model.rerank_documents(&tokenizer, query, &documents.iter().map(|d| *d).collect::<Vec<_>>())?;
let rerank_time = start_rerank.elapsed();
println!(" Direct rerank time: {:?}", rerank_time);

Ok(results.into_iter().map(|r| (r.index, r.score)).collect())
}).await??;

let direct_total_time = start.elapsed();
println!(" Total direct time: {:?}", direct_total_time);

println!("\n Results:");
for (i, (idx, score)) in results2.iter().enumerate() {
println!(" {}. Doc {} - Score: {:.4}", i+1, idx, score);
}

// Compare
println!("\n=== Performance Comparison ===");
println!("Pipeline approach: {:?}", pipeline_build_time + pipeline_rerank_time);
println!("Direct approach: {:?}", direct_total_time);
let speedup = direct_total_time.as_secs_f64() / (pipeline_build_time + pipeline_rerank_time).as_secs_f64();
println!("Direct is {:.2}x faster", speedup);

Ok(())
}
84 changes: 84 additions & 0 deletions examples/benchmark_reranker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::time::Instant;
use transformers::models::implementations::Qwen3RerankSize;
use transformers::pipelines::reranker_pipeline::*;
use transformers::pipelines::utils::BasePipelineBuilder;
use transformers::pipelines::utils::DeviceSelectable;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Test documents
let documents = vec![
"Machine learning is a subset of artificial intelligence.",
"Mathematics is the study of numbers and patterns.",
"Physics is the fundamental science of the universe.",
"Cooking is both an art and a science.",
];
let query = "How do neural networks work?";

println!("=== Reranker Performance Benchmark ===\n");

// Benchmark 1: Pipeline approach
println!("1. Testing pipeline approach...");
let start = Instant::now();

let rerank_pipe = RerankPipelineBuilder::qwen3(Qwen3RerankSize::Size0_6B)
.cpu()
.build()
.await?;

let pipeline_build_time = start.elapsed();
println!(" Pipeline build time: {:?}", pipeline_build_time);

// Warm up
let _ = rerank_pipe
.rerank(query, &documents.iter().map(|d| *d).collect::<Vec<_>>())
.await?;

let start = Instant::now();
for _ in 0..5 {
let _ = rerank_pipe
.rerank(query, &documents.iter().map(|d| *d).collect::<Vec<_>>())
.await?;
}
let pipeline_rerank_time = start.elapsed();
println!(" Pipeline rerank time (5 iterations): {:?}", pipeline_rerank_time);
println!(" Average per iteration: {:?}", pipeline_rerank_time / 5);

// Benchmark 2: Direct model approach (sync in spawn_blocking)
println!("\n2. Testing direct model approach (in spawn_blocking)...");
let start = Instant::now();

tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
use transformers::models::implementations::qwen3_reranker::{Qwen3RerankModel, Qwen3RerankSize};
use transformers::pipelines::reranker_pipeline::model::RerankModel;
use candle_core::Device;

// Load model directly
let start_load = Instant::now();
let mut model = futures::executor::block_on(
Qwen3RerankModel::from_hf(&Device::Cpu, Qwen3RerankSize::Size0_6B)
)?;
let tokenizer = futures::executor::block_on(model.get_tokenizer())?;
let load_time = start_load.elapsed();
println!(" Model load time: {:?}", load_time);

// Warm up
let _ = model.rerank_documents(&tokenizer, query, &documents.iter().map(|d| *d).collect::<Vec<_>>())?;

// Benchmark
let start_rerank = Instant::now();
for _ in 0..5 {
let _ = model.rerank_documents(&tokenizer, query, &documents.iter().map(|d| *d).collect::<Vec<_>>())?;
}
let rerank_time = start_rerank.elapsed();
println!(" Direct rerank time (5 iterations): {:?}", rerank_time);
println!(" Average per iteration: {:?}", rerank_time / 5);

Ok(())
}).await??;

let direct_total_time = start.elapsed();
println!(" Total direct time: {:?}", direct_total_time);

Ok(())
}
48 changes: 48 additions & 0 deletions examples/check_model_layers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use candle_core::quantized::gguf_file;
use std::fs::File;

fn main() -> anyhow::Result<()> {
// Path to GGUF file
let path = std::env::home_dir().unwrap()
.join(".cache/huggingface/hub/models--Mungert--Qwen3-Reranker-0.6B-GGUF/blobs/66867f47323e058f9dbfe24a13268859a84d9e9a8bb89ad0789c7c52131267e2");

println!("Opening GGUF file: {}", path.display());
let mut file = File::open(&path)?;

// Read GGUF content metadata only
let start = std::time::Instant::now();
let content = gguf_file::Content::read(&mut file)?;
println!("Read GGUF metadata in: {:?}", start.elapsed());

// Print metadata
if let Some(layers) = content.metadata.get("qwen3.block_count") {
println!("Number of layers: {:?}", layers);
}

// Count tensors
println!("\nTotal tensors in file: {}", content.tensor_infos.len());

// List first 20 tensor names
println!("\nFirst 20 tensor names:");
for (i, (name, _)) in content.tensor_infos.iter().enumerate() {
if i >= 20 { break; }
println!(" {}: {}", i, name);
}

// Count tensors by prefix
let mut layer_tensors = 0;
let mut other_tensors = 0;
for (name, _) in &content.tensor_infos {
if name.starts_with("blk.") {
layer_tensors += 1;
} else {
other_tensors += 1;
}
}

println!("\nTensor breakdown:");
println!(" Layer tensors (blk.*): {}", layer_tensors);
println!(" Other tensors: {}", other_tensors);

Ok(())
}
50 changes: 50 additions & 0 deletions examples/compare_reranker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use anyhow::Result;
use std::time::Instant;

#[tokio::main]
async fn main() -> Result<()> {
let documents = vec![
"Machine learning is a subset of artificial intelligence.",
"Mathematics is the study of numbers and patterns.",
];
let query = "How do neural networks work?";

println!("=== Testing Direct Model Usage (like minimal example) ===");
let start = Instant::now();

// Direct usage like minimal example
use candle_core::Device;
use transformers::models::implementations::qwen3_reranker::{Qwen3RerankModel, Qwen3RerankSize};

let mut model = Qwen3RerankModel::from_hf(&Device::Cpu, Qwen3RerankSize::Size0_6B).await?;
let tokenizer = model.get_tokenizer().await?;
println!("Model loading time: {:?}", start.elapsed());

let start = Instant::now();
let results = model.rerank_documents(&tokenizer, query, &documents)?;
println!("Direct rerank time: {:?}", start.elapsed());
for res in &results {
println!(" Doc {} - Score: {:.4}", res.index, res.score);
}

println!("\n=== Testing Pipeline Usage ===");
let start = Instant::now();

use transformers::pipelines::reranker_pipeline::*;
use transformers::pipelines::utils::{BasePipelineBuilder, DeviceSelectable};

let rerank_pipe = RerankPipelineBuilder::qwen3(Qwen3RerankSize::Size0_6B)
.cpu()
.build()
.await?;
println!("Pipeline build time: {:?}", start.elapsed());

let start = Instant::now();
let results2 = rerank_pipe.rerank(query, &documents).await?;
println!("Pipeline rerank time: {:?}", start.elapsed());
for res in &results2 {
println!(" Doc {} - Score: {:.4}", res.index, res.score);
}

Ok(())
}
Loading
Loading