Skip to content
Draft
5 changes: 5 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"]
buildtime-download = []
candle-transformers-provence-process = ["candle-transformers/provence-process"]

[[example]]
name = "llama_multiprocess"
Expand Down Expand Up @@ -161,3 +162,7 @@ required-features = ["symphonia"]
[[example]]
name = "bert_single_file_binary"
required-features = ["buildtime-download"]

[[example]]
name = "provence"
required-features = ["candle-transformers-provence-process"]
21 changes: 21 additions & 0 deletions candle-examples/examples/provence/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## provence

This is a port of the [Provence](https://huggingface.co/naver/provence-reranker-debertav3-v1) model. Provence is based on DebertaV3.

> Provence is a lightweight context pruning model for retrieval-augmented generation, particularly optimized for question answering. Given a user question and a retrieved passage, Provence removes sentences from the passage that are not relevant to the user question. This speeds up generation and reduces context noise, in a plug-and-play manner for any LLM.

## Examples

Note that all examples here use the `metal` feature flag provided by the `candle-examples` crate. You may need to adjust this to match your environment.

Also, the `candle-transformers-provence-process` feature flag is required to enable the model's `process` helper function. This will enable additional dependencies, like `tokenizer`. If you only need to run `forward` and not `process`, then no additional dependencies are needed and the `candle-transformers-provence-process` feature flag is not needed.

### Single Questing and Context

```bash
cargo run --example provence --release --features=metal,candle-transformers-provence-process -- -q "What is used to thicken a classic béchamel sauce?" -c "Béchamel sauce. Basics. Béchamel is one of the five mother sauces of French cuisine. It is a simple white sauce made from a roux of butter and flour, to which milk is gradually added while whisking to avoid lumps. The roux acts as the thickening agent, giving the sauce a smooth, creamy texture. Variations. Some chefs add a pinch of nutmeg for flavor. In Italian cuisine, a similar sauce called besciamella is often used in lasagna. Modern adaptations may substitute butter with olive oil or milk with plant-based alternatives, but the thickening principle with flour remains the same." -t="0.35"
```

### Running on CPU

To run the example on CPU, supply the `--cpu` flag.
312 changes: 312 additions & 0 deletions candle-examples/examples/provence/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use std::fmt;
use std::path::PathBuf;

use anyhow::{bail, Context, Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::{debertav2::Config as DebertaV2Config, provence::ProvenceModel};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{Encoding, PaddingParams, Tokenizer};

enum TaskType {
Single(Box<ProvenceModel>),
}

#[derive(Parser, Debug, Clone, ValueEnum)]
enum ArgsTask {
Single,
}

impl fmt::Display for ArgsTask {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ArgsTask::Single => write!(f, "single"),
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// The model id to use from HuggingFace
#[arg(
long,
default_value = "naver/provence-reranker-debertav3-v1",
group = "model_source",
conflicts_with = "model_path"
)]
model_id: String,

/// Local model path
#[arg(long, group = "model_source", conflicts_with = "model_id")]
model_path: Option<PathBuf>,

/// Revision of the model to use (default: "main")
#[arg(long, default_value = "main")]
revision: String,

/// Question
#[arg(
short,
long,
default_value = "What goes on the bottom of Shepherd's pie?"
)]
question: String,

/// Context (either repeat the flag or provide a comma-separated list)
#[arg(
short,
long,
num_args = 1..,
default_values = &[
"Shepherd’s pie. History. In early cookery books, the dish was a means of using leftover roasted meat of any kind, and the pie dish was lined on the sides and bottom with mashed potato, as well as having a mashed potato crust on top. Variations and similar dishes. Other potato-topped pies include: The modern ”Cumberland pie” is a version with either beef or lamb and a layer of bread- crumbs and cheese on top. In medieval times, and modern-day Cumbria, the pastry crust had a filling of meat with fruits and spices.. In Quebec, a varia- tion on the cottage pie is called ”Paˆte ́ chinois”. It is made with ground beef on the bottom layer, canned corn in the middle, and mashed potato on top.. The ”shepherdess pie” is a vegetarian version made without meat, or a vegan version made without meat and dairy.. In the Netherlands, a very similar dish called ”philosopher’s stew” () often adds ingredients like beans, apples, prunes, or apple sauce.. In Brazil, a dish called in refers to the fact that a manioc puree hides a layer of sun-dried meat.",
]
)]
context: Vec<String>,

/// Threshold
#[arg(short, long, default_value = "0.5")]
threshold: f32,

/// Which task to run
#[arg(long, default_value_t = ArgsTask::Single)]
task: ArgsTask,
}

impl Args {
fn build_model_and_tokenizer(&self) -> Result<(TaskType, DebertaV2Config, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;

// Get files from either the HuggingFace API, or from a specified local directory.
let (config_filename, tokenizer_filename, weights_filename) =
get_model_files(&self.model_path, &self.model_id, &self.revision)?;

let config = std::fs::read_to_string(config_filename)?;
let config: DebertaV2Config = serde_json::from_str(&config)?;

let id2label = if let Some(id2label) = &config.id2label {
id2label.clone()
} else {
bail!("Id2Label not found in the model configuration nor specified as a parameter")
};

let mut tokenizer = Tokenizer::from_file(tokenizer_filename)
.map_err(|e| candle::Error::Msg(format!("Tokenizer error: {e}")))?;

tokenizer.with_padding(Some(PaddingParams::default()));

let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&[weights_filename],
candle_transformers::models::debertav2::DTYPE,
&device,
)?
};

