Skip to content

Commit d7d975e

Browse files
rob-pclaude
andcommitted
perf(map): reuse ksw2 scratch buffers and add a two-MEM chaining fast path
Cut allocator traffic in the alignment/chaining hot path: - align.rs: add a per-thread `KswScratch` (a reusable `ksw2rs::Aligner` plus DNA5-encoded query/target buffers) and route `ksw2_gap_global`, `ksw2_flank_extend`, and `ksw2_align_score` through it via new `dna5_encode_into` / `dna5_encode_rev_into` helpers. This replaces the per-call `Vec` encode allocations and the `Extz::default()` workspace with reused buffers; the score-only ksw2 configuration is unchanged. - chain.rs: add exact single-MEM and two-MEM fast paths to `chain_mems`, avoiding the general DP's n-sized scratch buffers for the overwhelmingly common small cases. `chain_two_mems` returns `None` on tie cases so the caller falls back to the general implementation, preserving its ordering. These are score/result-preserving refactors; the salmon-map suite (53 tests, incl. alignment-validation and end-to-end mapping) passes, fmt and clippy (`-D warnings`) are clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent de54265 commit d7d975e

2 files changed

Lines changed: 196 additions & 86 deletions

File tree

crates/salmon-map/src/align.rs

Lines changed: 133 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,26 @@ thread_local! {
163163
/// — an unbounded thread-local here cost tens of GB on a 36M-read library.
164164
static GAP_CACHE: std::cell::RefCell<ahash::AHashMap<(Box<[u8]>, Box<[u8]>), i32>> =
165165
std::cell::RefCell::new(ahash::AHashMap::new());
166+
/// Per-thread ksw2 scratch. This reuses both the ksw DP workspace/result and
167+
/// the DNA5-encoded query/target buffers between small gap/flank DPs.
168+
static KSW_SCRATCH: std::cell::RefCell<KswScratch> =
169+
std::cell::RefCell::new(KswScratch::default());
170+
}
171+
172+
struct KswScratch {
173+
aligner: ksw2rs::Aligner,
174+
query: Vec<u8>,
175+
target: Vec<u8>,
176+
}
177+
178+
impl Default for KswScratch {
179+
fn default() -> Self {
180+
Self {
181+
aligner: ksw2rs::Aligner::new(),
182+
query: Vec::new(),
183+
target: Vec::new(),
184+
}
185+
}
166186
}
167187

