MiniLM sentence transformer implementation in Rust using Burn.
Supports two model variants from HuggingFace:
- all-MiniLM-L6-v2 - 6 layers, faster
- all-MiniLM-L12-v2 - 12 layers, better quality (default)
use burn::backend::ndarray::NdArray;
use minilm_burn::{mean_pooling, MiniLmModel};
type B = NdArray<f32>;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Default::default();
// Load pretrained model and tokenizer (downloads from HuggingFace)
// Use MiniLmVariant::L6 for faster inference, L12 for better quality
let (model, tokenizer) = MiniLmModel::<B>::pretrained(&device, Default::default(), None)?;
// Tokenize and run inference
let output = model.forward(input_ids, attention_mask.clone(), None);
let embeddings = mean_pooling(output.hidden_states, attention_mask);
Ok(())
}pretrained- Enables model download utilities (default)ndarray- NdArray backend (required for inference example and tests)
Backend features for benchmarks:
wgpu- WebGPU backendcuda- CUDA backendtch-cpu- LibTorch CPU backendtch-gpu- LibTorch GPU backend
Run the inference example:
cargo run --example inference --features ndarray --releaseUnit tests:
cargo test --features ndarrayIntegration tests (requires model download):
cargo test --features ndarray -- --ignoredRun for each backend:
cargo bench --features ndarray
cargo bench --features wgpu
cargo bench --features cuda
cargo bench --features tch-cpuResults are saved to target/criterion/ for comparison across backends. View the HTML report:
open target/criterion/report/index.htmlL6 vs L12 (single sentence):
| Variant | ndarray | wgpu | tch-cpu |
|---|---|---|---|
| L6 | 53 ms | 18 ms | 14 ms |
| L12 | 105 ms | 35 ms | 27 ms |
L12 batch scaling:
| Batch size | ndarray | wgpu | tch-cpu |
|---|---|---|---|
| 1 | 102 ms | 35 ms | 26 ms |
| 4 | 387 ms | 39 ms | 49 ms |
| 8 | 774 ms | 44 ms | 77 ms |
| 16 | 1.54 s | 73 ms | 130 ms |
MIT OR Apache-2.0