Skip to content

Commit fcda494

Browse files
committed
feat: add --batch-size CLI flag for manual genome-per-batch override (.verify-implement-disk-budget)
Add --batch-size <N> flag that partitions genomes into batches of exactly N genomes each, overriding --batch-bytes and --max-disk. Adds partition_into_batches_by_count() and run_batch_alignment_by_count() to batch_align.rs. Includes unit test for count-based partitioning. This was a missing requirement from the implement-disk-budget task (FLIP verification found it absent).
1 parent fab957a commit fcda494

File tree

2 files changed

+216
-6
lines changed

2 files changed

+216
-6
lines changed

src/batch_align.rs

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,21 @@ pub fn partition_into_batches_by_bp(genomes: Vec<GenomeInfo>, max_bp: u64) -> Ve
419419
batches
420420
}
421421

422+
/// Partition genomes into batches of at most `max_count` genomes each.
423+
/// This is the `--batch-size` code path: the user directly controls how many
424+
/// genomes go into each batch.
425+
pub fn partition_into_batches_by_count(genomes: Vec<GenomeInfo>, max_count: usize) -> Vec<GenomeBatch> {
426+
let mut batches: Vec<GenomeBatch> = Vec::new();
427+
for chunk in genomes.chunks(max_count) {
428+
let mut batch = GenomeBatch::new();
429+
for g in chunk {
430+
batch.add(g.clone());
431+
}
432+
batches.push(batch);
433+
}
434+
batches
435+
}
436+
422437
/// Write a subset of genomes to a temporary FASTA file
423438
pub fn write_batch_fasta(batch: &GenomeBatch, output_path: &Path) -> Result<()> {
424439
let mut output = File::create(output_path)
@@ -1085,6 +1100,133 @@ pub fn run_batch_alignment_generic(
10851100
Ok(merged_paf)
10861101
}
10871102

1103+
/// Run batch alignment partitioned by genome count (`--batch-size`).
1104+
///
1105+
/// Each batch contains at most `max_count` genomes. The alignment loop
1106+
/// follows the same prepare/align/cleanup protocol as the byte-based path.
1107+
pub fn run_batch_alignment_by_count(
1108+
fasta_files: &[String],
1109+
max_count: usize,
1110+
aligner: &dyn BatchAligner,
1111+
config: &BatchAlignConfig,
1112+
tempdir: Option<&str>,
1113+
) -> Result<tempfile::NamedTempFile> {
1114+
let temp_base = if let Some(dir) = tempdir {
1115+
PathBuf::from(dir)
1116+
} else if let Ok(tmpdir) = std::env::var("TMPDIR") {
1117+
PathBuf::from(tmpdir)
1118+
} else {
1119+
PathBuf::from("/tmp")
1120+
};
1121+
1122+
let batch_dir = temp_base.join(format!("sweepga_batch_{}", std::process::id()));
1123+
std::fs::create_dir_all(&batch_dir)?;
1124+
1125+
if !config.quiet {
1126+
eprintln!("[batch] Scanning input files for genome sizes...");
1127+
}
1128+
let genomes = parse_genome_sizes(fasta_files)?;
1129+
let total_bp: u64 = genomes.iter().map(|g| g.total_bp).sum();
1130+
1131+
if genomes.is_empty() {
1132+
anyhow::bail!("No genomes found in input files");
1133+
}
1134+
1135+
let batches = partition_into_batches_by_count(genomes, max_count);
1136+
1137+
if batches.len() == 1 {
1138+
if !config.quiet {
1139+
eprintln!(
1140+
"[batch] All genomes ({}) fit in single batch, no batching needed",
1141+
format_bytes(total_bp),
1142+
);
1143+
}
1144+
let _ = std::fs::remove_dir_all(&batch_dir);
1145+
return aligner.align_single(fasta_files, tempdir);
1146+
}
1147+
1148+
report_batch_plan(&batches, total_bp, config.quiet);
1149+
1150+
let mut batch_files: Vec<PathBuf> = Vec::new();
1151+
for (i, batch) in batches.iter().enumerate() {
1152+
let batch_subdir = batch_dir.join(format!("batch_{}", i));
1153+
std::fs::create_dir_all(&batch_subdir)?;
1154+
let batch_path = batch_subdir.join("genomes.fa");
1155+
if !config.quiet {
1156+
eprintln!(
1157+
"[batch] Writing batch {} ({} genomes)...",
1158+
i + 1,
1159+
batch.genomes.len()
1160+
);
1161+
}
1162+
write_batch_fasta(batch, &batch_path)?;
1163+
batch_files.push(batch_path);
1164+
}
1165+
1166+
aligner.prepare_all(&batch_files, config.quiet)?;
1167+
1168+
let merged_paf = tempfile::NamedTempFile::with_suffix(".paf")?;
1169+
let mut merged_output = File::create(merged_paf.path())?;
1170+
1171+
let mut total_alignments = 0;
1172+
let num_batches = batch_files.len();
1173+
1174+
for i in 0..num_batches {
1175+
aligner.prepare_target(i, config.quiet)?;
1176+
1177+
for j in 0..num_batches {
1178+
if !config.quiet {
1179+
if i == j {
1180+
eprintln!("[batch] Aligning batch {} to itself...", i + 1);
1181+
} else {
1182+
eprintln!("[batch] Aligning batch {} vs batch {}...", j + 1, i + 1);
1183+
}
1184+
}
1185+
1186+
let paf_bytes = aligner.align(&batch_files[j], &batch_files[i])?;
1187+
let line_count = paf_bytes.iter().filter(|&&b| b == b'\n').count();
1188+
total_alignments += line_count;
1189+
merged_output.write_all(&paf_bytes)?;
1190+
1191+
if !config.quiet {
1192+
eprintln!("[batch] {} alignments", line_count);
1193+
}
1194+
}
1195+
1196+
aligner.cleanup_target(i, config.quiet)?;
1197+
}
1198+
1199+
aligner.cleanup_all()?;
1200+
1201+
if !config.quiet {
1202+
eprintln!(
1203+
"[batch] Completed batch alignment: {} total alignments",
1204+
total_alignments
1205+
);
1206+
}
1207+
1208+
drop(merged_output);
1209+
let genome_prefixes: Vec<String> = batches
1210+
.iter()
1211+
.flat_map(|b| b.genomes.iter().map(|g| g.prefix.clone()))
1212+
.collect();
1213+
let verification = verify_batch_completeness(
1214+
merged_paf.path(),
1215+
&genome_prefixes,
1216+
!config.keep_self,
1217+
)?;
1218+
if !config.quiet {
1219+
eprintln!(
1220+
"[batch] Verification: {}/{} genome pairs present",
1221+
verification.found_pairs, verification.expected_pairs
1222+
);
1223+
}
1224+
1225+
let _ = std::fs::remove_dir_all(&batch_dir);
1226+
1227+
Ok(merged_paf)
1228+
}
1229+
10881230
/// Result of batch completeness verification.
10891231
#[derive(Debug, Clone)]
10901232
pub struct BatchVerification {
@@ -1294,6 +1436,36 @@ mod tests {
12941436
assert_eq!(batches[0].genomes.len(), 3);
12951437
}
12961438

1439+
#[test]
1440+
fn test_partition_by_count() {
1441+
let genomes = vec![
1442+
GenomeInfo { prefix: "A#1#".to_string(), total_bp: 100_000_000, source_file: PathBuf::from("a.fa") },
1443+
GenomeInfo { prefix: "B#1#".to_string(), total_bp: 100_000_000, source_file: PathBuf::from("a.fa") },
1444+
GenomeInfo { prefix: "C#1#".to_string(), total_bp: 100_000_000, source_file: PathBuf::from("a.fa") },
1445+
GenomeInfo { prefix: "D#1#".to_string(), total_bp: 100_000_000, source_file: PathBuf::from("a.fa") },
1446+
GenomeInfo { prefix: "E#1#".to_string(), total_bp: 100_000_000, source_file: PathBuf::from("a.fa") },
1447+
];
1448+
1449+
// 2 genomes per batch -> 3 batches (AB, CD, E)
1450+
let batches = partition_into_batches_by_count(genomes.clone(), 2);
1451+
assert_eq!(batches.len(), 3);
1452+
assert_eq!(batches[0].genomes.len(), 2);
1453+
assert_eq!(batches[1].genomes.len(), 2);
1454+
assert_eq!(batches[2].genomes.len(), 1);
1455+
1456+
// 5 genomes per batch -> 1 batch
1457+
let batches = partition_into_batches_by_count(genomes.clone(), 5);
1458+
assert_eq!(batches.len(), 1);
1459+
assert_eq!(batches[0].genomes.len(), 5);
1460+
1461+
// 1 genome per batch -> 5 batches
1462+
let batches = partition_into_batches_by_count(genomes, 1);
1463+
assert_eq!(batches.len(), 5);
1464+
for b in &batches {
1465+
assert_eq!(b.genomes.len(), 1);
1466+
}
1467+
}
1468+
12971469
#[test]
12981470
fn test_batch_size_from_budget() {
12991471
// 8 yeast genomes, ~12M bp each, 8 threads, no zstd

src/main.rs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ struct Args {
287287
#[clap(long = "batch-bytes", value_parser = parse_metric_number, help_heading = "Alignment options")]
288288
batch_bytes: Option<u64>,
289289

290+
/// Explicit number of genomes per batch (manual override). Overrides --batch-bytes and --max-disk.
291+
#[clap(long = "batch-size", help_heading = "Alignment options")]
292+
batch_size: Option<usize>,
293+
290294
/// Maximum disk space for temporary files during alignment (e.g., "100G", "500M").
291295
/// Computes batch size automatically to stay within budget. Overrides --batch-bytes.
292296
#[clap(long = "max-disk", value_parser = parse_metric_number, help_heading = "Alignment options")]
@@ -1209,6 +1213,7 @@ fn align_multiple_fastas(
12091213
map_pct_identity: Option<String>,
12101214
all_pairs: bool,
12111215
batch_bytes: Option<u64>,
1216+
batch_size: Option<usize>,
12121217
max_disk: Option<u64>,
12131218
threads: usize,
12141219
keep_self: bool,
@@ -1234,13 +1239,22 @@ fn align_multiple_fastas(
12341239
}
12351240
}
12361241

1237-
// Resolve effective batch bytes from --max-disk / --batch-bytes
1238-
let effective_batch_bytes = batch_align::resolve_batch_bytes(
1239-
max_disk, batch_bytes, fasta_files, threads, zstd_compress, quiet,
1240-
)?;
1242+
// --batch-size overrides --batch-bytes and --max-disk
1243+
if batch_size.is_some() && (batch_bytes.is_some() || max_disk.is_some()) && !quiet {
1244+
eprintln!("[batch] WARNING: --batch-size overrides --batch-bytes and --max-disk");
1245+
}
1246+
1247+
// Resolve effective batch bytes from --max-disk / --batch-bytes (used when --batch-size is not set)
1248+
let effective_batch_bytes = if batch_size.is_none() {
1249+
batch_align::resolve_batch_bytes(
1250+
max_disk, batch_bytes, fasta_files, threads, zstd_compress, quiet,
1251+
)?
1252+
} else {
1253+
None
1254+
};
12411255

1242-
// Check for batch mode
1243-
if let Some(max_index_bytes) = effective_batch_bytes {
1256+
// Check for batch mode (either --batch-size or byte-based batching)
1257+
if batch_size.is_some() || effective_batch_bytes.is_some() {
12441258
let batch_config = batch_align::BatchAlignConfig {
12451259
frequency,
12461260
threads,
@@ -1268,6 +1282,27 @@ fn align_multiple_fastas(
12681282
)),
12691283
};
12701284

1285+
// --batch-size: partition by genome count
1286+
if let Some(size) = batch_size {
1287+
if !quiet {
1288+
timing.log(
1289+
"batch",
1290+
&format!(
1291+
"Batch mode enabled: {} genomes per batch ({})",
1292+
size, aligner_name,
1293+
),
1294+
);
1295+
}
1296+
return batch_align::run_batch_alignment_by_count(
1297+
fasta_files,
1298+
size,
1299+
aligner.as_ref(),
1300+
&batch_config,
1301+
tempdir,
1302+
);
1303+
}
1304+
1305+
let max_index_bytes = effective_batch_bytes.unwrap();
12711306
if !quiet {
12721307
timing.log(
12731308
"batch",
@@ -2997,6 +3032,7 @@ fn main() -> Result<()> {
29973032
args.map_pct_identity.clone(),
29983033
args.all_pairs,
29993034
args.batch_bytes,
3035+
args.batch_size,
30003036
args.max_disk,
30013037
args.threads,
30023038
args.keep_self,
@@ -3089,6 +3125,7 @@ fn main() -> Result<()> {
30893125
args.map_pct_identity.clone(),
30903126
args.all_pairs,
30913127
args.batch_bytes,
3128+
args.batch_size,
30923129
args.max_disk,
30933130
args.threads,
30943131
args.keep_self,
@@ -3179,6 +3216,7 @@ fn main() -> Result<()> {
31793216
args.map_pct_identity.clone(),
31803217
args.all_pairs,
31813218
args.batch_bytes,
3219+
args.batch_size,
31823220
args.max_disk,
31833221
args.threads,
31843222
args.keep_self,

0 commit comments

Comments
 (0)