Skip to content

Commit 42a4edc

Browse files
authored
Mamba2 implementation (huggingface#3264)
1 parent 54131f1 commit 42a4edc

4 files changed

Lines changed: 1030 additions & 0 deletions

File tree

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# candle-mamba2: Mamba2 implementation
2+
3+
Candle implementation of _Mamba2_ [1] inference. Mamba2 introduces the State Space
4+
Duality (SSD) framework which unifies structured SSMs and attention variants.
5+
6+
- [1]. [Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality](https://arxiv.org/abs/2405.21060)
7+
8+
## Running the example
9+
10+
```bash
11+
cargo run --example mamba2 --release -- --prompt "Mamba is the"
12+
```
13+
14+
## Supported models
15+
16+
| Model | HuggingFace ID |
17+
|-------|----------------|
18+
| Mamba2-130m | `AntonV/mamba2-130m-hf` |
19+
| Mamba2-370m | `AntonV/mamba2-370m-hf` |
20+
| Mamba2-780m | `AntonV/mamba2-780m-hf` |
21+
| Mamba2-1.3b | `AntonV/mamba2-1.3b-hf` |
22+
| Mamba2-2.7b | `AntonV/mamba2-2.7b-hf` |
23+
24+
## Verification
25+
26+
Outputs match the PyTorch transformers `Mamba2ForCausalLM` reference implementation.
27+
28+
### mamba2-130m
29+
30+
```bash
31+
cargo run --example mamba2 --release -- \
32+
--prompt "Mamba is the" \
33+
--which mamba2-130m \
34+
--sample-len 20 \
35+
--repeat-penalty 1.0
36+
```
37+
38+
Expected output:
39+
```
40+
Mamba is the most popular and popular game in the world. It is a game where you can play with your friends
41+
```
42+
43+
### mamba2-370m
44+
45+
```bash
46+
cargo run --example mamba2 --release -- \
47+
--prompt "Mamba is the" \
48+
--which mamba2-370m \
49+
--sample-len 20 \
50+
--repeat-penalty 1.0
51+
```
52+
53+
Expected output:
54+
```
55+
Mamba is the first game in the series to feature a new character, the Mamba, who is a female version
56+
```
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
#[cfg(feature = "mkl")]
2+
extern crate intel_mkl_src;
3+
4+
#[cfg(feature = "accelerate")]
5+
extern crate accelerate_src;
6+
7+
use anyhow::{Error as E, Result};
8+
use clap::{Parser, ValueEnum};
9+
10+
use candle_transformers::models::mamba2::{Config, Model, State};
11+
12+
use candle::{DType, Device, Tensor};
13+
use candle_examples::token_output_stream::TokenOutputStream;
14+
use candle_nn::VarBuilder;
15+
use candle_transformers::generation::LogitsProcessor;
16+
use hf_hub::{api::sync::Api, Repo, RepoType};
17+
use tokenizers::Tokenizer;
18+
19+
struct TextGeneration {
20+
model: Model,
21+
config: Config,
22+
device: Device,
23+
tokenizer: TokenOutputStream,
24+
logits_processor: LogitsProcessor,
25+
repeat_penalty: f32,
26+
repeat_last_n: usize,
27+
use_prefill: bool,
28+
chunk_size: usize,
29+
}
30+
31+
impl TextGeneration {
32+
#[allow(clippy::too_many_arguments)]
33+
fn new(
34+
model: Model,
35+
config: Config,
36+
tokenizer: Tokenizer,
37+
seed: u64,
38+
temp: Option<f64>,
39+
top_p: Option<f64>,
40+
repeat_penalty: f32,
41+
repeat_last_n: usize,
42+
use_prefill: bool,
43+
chunk_size: usize,
44+
device: &Device,
45+
) -> Self {
46+
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
47+
Self {
48+
model,
49+
config,
50+
tokenizer: TokenOutputStream::new(tokenizer),
51+
logits_processor,
52+
repeat_penalty,
53+
repeat_last_n,
54+
use_prefill,
55+
chunk_size,
56+
device: device.clone(),
57+
}
58+
}
59+
60+
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
61+
use std::io::Write;
62+
self.tokenizer.clear();
63+
let dtype = self.model.dtype();
64+
let mut tokens = self
65+
.tokenizer
66+
.tokenizer()
67+
.encode(prompt, true)
68+
.map_err(E::msg)?
69+
.get_ids()
70+
.to_vec();
71+
let mut generated_tokens = 0usize;
72+
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
73+
Some(token) => token,
74+
None => anyhow::bail!("cannot find the <|endoftext|> token"),
75+
};
76+
let mut state = State::new(1, &self.config, dtype, &self.device)?;
77+
let mut next_logits = None;
78+
79+
if self.use_prefill && tokens.len() > 1 {
80+
let prefill_start = std::time::Instant::now();
81+
// Prefill mode: process all tokens at once
82+
let input = Tensor::new(&tokens[..], &self.device)?.unsqueeze(0)?;
83+
let logits = self
84+
.model
85+
.forward_prefill(&input, &mut state, self.chunk_size)?;
86+
// Get logits for last position
87+
next_logits = Some(logits.narrow(1, tokens.len() - 1, 1)?.squeeze(1)?);
88+
for &t in tokens.iter() {
89+
if let Some(t) = self.tokenizer.next_token(t)? {
90+
print!("{t}")
91+
}
92+
}
93+
println!(
94+
"\n[Prefill {} tokens in {:.2}ms]",
95+
tokens.len(),
96+
prefill_start.elapsed().as_secs_f64() * 1000.0
97+
);
98+
} else {
99+
// Step-by-step mode
100+
for &t in tokens.iter() {
101+
let input = Tensor::new(&[t], &self.device)?;
102+
let logits = self.model.forward(&input, &mut state)?;
103+
next_logits = Some(logits);
104+
if let Some(t) = self.tokenizer.next_token(t)? {
105+
print!("{t}")
106+
}
107+
}
108+
}
109+
std::io::stdout().flush()?;
110+
111+
let start_gen = std::time::Instant::now();
112+
for _ in 0..sample_len {
113+
let logits = match next_logits.as_ref() {
114+
Some(logits) => logits,
115+
None => anyhow::bail!("cannot work on an empty prompt"),
116+
};
117+
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
118+
let logits = if self.repeat_penalty == 1. {
119+
logits
120+
} else {
121+
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
122+
candle_transformers::utils::apply_repeat_penalty(
123+
&logits,
124+
self.repeat_penalty,
125+
&tokens[start_at..],
126+
)?
127+
};
128+
let next_token = self.logits_processor.sample(&logits)?;
129+
tokens.push(next_token);
130+
generated_tokens += 1;
131+
if next_token == eos_token {
132+
break;
133+
}
134+
if let Some(t) = self.tokenizer.next_token(next_token)? {
135+
print!("{t}");
136+
std::io::stdout().flush()?;
137+
}
138+
139+
let input = Tensor::new(&[next_token], &self.device)?;
140+
next_logits = Some(self.model.forward(&input, &mut state)?)
141+
}
142+
let dt = start_gen.elapsed();
143+
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
144+
print!("{rest}");
145+
}
146+
std::io::stdout().flush()?;
147+
println!(
148+
"\n{generated_tokens} tokens generated ({:.2} token/s)",
149+
generated_tokens as f64 / dt.as_secs_f64(),
150+
);
151+
Ok(())
152+
}
153+
}
154+
155+
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
156+
enum Which {
157+
Mamba2_130m,
158+
Mamba2_370m,
159+
Mamba2_780m,
160+
Mamba2_1_3b,
161+
Mamba2_2_7b,
162+
}
163+
164+
impl std::fmt::Display for Which {
165+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166+
write!(f, "{self:?}")
167+
}
168+
}
169+
170+
impl Which {
171+
fn model_id(&self) -> &'static str {
172+
match self {
173+
Self::Mamba2_130m => "AntonV/mamba2-130m-hf",
174+
Self::Mamba2_370m => "AntonV/mamba2-370m-hf",
175+
Self::Mamba2_780m => "AntonV/mamba2-780m-hf",
176+
Self::Mamba2_1_3b => "AntonV/mamba2-1.3b-hf",
177+
Self::Mamba2_2_7b => "AntonV/mamba2-2.7b-hf",
178+
}
179+
}
180+
}
181+
182+
#[derive(Parser, Debug)]
183+
#[command(author, version, about, long_about = None)]
184+
struct Args {
185+
/// Run on CPU rather than on GPU.
186+
#[arg(long)]
187+
cpu: bool,
188+
189+
/// Enable tracing (generates a trace-timestamp.json file).
190+
#[arg(long)]
191+
tracing: bool,
192+
193+
#[arg(long)]
194+
prompt: String,
195+
196+
/// The temperature used to generate samples.
197+
#[arg(long)]
198+
temperature: Option<f64>,
199+
200+
/// Nucleus sampling probability cutoff.
201+
#[arg(long)]
202+
top_p: Option<f64>,
203+
204+
/// The seed to use when generating random samples.
205+
#[arg(long, default_value_t = 299792458)]
206+
seed: u64,
207+
208+
/// The length of the sample to generate (in tokens).
209+
#[arg(long, short = 'n', default_value_t = 5000)]
210+
sample_len: usize,
211+
212+
#[arg(long, default_value = "mamba2-130m")]
213+
which: Which,
214+
215+
#[arg(long)]
216+
model_id: Option<String>,
217+
218+
#[arg(long)]
219+
tokenizer_file: Option<String>,
220+
221+
#[arg(long)]
222+
weight_files: Option<String>,
223+
224+
#[arg(long)]
225+
config_file: Option<String>,
226+
227+
#[arg(long, default_value = "f32")]
228+
dtype: String,
229+
230+
/// Penalty to be applied for repeating tokens, 1. means no penalty.
231+
#[arg(long, default_value_t = 1.1)]
232+
repeat_penalty: f32,
233+
234+
/// The context size to consider for the repeat penalty.
235+
#[arg(long, default_value_t = 64)]
236+
repeat_last_n: usize,
237+
238+
/// Use chunked prefill for processing the initial prompt.
239+
#[arg(long)]
240+
use_prefill: bool,
241+
242+
/// Chunk size for prefill (default 256).
243+
#[arg(long, default_value_t = 256)]
244+
chunk_size: usize,
245+
}
246+
247+
fn main() -> Result<()> {
248+
use std::str::FromStr;
249+
use tracing_chrome::ChromeLayerBuilder;
250+
use tracing_subscriber::prelude::*;
251+
252+
let args = Args::parse();
253+
let _guard = if args.tracing {
254+
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
255+
tracing_subscriber::registry().with(chrome_layer).init();
256+
Some(guard)
257+
} else {
258+
None
259+
};
260+
println!(
261+
"avx: {}, neon: {}, simd128: {}, f16c: {}",
262+
candle::utils::with_avx(),
263+
candle::utils::with_neon(),
264+
candle::utils::with_simd128(),
265+
candle::utils::with_f16c()
266+
);
267+
println!(
268+
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
269+
args.temperature.unwrap_or(0.),
270+
args.repeat_penalty,
271+
args.repeat_last_n
272+
);
273+
274+
let start = std::time::Instant::now();
275+
let api = Api::new()?;
276+
let model_id = args
277+
.model_id
278+
.unwrap_or_else(|| args.which.model_id().to_string());
279+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
280+
let tokenizer_filename = match args.tokenizer_file {
281+
Some(file) => std::path::PathBuf::from(file),
282+
None => repo.get("tokenizer.json")?,
283+
};
284+
let config_filename = match args.config_file {
285+
Some(file) => std::path::PathBuf::from(file),
286+
None => repo.get("config.json")?,
287+
};
288+
let filenames = match args.weight_files {
289+
Some(files) => files
290+
.split(',')
291+
.map(std::path::PathBuf::from)
292+
.collect::<Vec<_>>(),
293+
None => {
294+
vec![repo.get("model.safetensors")?]
295+
}
296+
};
297+
println!("retrieved the files in {:?}", start.elapsed());
298+
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
299+
300+
let start = std::time::Instant::now();
301+
// Config contains `Infinity` which is not valid JSON, replace with a large number
302+
let config_str = std::fs::read_to_string(config_filename)?;
303+
let config_str = config_str.replace("Infinity", "1e30");
304+
let config: Config = serde_json::from_str(&config_str)?;
305+
let device = candle_examples::device(args.cpu)?;
306+
let dtype = DType::from_str(&args.dtype)?;
307+
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
308+
let model = Model::new(&config, vb.pp("backbone"))?;
309+
println!("loaded the model in {:?}", start.elapsed());
310+
311+
let mut pipeline = TextGeneration::new(
312+
model,
313+
config,
314+
tokenizer,
315+
args.seed,
316+
args.temperature,
317+
args.top_p,
318+
args.repeat_penalty,
319+
args.repeat_last_n,
320+
args.use_prefill,
321+
args.chunk_size,
322+
&device,
323+
);
324+
pipeline.run(&args.prompt, args.sample_len)?;
325+
Ok(())
326+
}

0 commit comments

Comments
 (0)