Skip to content

Commit f3d0ed2

Browse files
authored
Merge PR #1014: abundance-aware fragment-length-distribution training
feat(quant): abundance-aware fragment-length-distribution training
2 parents 6d5d062 + 4da312d commit f3d0ed2

2 files changed

Lines changed: 87 additions & 72 deletions

File tree

crates/salmon-quant/src/lib.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,19 @@ pub fn quantify(opts: &QuantOptions) -> Result<QuantResult> {
362362
// bias models are weighted by abundance-aware posteriors (salmon's online
363363
// phase), not score-only weights. The offline EM still gives the final
364364
// point estimate.
365-
let online = (opts.seq_bias || opts.gc_bias || opts.pos_bias).then(|| {
365+
// The online estimate runs unconditionally: besides weighting the observed
366+
// bias models (when bias correction is on), it provides the abundance-aware
367+
// posterior used to train the fragment-length distribution (salmon's
368+
// `r < exp(aln.logProb)` acceptance). The offline EM does not read it.
369+
let online = {
366370
let ref_lens: Vec<u64> = (0..num_refs).map(|t| salmon.ref_len(t)).collect();
367-
salmon_infer::OnlineInference::new(
371+
Some(salmon_infer::OnlineInference::new(
368372
&ref_lens,
369373
0.05,
370374
opts.forgetting_factor,
371375
opts.num_aux_model_samples,
372-
)
373-
});
376+
))
377+
};
374378

375379
// ---- parallel mapping pass (borrows the accumulators) -------------------
376380
{

crates/salmon-quant/src/processor.rs

Lines changed: 79 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,24 @@ const LOG_EPSILON: f64 = -23.998_158_637_57; // (0.375e-10f64).ln()
165165
/// posterior only after this many fragments have been assigned.
166166
pub(crate) const NUM_PRE_BURNIN: u64 = 5000;
167167

168+
thread_local! {
169+
/// Per-thread PRNG state for stochastic FLD-sample acceptance (mirrors
170+
/// salmon's per-thread RNG used when training the fragment-length model).
171+
static FLD_RNG: std::cell::Cell<u64> = const { std::cell::Cell::new(0x2545_F491_4F6C_DD1D) };
172+
}
173+
174+
/// Draw a pseudo-random value in `[0, 1)` from the per-thread xorshift state.
175+
fn fld_rng_u01() -> f64 {
176+
FLD_RNG.with(|s| {
177+
let mut x = s.get();
178+
x ^= x << 13;
179+
x ^= x >> 7;
180+
x ^= x << 17;
181+
s.set(x);
182+
((x >> 11) as f64) * (1.0 / ((1u64 << 53) as f64))
183+
})
184+
}
185+
168186
/// Collect the orientation-aware 5'/3' sequence-bias contexts for one mapping,
169187
/// weighted by `weight` (the fragment-transcript posterior). A forward read's 5'
170188
/// context feeds the forward model; a reverse read's 5' context
@@ -326,50 +344,53 @@ fn record(
326344
// also advances the online masses by `logForgettingMass + log(posterior)`;
327345
// without it they fall back to the normalized aux (score) weights.
328346
let collecting = seqbias.is_some() || gcbias.is_some() || posbias.is_some();
329-
let bias_w: Vec<f64> = if collecting {
330-
if let Some(online) = sh.online {
331-
// Per-mapping log auxiliary probability (salmon's
332-
// `auxProb + startPosProb`, abundance-independent):
333-
// logFragCov = ln(score weight)
334-
// startPosProb = proper pair -> -ln(refLen - flen + 1) (flen<=refLen,
335-
// else LOG_EPSILON); otherwise -ln(refLen)
336-
// logFragProb = proper pair (after pre-burn-in) -> live FLD pmf(flen);
337-
// unexpected orphan in a paired library -> LOG_EPSILON.
338-
// The FLD term discriminates isoforms/paralogs whose implied insert
339-
// size differs (e.g. alternative splicing), which the length norm alone
340-
// cannot capture.
341-
let use_aux = online.num_assigned() >= sh.pre_burnin;
342-
let mm: Vec<(u32, f64)> = compat
343-
.iter()
344-
.map(|(m, w)| {
345-
let rl = sh.salmon.ref_len(m.tid as usize).max(1) as f64;
346-
let proper = m.status == MateStatus::PairedEndPaired && m.fragment_len > 0;
347-
let flen = m.fragment_len as f64;
348-
let start_pos_prob = if proper {
349-
if flen <= rl {
350-
-((rl - flen + 1.0).ln())
351-
} else {
352-
LOG_EPSILON
353-
}
347+
// Abundance-aware online posterior, computed once per fragment within the
348+
// model-training window (`o.collecting()`), advancing the online masses.
349+
// Mirrors salmon's `aln.logProb = transcriptLogCount + auxProb + startPosProb`
350+
// normalized over the fragment's mappings, where
351+
// logFragCov = ln(score weight)
352+
// startPosProb = proper pair -> -ln(refLen - flen + 1) (flen<=refLen, else
353+
// LOG_EPSILON); otherwise -ln(refLen)
354+
// logFragProb = proper pair (after pre-burn-in) -> live FLD pmf(flen);
355+
// unexpected orphan in a paired library -> LOG_EPSILON.
356+
// Used for abundance-aware bias collection and abundance-aware FLD training.
357+
let online_post: Option<Vec<f64>> = sh.online.filter(|o| o.collecting()).map(|online| {
358+
let use_aux = online.num_assigned() >= sh.pre_burnin;
359+
let mm: Vec<(u32, f64)> = compat
360+
.iter()
361+
.map(|(m, w)| {
362+
let rl = sh.salmon.ref_len(m.tid as usize).max(1) as f64;
363+
let proper = m.status == MateStatus::PairedEndPaired && m.fragment_len > 0;
364+
let flen = m.fragment_len as f64;
365+
let start_pos_prob = if proper {
366+
if flen <= rl {
367+
-((rl - flen + 1.0).ln())
354368
} else {
355-
-(rl.ln())
356-
};
357-
let log_frag_prob = if proper {
358-
if use_aux {
359-
sh.fld.pmf(m.fragment_len as usize)
360-
} else {
361-
0.0
362-
}
363-
} else if sh.paired_lib {
364-
LOG_EPSILON // unexpected orphan
369+
LOG_EPSILON
370+
}
371+
} else {
372+
-(rl.ln())
373+
};
374+
let log_frag_prob = if proper {
375+
if use_aux {
376+
sh.fld.pmf(m.fragment_len as usize)
365377
} else {
366378
0.0
367-
};
368-
let log_cov = if *w > 0.0 { w.ln() } else { f64::NEG_INFINITY };
369-
(m.tid, log_cov + start_pos_prob + log_frag_prob)
370-
})
371-
.collect();
372-
online.assign_fragment(&mm, log_fm)
379+
}
380+
} else if sh.paired_lib {
381+
LOG_EPSILON // unexpected orphan
382+
} else {
383+
0.0
384+
};
385+
let log_cov = if *w > 0.0 { w.ln() } else { f64::NEG_INFINITY };
386+
(m.tid, log_cov + start_pos_prob + log_frag_prob)
387+
})
388+
.collect();
389+
online.assign_fragment(&mm, log_fm)
390+
});
391+
let bias_w: Vec<f64> = if collecting {
392+
if let Some(post) = &online_post {
393+
post.clone()
373394
} else {
374395
let wsum: f64 = compat.iter().map(|(_, w)| *w).sum();
375396
compat
@@ -424,36 +445,26 @@ fn record(
424445
})
425446
.collect();
426447

427-
// Observe a fragment length from the best concordant compatible pair,
428-
// weighted by that pair's posterior confidence among the concordant pairs.
429-
// salmon trains its FLD stochastically (it accepts an alignment with
430-
// probability proportional to exp(aln.logProb)), so ambiguous multimappers
431-
// contribute little; adding every best pair at full weight overdisperses the
432-
// FLD on paralog/near-duplicate-rich inputs. Down-weighting by confidence is
433-
// the deterministic analog.
434-
let mut conc_sum = 0.0_f64;
435-
let mut best_concordant: Option<&ScoredMapping> = None;
436-
for m in maps.iter() {
437-
if m.status == MateStatus::PairedEndPaired
438-
&& m.fragment_len > 0
439-
&& sh
440-
.expected_format
441-
.is_none_or(|exp| is_compatible(exp, m.format, m.is_fw, m.status))
442-
{
443-
conc_sum += m.weight;
444-
if best_concordant.is_none_or(|b| m.weight > b.weight) {
445-
best_concordant = Some(m);
448+
// Abundance-aware FLD training: accept each concordant compatible pair's
449+
// fragment length with probability = its abundance-aware online posterior
450+
// (salmon's `if (r < exp(aln.logProb)) fragLengthDist.addVal(...)`, where
451+
// aln.logProb includes transcriptLogCount). For reads shared between
452+
// near-duplicates this preferentially samples the dominant transcript's
453+
// implied length, concentrating the FLD as salmon does (vs adding every best
454+
// pair at full weight, which overdisperses it). Frozen after the training
455+
// window (`online_post` is `None`).
456+
if let Some(post) = &online_post {
457+
for (i, (m, _)) in compat.iter().enumerate() {
458+
let conc = m.status == MateStatus::PairedEndPaired
459+
&& m.fragment_len > 0
460+
&& sh
461+
.expected_format
462+
.is_none_or(|exp| is_compatible(exp, m.format, m.is_fw, m.status));
463+
if conc && fld_rng_u01() < post[i] {
464+
sh.fld.add_val(m.fragment_len as usize, 0.0);
446465
}
447466
}
448467
}
449-
if let Some(best) = best_concordant {
450-
let conf = if conc_sum > 0.0 {
451-
best.weight / conc_sum
452-
} else {
453-
1.0
454-
};
455-
sh.fld.add_val(best.fragment_len as usize, conf.ln());
456-
}
457468

458469
// Build the equivalence class: sorted, de-duplicated transcript ids + weights.
459470
pairs.sort_by_key(|p| p.0);

0 commit comments

Comments
 (0)