diff --git a/Cargo.lock b/Cargo.lock index 395f2ee..888dffb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -364,6 +364,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -727,6 +738,15 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "predicates" version = "3.1.3" @@ -781,6 +801,36 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + [[package]] name = "rayon" version = "1.11.0" @@ -856,6 +906,7 @@ dependencies = [ "memmap2", "noodles", "predicates", + "rand", "rayon", "tempfile", "thiserror", @@ -963,7 +1014,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", - "getrandom", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys", @@ -1016,6 +1067,12 @@ dependencies = [ "libc", ] +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + [[package]] name = "wasip2" version = "1.0.2+wasi-0.2.9" @@ -1143,3 +1200,23 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 6e19675..acd50a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ flate2 = "1" rayon = "1" dashmap = "6" chrono = "0.4" +rand = "0.8" [dev-dependencies] tempfile = "3" diff --git a/src/align/read_align.rs b/src/align/read_align.rs index c53daef..f3cf514 100644 --- a/src/align/read_align.rs +++ b/src/align/read_align.rs @@ -1,14 +1,44 @@ /// Read alignment driver function use crate::align::score::{AlignmentScorer, SpliceMotif}; use crate::align::seed::Seed; -use crate::align::stitch::{ - cluster_seeds, stitch_seeds_with_jdb_debug, -}; +use crate::align::stitch::{cluster_seeds, stitch_seeds_with_jdb_debug}; use crate::align::transcript::Transcript; use crate::error::Error; use crate::index::GenomeIndex; use crate::params::{IntronMotifFilter, IntronStrandFilter, Parameters}; use crate::stats::UnmappedReason; +use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom}; +use std::hash::{DefaultHasher, Hash, Hasher}; + +/// Derive a deterministic per-read RNG seed from `run_rng_seed` + the read name. +/// +/// STAR seeds `std::mt19937` once per chunk/thread (`runRNGseed*(iChunk+1)`), +/// then advances the state sequentially per read. ruSTAR parallelises per-read +/// via rayon, so we instead fold the read name into the seed — this keeps tie +/// breaks reproducible regardless of thread count while still honoring the +/// user's `--runRNGseed` value. +fn per_read_seed(run_rng_seed: u64, read_name: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + read_name.hash(&mut hasher); + run_rng_seed.wrapping_mul(hasher.finish().wrapping_add(1)) +} + +/// Shuffle the prefix of `items` whose `score_fn` equals the first element's score. +/// +/// Mirrors STAR's `ReadAlign_multMapSelect` / `funPrimaryAlignMark`: best-scoring +/// alignments are randomized so primary selection (index 0) is not biased by +/// upstream sort order. Non-tied elements are left alone. +fn shuffle_tied_prefix(items: &mut [T], score_fn: impl Fn(&T) -> i32, seed: u64) { + let Some(first) = items.first() else { + return; + }; + let best = score_fn(first); + let tied = items.iter().take_while(|t| score_fn(t) == best).count(); + if tied < 2 { + return; + } + items[..tied].shuffle(&mut StdRng::seed_from_u64(seed)); +} /// Result of aligning a single read: (transcripts, chimeric_alignments, n_for_mapq, unmapped_reason) pub type AlignReadResult = ( @@ -301,6 +331,13 @@ pub fn align_read( .then_with(|| a.is_reverse.cmp(&b.is_reverse)) }); + // Randomize primary among best-scoring ties (ReadAlign_multMapSelect.cpp:71-79). + shuffle_tied_prefix( + &mut transcripts, + |t| t.score, + per_read_seed(params.run_rng_seed, read_name), + ); + // Score-range filter: keep only alignments within outFilterMultimapScoreRange of the best. // (STAR's multMapSelect step — must run before quality filters.) if !transcripts.is_empty() { @@ -621,7 +658,10 @@ pub fn align_paired_read( // This correctly represents mate2 on - strand for FR pairs without explicit RC handling. let debug_name: &str = if debug_pe { read_name } else { "" }; let mut mate1_transcripts: Vec = Vec::new(); - for cluster in mate1_clusters.iter().take(params.align_windows_per_read_nmax) { + for cluster in mate1_clusters + .iter() + .take(params.align_windows_per_read_nmax) + { let ts = stitch_seeds_with_jdb_debug( cluster, mate1_seq, @@ -635,7 +675,10 @@ pub fn align_paired_read( } let mut mate2_transcripts: Vec = Vec::new(); - for cluster in mate2_clusters.iter().take(params.align_windows_per_read_nmax) { + for cluster in mate2_clusters + .iter() + .take(params.align_windows_per_read_nmax) + { let ts = stitch_seeds_with_jdb_debug( cluster, mate2_seq, @@ -656,14 +699,9 @@ pub fn align_paired_read( let mut joint_pairs: Vec = Vec::new(); for t1 in &mate1_transcripts { for t2 in &mate2_transcripts { - if let Some(pair) = try_pair_transcripts( - t1, - t2, - len1, - len2, - params, - combined_score_threshold, - ) { + if let Some(pair) = + try_pair_transcripts(t1, t2, len1, len2, params, combined_score_threshold) + { joint_pairs.push(pair); } } @@ -826,6 +864,13 @@ pub fn align_paired_read( }) }); + // Randomize primary among best-scoring pairs (STAR's funPrimaryAlignMark). + shuffle_tied_prefix( + &mut joint_pairs, + |pa| pa.combined_wt_score, + per_read_seed(params.run_rng_seed, read_name), + ); + // Step 4: quality filter (mappedFilter). filter_paired_transcripts(&mut joint_pairs, params); @@ -840,12 +885,10 @@ pub fn align_paired_read( // Half-mapped fallback: report the best-scoring single-mate transcript. // Apply per-mate quality threshold (outFilterScoreMinOverLread * (len - 1)). - let thresh1 = - ((params.out_filter_score_min_over_lread * (len1 as f64 - 1.0)) as i32) - .max(params.out_filter_score_min); - let thresh2 = - ((params.out_filter_score_min_over_lread * (len2 as f64 - 1.0)) as i32) - .max(params.out_filter_score_min); + let thresh1 = ((params.out_filter_score_min_over_lread * (len1 as f64 - 1.0)) as i32) + .max(params.out_filter_score_min); + let thresh2 = ((params.out_filter_score_min_over_lread * (len2 as f64 - 1.0)) as i32) + .max(params.out_filter_score_min); let best_m1 = mate1_transcripts .into_iter() @@ -1041,9 +1084,7 @@ fn filter_paired_transcripts(paired_alns: &mut Vec, params: &Pa // the read is unmapped. // // Step 1: find the best pair and check quality thresholds on it. - let best_pa = paired_alns - .iter() - .max_by_key(|pa| pa.combined_wt_score); + let best_pa = paired_alns.iter().max_by_key(|pa| pa.combined_wt_score); if let Some(best) = best_pa { let mate1_len = (best.mate1_region.1 - best.mate1_region.0) as f64; let mate2_len = (best.mate2_region.1 - best.mate2_region.0) as f64; @@ -1059,8 +1100,7 @@ fn filter_paired_transcripts(paired_alns: &mut Vec, params: &Pa // per-mate spans are small (penalty ≈ −2 each → sum penalty −4 vs combined penalty −2). let combined_score = best.combined_wt_score; if combined_score < params.out_filter_score_min - || combined_score - < (params.out_filter_score_min_over_lread * combined_lread_m1) as i32 + || combined_score < (params.out_filter_score_min_over_lread * combined_lread_m1) as i32 { paired_alns.clear(); return; @@ -1074,8 +1114,7 @@ fn filter_paired_transcripts(paired_alns: &mut Vec, params: &Pa // which directly mirrors STAR's joint transcript nMatch without extension inflation. let combined_match = best.combined_n_match; if combined_match < params.out_filter_match_nmin - || combined_match - < (params.out_filter_match_nmin_over_lread * combined_lread_m1) as u32 + || combined_match < (params.out_filter_match_nmin_over_lread * combined_lread_m1) as u32 { paired_alns.clear(); return; @@ -1750,4 +1789,55 @@ mod tests { assert!(mate1_is_mapped); } } + + #[test] + fn shuffle_tied_prefix_is_deterministic() { + // Same seed + same input → same permutation on reruns. + let items: Vec<(i32, u32)> = (0..8).map(|i| (100, i)).collect(); + let mut a = items.clone(); + let mut b = items.clone(); + shuffle_tied_prefix(&mut a, |t| t.0, 12345); + shuffle_tied_prefix(&mut b, |t| t.0, 12345); + assert_eq!(a, b); + } + + #[test] + fn shuffle_tied_prefix_respects_ties() { + // Only the top-score prefix gets shuffled; lower-scored tail is left alone. + let mut items = vec![(100, 0u32), (100, 1), (100, 2), (50, 3), (40, 4)]; + shuffle_tied_prefix(&mut items, |t| t.0, 777); + // Last two elements (non-tied) stay in place. + assert_eq!(items[3], (50, 3)); + assert_eq!(items[4], (40, 4)); + // Tied prefix contains the original three items in some order. + let mut top: Vec = items[..3].iter().map(|t| t.1).collect(); + top.sort(); + assert_eq!(top, vec![0, 1, 2]); + } + + #[test] + fn shuffle_tied_prefix_different_seeds_can_diverge() { + // Probabilistic: for a tied set of 8, at least two seeds should disagree + // on the chosen primary. (Exhaustive over a small seed range is fine.) + let base: Vec<(i32, u32)> = (0..8).map(|i| (100, i)).collect(); + let mut firsts = std::collections::HashSet::new(); + for seed in 0..32u64 { + let mut v = base.clone(); + shuffle_tied_prefix(&mut v, |t| t.0, seed); + firsts.insert(v[0].1); + } + assert!( + firsts.len() >= 2, + "expected different seeds to pick different primaries, got {:?}", + firsts + ); + } + + #[test] + fn shuffle_tied_prefix_noop_when_no_ties() { + let mut items = vec![(100, 0u32), (90, 1), (80, 2)]; + let before = items.clone(); + shuffle_tied_prefix(&mut items, |t| t.0, 42); + assert_eq!(items, before); + } } diff --git a/src/align/stitch.rs b/src/align/stitch.rs index e2ceeff..77ac2ec 100644 --- a/src/align/stitch.rs +++ b/src/align/stitch.rs @@ -2407,7 +2407,6 @@ pub(crate) fn stitch_seeds_with_jdb_debug( Ok(transcripts) } - /// Shared core: preprocessing + recursive stitcher, returns working transcripts + context. #[allow(clippy::too_many_arguments)] fn stitch_seeds_core( @@ -2750,8 +2749,6 @@ fn stitch_seeds_core( )) } - - #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 6c930dd..227b706 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,7 +144,10 @@ fn align_reads(params: &Parameters) -> anyhow::Result<()> { let quant_ctx: Option> = if params.quant_gene_counts() { let gtf_path = params.sjdb_gtf_file.as_ref().unwrap(); - info!("quantMode GeneCounts: building gene annotation from {}", gtf_path.display()); + info!( + "quantMode GeneCounts: building gene annotation from {}", + gtf_path.display() + ); let ctx = crate::quant::QuantContext::build(gtf_path, &index.genome)?; Some(std::sync::Arc::new(ctx)) } else { @@ -223,8 +226,22 @@ fn run_single_pass( // Route to single-end or paired-end mode match params.read_files_in.len() { - 1 => align_reads_single_end(params, index, &mut writer, &stats, &sj_stats, quant.as_ref()), - 2 => align_reads_paired_end(params, index, &mut writer, &stats, &sj_stats, quant.as_ref()), + 1 => align_reads_single_end( + params, + index, + &mut writer, + &stats, + &sj_stats, + quant.as_ref(), + ), + 2 => align_reads_paired_end( + params, + index, + &mut writer, + &stats, + &sj_stats, + quant.as_ref(), + ), n => anyhow::bail!("Invalid number of read files: {} (expected 1 or 2)", n), }?; } @@ -241,8 +258,22 @@ fn run_single_pass( // Route to single-end or paired-end mode (same functions as SAM, generic!) match params.read_files_in.len() { - 1 => align_reads_single_end(params, index, &mut writer, &stats, &sj_stats, quant.as_ref()), - 2 => align_reads_paired_end(params, index, &mut writer, &stats, &sj_stats, quant.as_ref()), + 1 => align_reads_single_end( + params, + index, + &mut writer, + &stats, + &sj_stats, + quant.as_ref(), + ), + 2 => align_reads_paired_end( + params, + index, + &mut writer, + &stats, + &sj_stats, + quant.as_ref(), + ), n => anyhow::bail!("Invalid number of read files: {} (expected 1 or 2)", n), }?; @@ -345,8 +376,22 @@ fn run_pass1( // Align reads (single-end or paired-end); no quant counting in pass 1 match params.read_files_in.len() { - 1 => align_reads_single_end(¶ms_pass1, index, &mut null_writer, &stats, &sj_stats, None)?, - 2 => align_reads_paired_end(¶ms_pass1, index, &mut null_writer, &stats, &sj_stats, None)?, + 1 => align_reads_single_end( + ¶ms_pass1, + index, + &mut null_writer, + &stats, + &sj_stats, + None, + )?, + 2 => align_reads_paired_end( + ¶ms_pass1, + index, + &mut null_writer, + &stats, + &sj_stats, + None, + )?, n => anyhow::bail!("Invalid number of read files: {} (expected 1 or 2)", n), } @@ -572,7 +617,8 @@ fn align_reads_single_end( // Gene-level quantification (lock-free atomic counts) if let Some(ref q) = quant { - q.counts.count_se_read(&transcripts, n_for_mapq, &q.gene_ann); + q.counts + .count_se_read(&transcripts, n_for_mapq, &q.gene_ann); } // Record junction statistics diff --git a/src/params.rs b/src/params.rs index f9f9455..10e9d5d 100644 --- a/src/params.rs +++ b/src/params.rs @@ -242,6 +242,10 @@ pub struct Parameters { #[arg(long = "runThreadN", default_value_t = 1)] pub run_thread_n: usize, + /// Random number generator seed for tie-breaking among equal-scoring alignments + #[arg(long = "runRNGseed", default_value_t = 777)] + pub run_rng_seed: u64, + // ── Genome ────────────────────────────────────────────────────────── /// Path to genome index directory #[arg(long = "genomeDir", default_value = "./GenomeDir")] @@ -743,6 +747,7 @@ mod tests { let p = parse(&["--readFilesIn", "reads.fq"]); assert_eq!(p.run_mode, RunMode::AlignReads); assert_eq!(p.run_thread_n, 1); + assert_eq!(p.run_rng_seed, 777); assert_eq!(p.genome_dir, PathBuf::from("./GenomeDir")); assert_eq!(p.genome_sa_index_nbases, 14); assert_eq!(p.genome_chr_bin_nbits, 18); @@ -997,6 +1002,12 @@ mod tests { assert_eq!(p.win_bin_window_dist(), 81_920); // 2^14 * 5 } + #[test] + fn run_rng_seed_override() { + let p = parse(&["--readFilesIn", "r.fq", "--runRNGseed", "42"]); + assert_eq!(p.run_rng_seed, 42); + } + #[test] fn sj_stitch_mismatch() { let p = parse(&[ diff --git a/src/quant/mod.rs b/src/quant/mod.rs index 41c2d7a..f9884f9 100644 --- a/src/quant/mod.rs +++ b/src/quant/mod.rs @@ -234,13 +234,8 @@ impl GeneCounts { } /// Write `ReadsPerGene.out.tab` in STAR's format. - pub fn write_output( - &self, - path: &Path, - gene_ann: &GeneAnnotation, - ) -> Result<(), Error> { - let mut file = std::fs::File::create(path) - .map_err(|e| Error::io(e, path))?; + pub fn write_output(&self, path: &Path, gene_ann: &GeneAnnotation) -> Result<(), Error> { + let mut file = std::fs::File::create(path).map_err(|e| Error::io(e, path))?; macro_rules! wl { ($($arg:tt)*) => { @@ -350,10 +345,7 @@ impl QuantContext { let exons = crate::junction::gtf::parse_gtf(gtf_path)?; let gene_ann = GeneAnnotation::from_gtf_exons(&exons, genome); let n = gene_ann.n_genes(); - log::info!( - "quantMode GeneCounts: {} genes loaded from GTF", - n - ); + log::info!("quantMode GeneCounts: {} genes loaded from GTF", n); let counts = GeneCounts::new(n); Ok(QuantContext { gene_ann, counts }) } @@ -381,7 +373,13 @@ mod tests { } } - fn make_gtf_exon(seqname: &str, start: u64, end: u64, strand: char, gene_id: &str) -> GtfRecord { + fn make_gtf_exon( + seqname: &str, + start: u64, + end: u64, + strand: char, + gene_id: &str, + ) -> GtfRecord { let mut attrs = std::collections::HashMap::new(); attrs.insert("gene_id".to_string(), gene_id.to_string()); attrs.insert("transcript_id".to_string(), "T1".to_string()); @@ -401,7 +399,12 @@ mod tests { genome_start: gs, genome_end: ge, is_reverse, - exons: vec![Exon { genome_start: gs, genome_end: ge, read_start: 0, read_end: (ge - gs) as usize }], + exons: vec![Exon { + genome_start: gs, + genome_end: ge, + read_start: 0, + read_end: (ge - gs) as usize, + }], cigar: vec![], score: 100, n_mismatch: 0, @@ -417,7 +420,7 @@ mod tests { fn test_gene_annotation_basic() { let genome = make_genome(); let exons = vec![ - make_gtf_exon("chr1", 101, 200, '+', "G1"), // 0-based: [100, 200) + make_gtf_exon("chr1", 101, 200, '+', "G1"), // 0-based: [100, 200) make_gtf_exon("chr1", 301, 400, '+', "G1"), make_gtf_exon("chr1", 501, 600, '-', "G2"), ]; @@ -426,7 +429,7 @@ mod tests { assert_eq!(ann.gene_ids[0], "G1"); assert_eq!(ann.gene_ids[1], "G2"); assert!(!ann.gene_is_reverse[0]); // G1 is + - assert!(ann.gene_is_reverse[1]); // G2 is - + assert!(ann.gene_is_reverse[1]); // G2 is - assert_eq!(ann.chr_exons[0].len(), 3); }