let vb = vb.set_prefix("deberta");

match self.task {
ArgsTask::Single => Ok((
TaskType::Single(ProvenceModel::load(vb, &config, Some(id2label.clone()))?.into()),
config,
tokenizer,
)),
}
}
}

fn main() -> Result<()> {
let args = Args::parse();

let model_load_time = std::time::Instant::now();
let (task_type, _model_config, tokenizer) = args.build_model_and_tokenizer()?;

println!(
"Loaded model and tokenizers in {:?}",
model_load_time.elapsed()
);

let tokenize_time = std::time::Instant::now();

println!(
"Tokenized and loaded inputs in {:?}",
tokenize_time.elapsed()
);

match task_type {
TaskType::Single(model) => {
let question = &args.question;
let context = args.context.first().context("context can't be empty")?;

// Forward only
println!("Running forward pass only");

let input_text = ProvenceModel::format_input(question, context);

let encoding = tokenizer
.encode(input_text, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;

let input_ids = Tensor::new(encoding.get_ids(), &model.device)?.unsqueeze(0)?;
let attention_mask =
Tensor::new(encoding.get_attention_mask(), &model.device)?.unsqueeze(0)?;

let output = model.forward(&input_ids, Some(attention_mask.clone()))?;

println!("Forward pass output");
dbg!(&output);

// Simple usage
println!("Running process helper function");
let result =
model.process_single(&tokenizer, question, context, args.threshold, false, true)?;

println!("Simple output");
println!("Pruned: {}", result.pruned_context);
println!("Score: {:.2}", result.reranking_score);
println!("Compression: {:.1}%", result.compression_rate);

// Detailed usage
println!("Detailed output");
let max_tokens = 80;

let token_details = result.token_details.context("token details is none")?;

println!("Ranking Score: {:.4}", result.reranking_score);
println!(" (Higher = more relevant context for this query)\n");

println!("Original Context Length (chars): {}", context.len());
println!(
"Pruned Context Length (chars): {}",
result.pruned_context.len()
);
println!(
"Compression Rate (context-only): {:.1}%",
result.compression_rate
);

println!("\nQuestion:\n{}", question);
println!("\nPruned Context:\n{}\n", result.pruned_context);

println!("Token-level Analysis (first {} tokens)", max_tokens);
for detail in token_details.iter().take(max_tokens) {
println!(
"{:3}: {:20} prob={:.3} -> {}",
detail.index,
format!("'{}'", detail.token),
detail.probability,
detail.status
);
}
println!("\nNOTE: With sentence rounding, entire sentences are kept/dropped together");
}
}

Ok(())
}

fn get_model_files(
model_path: &Option<PathBuf>,
model_id: &str,
revision: &str,
) -> Result<(PathBuf, PathBuf, PathBuf)> {
let config_filename = "config.json";
let tokenizer_filename = "tokenizer.json";
let weights_filename = "model.safetensors";

let config;
let tokenizer;
let weights;

match model_path {
Some(base_path) => {
if !base_path.is_dir() {
bail!("Model path {} is not a directory.", base_path.display())
}

config = base_path.join(config_filename);
tokenizer = base_path.join(tokenizer_filename);
weights = base_path.join(weights_filename);
}
None => {
let repo =
Repo::with_revision(model_id.to_owned(), RepoType::Model, revision.to_owned());

let api = Api::new()?;
let api = api.repo(repo);

config = api.get(config_filename)?;
tokenizer = api.get(tokenizer_filename)?;
weights = api.get(weights_filename)?;
}
}

Ok((config, tokenizer, weights))
}

// From xml-roberta
#[derive(Debug)]
pub enum TokenizeInput<'a> {
Single(&'a [String]),
Pairs(&'a [(String, String)]),
}

pub fn tokenize_batch(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = get_tokens(tokenizer, input)?;

let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;

Ok(Tensor::stack(&token_ids, 0)?)
}

pub fn get_attention_mask(
tokenizer: &Tokenizer,
input: TokenizeInput,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = get_tokens(tokenizer, input)?;

let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;

Ok(Tensor::stack(&attention_mask, 0)?)
}

fn get_tokens(tokenizer: &Tokenizer, input: TokenizeInput) -> anyhow::Result<Vec<Encoding>> {
let tokens = match input {
TokenizeInput::Single(text_batch) => tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?,
TokenizeInput::Pairs(pairs) => tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?,
};

Ok(tokens)
}
2 changes: 2 additions & 0 deletions candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tokenizers = { workspace = true, optional = true }
tracing = { workspace = true }

[features]
Expand All @@ -33,3 +34,4 @@ cudnn = ["candle/cudnn", "candle-nn/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
metal = ["candle/metal", "candle-nn/metal"]
provence-process = ["tokenizers"]
5 changes: 4 additions & 1 deletion candle-transformers/src/models/debertav2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,10 @@ pub struct DebertaV2NERModel {
classifier: candle_nn::Linear,
}

fn id2label_len(config: &Config, id2label: Option<HashMap<u32, String>>) -> Result<usize> {
pub(crate) fn id2label_len(
config: &Config,
id2label: Option<HashMap<u32, String>>,
) -> Result<usize> {
let id2label_len = match (&config.id2label, id2label) {
(None, None) => bail!("Id2Label is either not present in the model configuration or not passed into DebertaV2NERModel::load as a parameter"),
(None, Some(id2label_p)) => id2label_p.len(),
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub mod persimmon;
pub mod phi;
pub mod phi3;
pub mod pixtral;
pub mod provence;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_gemma3;
Expand Down
Loading