168188
/// Maximum number of cached (query, ref) DP scores per thread. The cache exists
@@ -260,38 +280,43 @@ fn dna5_mat(cfg: &AlignConfig) -> [i8; 25] {
260280
/// Global DP score of an inter-MEM gap (both substrings fully consumed). The
261281
/// band is widened to guarantee the alignment can reach the corner.
262282
fn ksw2_gap_global(qg: &[u8], tg: &[u8], cfg: &AlignConfig) -> i32 {
263-
use ksw2rs::{extz2, Extz, Extz2Input, KSW_EZ_RIGHT, KSW_EZ_SCORE_ONLY, KSW_NEG_INF};
283+
use ksw2rs::{Extz2Input, KSW_EZ_RIGHT, KSW_EZ_SCORE_ONLY, KSW_NEG_INF};
264284
if qg.is_empty() {
265285
return -(cfg.gap_open_pen as i32 + tg.len() as i32 * cfg.gap_extend_pen as i32);
266286
}
267287
if tg.is_empty() {
268288
return -(cfg.gap_open_pen as i32 + qg.len() as i32 * cfg.gap_extend_pen as i32);
269289
}
270-
let q = dna5_encode(qg);
271-
let t = dna5_encode(tg);
272290
let mat = dna5_mat(cfg);
273291
let w = cfg
274292
.bandwidth
275293
.max((qg.len() as i32 - tg.len() as i32).abs() + 4);
276-
let input = Extz2Input {
277-
query: &q,
278-
target: &t,
279-
m: 5,
280-
mat: &mat,
281-
q: cfg.gap_open_pen,
282-
e: cfg.gap_extend_pen,
283-
w,
284-
zdrop: -1,
285-
end_bonus: 0,
286-
// Mapping only needs the score, not the CIGAR — score-only mode skips the
287-
// O(qlen·tlen) traceback-matrix fill and backtrack pass.
288-
flag: KSW_EZ_RIGHT | KSW_EZ_SCORE_ONLY,
289-
};
290-
let mut ez = Extz::default();
291-
ez.reset();
292-
extz2(&input, &mut ez);
293-
if ez.score > KSW_NEG_INF {
294-
ez.score
294+
let score = KSW_SCRATCH.with(|cell| {
295+
let KswScratch {
296+
aligner,
297+
query,
298+
target,
299+
} = &mut *cell.borrow_mut();
300+
dna5_encode_into(qg, query);
301+
dna5_encode_into(tg, target);
302+
let input = Extz2Input {
303+
query,
304+
target,
305+
m: 5,
306+
mat: &mat,
307+
q: cfg.gap_open_pen,
308+
e: cfg.gap_extend_pen,
309+
w,
310+
zdrop: -1,
311+
end_bonus: 0,
312+
// Mapping only needs the score, not the CIGAR — score-only mode skips
313+
// the O(qlen·tlen) traceback-matrix fill and backtrack pass.
314+
flag: KSW_EZ_RIGHT | KSW_EZ_SCORE_ONLY,
315+
};
316+
aligner.align(&input).score
317+
});
318+
if score > KSW_NEG_INF {
319+
score
295320
} else {
296321
// Corner unreachable: approximate by matching the common prefix length.
297322
let common = qg.len().min(tg.len()) as i32;
@@ -304,38 +329,43 @@ fn ksw2_gap_global(qg: &[u8], tg: &[u8], cfg: &AlignConfig) -> i32 {
304329
/// 5' flank (`anchor_right`) the sequences are reversed so the anchor is on the
305330
/// left, matching ksw2's left-anchored extension.
306331
fn ksw2_flank_extend(qf: &[u8], tf: &[u8], cfg: &AlignConfig, anchor_right: bool) -> i32 {
307-
use ksw2rs::{extz2, Extz, Extz2Input, KSW_EZ_RIGHT, KSW_EZ_SCORE_ONLY, KSW_NEG_INF};
332+
use ksw2rs::{Extz2Input, KSW_EZ_RIGHT, KSW_EZ_SCORE_ONLY, KSW_NEG_INF};
308333
if qf.is_empty() {
309334
return 0;
310335
}
311336
if tf.is_empty() {
312337
return -(cfg.gap_open_pen as i32 + qf.len() as i32 * cfg.gap_extend_pen as i32);
313338
}
314-
let (q, t): (Vec<u8>, Vec<u8>) = if anchor_right {
315-
(
316-
dna5_encode(&qf.iter().rev().copied().collect::<Vec<u8>>()),
317-
dna5_encode(&tf.iter().rev().copied().collect::<Vec<u8>>()),
318-
)
319-
} else {
320-
(dna5_encode(qf), dna5_encode(tf))
321-
};
322339
let mat = dna5_mat(cfg);
323-
let input = Extz2Input {
324-
query: &q,
325-
target: &t,
326-
m: 5,
327-
mat: &mat,
328-
q: cfg.gap_open_pen,
329-
e: cfg.gap_extend_pen,
330-
w: cfg.bandwidth.max(qf.len() as i32),
331-
zdrop: -1,
332-
end_bonus: 0,
333-
// Score-only: we read `mqe`/`max`, never the CIGAR (see ksw2_gap_global).
334-
flag: KSW_EZ_RIGHT | KSW_EZ_SCORE_ONLY,
335-
};
336-
let mut ez = Extz::default();
337-
ez.reset();
338-
extz2(&input, &mut ez);
340+
let (mqe, mte, max_score) = KSW_SCRATCH.with(|cell| {
341+
let KswScratch {
342+
aligner,
343+
query,
344+
target,
345+
} = &mut *cell.borrow_mut();
346+
if anchor_right {
347+
dna5_encode_rev_into(qf, query);
348+
dna5_encode_rev_into(tf, target);
349+
} else {
350+
dna5_encode_into(qf, query);
351+
dna5_encode_into(tf, target);
352+
}
353+
let input = Extz2Input {
354+
query,
355+
target,
356+
m: 5,
357+
mat: &mat,
358+
q: cfg.gap_open_pen,
359+
e: cfg.gap_extend_pen,
360+
w: cfg.bandwidth.max(qf.len() as i32),
361+
zdrop: -1,
362+
end_bonus: 0,
363+
// Score-only: we read `mqe`/`max`, never the CIGAR (see ksw2_gap_global).
364+
flag: KSW_EZ_RIGHT | KSW_EZ_SCORE_ONLY,
365+
};
366+
let ez = aligner.align(&input);
367+
(ez.mqe, ez.mte, ez.max as i32)
368+
});
339369
// Soft-clip semantics mirror PuffAligner's flank `part_score` (mqe = read
340370
// flank fully consumed, reference free to overhang; mte = reference fully
341371
// consumed, read free to overhang the transcript end):
@@ -345,23 +375,23 @@ fn ksw2_flank_extend(qf: &[u8], tf: &[u8], cfg: &AlignConfig, anchor_right: bool
345375
// `--softclip` implies `--softclipOverhangs`. `ez.max` is the fallback when
346376
// neither extension is valid (e.g. band-clipped), matching the prior guard.
347377
if cfg.softclip {
348-
let s = ez.mqe.max(ez.mte);
378+
let s = mqe.max(mte);
349379
if s > KSW_NEG_INF {
350380
s.max(0)
351381
} else {
352-
(ez.max as i32).max(0)
382+
max_score.max(0)
353383
}
354384
} else if cfg.softclip_overhangs {
355-
let s = ez.mqe.max(ez.mte);
385+
let s = mqe.max(mte);
356386
if s > KSW_NEG_INF {
357387
s
358388
} else {
359-
ez.max as i32
389+
max_score
360390
}
361-
} else if ez.mqe > KSW_NEG_INF {
362-
ez.mqe
391+
} else if mqe > KSW_NEG_INF {
392+
mqe
363393
} else {
364-
ez.max as i32
394+
max_score
365395
}
366396
}
367397

@@ -401,17 +431,28 @@ fn cached_flank_score(qf: &[u8], tf: &[u8], cfg: &AlignConfig, anchor_right: boo
401431
})
402432
}
403433

404-
/// 2-bit/DNA5 encode (A=0,C=1,G=2,T=3,N/other=4) for ksw2.
405-
fn dna5_encode(seq: &[u8]) -> Vec<u8> {
406-
seq.iter()
407-
.map(|&b| match b {
408-
b'A' | b'a' => 0,
409-
b'C' | b'c' => 1,
410-
b'G' | b'g' => 2,
411-
b'T' | b't' => 3,
412-
_ => 4,
413-
})
414-
.collect()
434+
fn dna5_encode_into(seq: &[u8], out: &mut Vec<u8>) {
435+
out.clear();
436+
out.reserve(seq.len());
437+
out.extend(seq.iter().map(|&b| match b {
438+
b'A' | b'a' => 0,
439+
b'C' | b'c' => 1,
440+
b'G' | b'g' => 2,
441+
b'T' | b't' => 3,
442+
_ => 4,
443+
}));
444+
}
445+
446+
fn dna5_encode_rev_into(seq: &[u8], out: &mut Vec<u8>) {
447+
out.clear();
448+
out.reserve(seq.len());
449+
out.extend(seq.iter().rev().map(|&b| match b {
450+
b'A' | b'a' => 0,
451+
b'C' | b'c' => 1,
452+
b'G' | b'g' => 2,
453+
b'T' | b't' => 3,
454+
_ => 4,
455+
}));
415456
}
416457

417458
/// ksw2 score, matching C++ salmon's banded `ksw_extz2_sse` configuration
@@ -421,37 +462,43 @@ fn dna5_encode(seq: &[u8]) -> Vec<u8> {
421462
/// diagonal. This narrow band is what makes salmon reject off-diagonal /
422463
/// large-indel placements that a wide-band aligner would accept.
423464
fn ksw2_align_score(query: &[u8], rwin: &[u8], cfg: &AlignConfig) -> i32 {
424-
use ksw2rs::{extz2, Extz, Extz2Input, KSW_EZ_RIGHT, KSW_EZ_SCORE_ONLY, KSW_NEG_INF};
425-
let q_enc = dna5_encode(query);
426-
let t_enc = dna5_encode(rwin);
465+
use ksw2rs::{Extz2Input, KSW_EZ_RIGHT, KSW_EZ_SCORE_ONLY, KSW_NEG_INF};
427466
// 5x5 DNA5 scoring matrix: match on the diagonal, mismatch off, N-N = 0.
428467
let mut mat = [-cfg.mismatch_pen; 25];
429468
for i in 0..5 {
430469
mat[i * 5 + i] = cfg.match_score;
431470
}
432471
mat[24] = 0;
433-
let input = Extz2Input {
434-
query: &q_enc,
435-
target: &t_enc,
436-
m: 5,
437-
mat: &mat,
438-
q: cfg.gap_open_pen,
439-
e: cfg.gap_extend_pen,
440-
w: cfg.bandwidth,
441-
zdrop: -1,
442-
end_bonus: 0,
443-
flag: KSW_EZ_RIGHT | KSW_EZ_SCORE_ONLY,
444-
};
445-
let mut ez = Extz::default();
446-
ez.reset(); // initialize mqe/mte/score to KSW_NEG_INF as ksw2 expects
447-
extz2(&input, &mut ez);
472+
let (mqe, max_score) = KSW_SCRATCH.with(|cell| {
473+
let KswScratch {
474+
aligner,
475+
query: q_enc,
476+
target: t_enc,
477+
} = &mut *cell.borrow_mut();
478+
dna5_encode_into(query, q_enc);
479+
dna5_encode_into(rwin, t_enc);
480+
let input = Extz2Input {
481+
query: q_enc,
482+
target: t_enc,
483+
m: 5,
484+
mat: &mat,
485+
q: cfg.gap_open_pen,
486+
e: cfg.gap_extend_pen,
487+
w: cfg.bandwidth,
488+
zdrop: -1,
489+
end_bonus: 0,
490+
flag: KSW_EZ_RIGHT | KSW_EZ_SCORE_ONLY,
491+
};
492+
let ez = aligner.align(&input);
493+
(ez.mqe, ez.max as i32)
494+
});
448495
// `mqe` is the best score with the entire query consumed (the read aligned
449496
// fully). Fall back to the local max if the query end was never reached
450497
// within the band.
451-
if ez.mqe > KSW_NEG_INF {
452-
ez.mqe
498+
if mqe > KSW_NEG_INF {
499+
mqe
453500
} else {
454-
ez.max as i32
501+
max_score
455502
}
456503
}
457504

crates/salmon-map/src/chain.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,61 @@ fn gap_cost(gap: i32, seed_len: i32, log2_lut: &[f32]) -> f32 {
161161
}
162162
}
163163

164+
/// Exact `chain_mems` fast path for the common two-anchor case. Returns `None`
165+
/// for tie cases where matching `sort_unstable_by`'s arbitrary equal-key/order
166+
/// behavior would be brittle; the caller then uses the general implementation.
167+
fn chain_two_mems(m0: Mem, m1: Mem, is_fw: bool, cfg: &ChainConfig) -> Option<Vec<MemChain>> {
168+
let k0 = (m0.ref_start, m0.read_start);
169+
let k1 = (m1.ref_start, m1.read_start);
170+
let (a, b) = if k0 < k1 {
171+
(m0, m1)
172+
} else if k1 < k0 {
173+
(m1, m0)
174+
} else {
175+
return None;
176+
};
177+
178+
let f0 = a.len as f32;
179+
let mut f1 = b.len as f32;
180+
let mut p1 = usize::MAX;
181+
182+
let dr = b.ref_start - a.ref_start;
183+
if dr <= cfg.max_gap {
184+
let dq = b.read_start - a.read_start;
185+
if dr > 0 && dq > 0 && dq <= cfg.max_gap {
186+
let gap = (dr - dq).abs();
187+
let gain = dq.min(dr).min(b.len) as f32;
188+
let sc = f0 + gain - gap_cost(gap, cfg.seed_len, &LOG2_LUT);
189+
if sc > f1 {
190+
f1 = sc;
191+
p1 = 0;
192+
}
193+
}
194+
}
195+
196+
if f0 == f1 {
197+
return None;
198+
}
199+
200+
let mut chains = Vec::with_capacity(2);
201+
if f1 > f0 {
202+
if p1 == 0 {
203+
chains.push(MemChain::new(vec![a, b], f1, is_fw));
204+
} else {
205+
chains.push(MemChain::new(vec![b], f1, is_fw));
206+
chains.push(MemChain::new(vec![a], f0, is_fw));
207+
}
208+
} else {
209+
chains.push(MemChain::new(vec![a], f0, is_fw));
210+
chains.push(MemChain::new(vec![b], f1, is_fw));
211+
}
212+
213+
let best = f0.max(f1);
214+
let cutoff = best * cfg.chain_subopt_thresh;
215+
chains.retain(|c| c.score >= cutoff);
216+
Some(chains)
217+
}
218+
164219
/// Reusable per-thread scratch for [`chain_mems`]. The chaining DP allocates a
165220
/// handful of `n`-sized buffers per (tid, orientation) group; with many groups
166221
/// per read this dominated the mapper's allocator traffic. We keep one set of
@@ -190,6 +245,14 @@ pub fn chain_mems(mems: &[Mem], is_fw: bool, cfg: &ChainConfig) -> Vec<MemChain>
190245
if mems.is_empty() {
191246
return Vec::new();
192247
}
248+
if let [mem] = mems {
249+
return vec![MemChain::new(vec![*mem], mem.len as f32, is_fw)];
250+
}
251+
if let [m0, m1] = mems {
252+
if let Some(chains) = chain_two_mems(*m0, *m1, is_fw, cfg) {
253+
return chains;
254+
}
255+
}
193256

194257
CHAIN_SCRATCH.with(|cell| {
195258
// Split-borrow each scratch field independently.

0 commit comments

Comments
 (0)