Skip to content

Commit b259737

Browse files
authored
feat: add batch size guardrails to prevent GPU memory exhaustion (#135)
* feat: add batch size guardrails to prevent GPU memory exhaustion Implements two-stage batch size validation: 1. Hard maximum (512): Validates at CLI parse time to reject absurdly large values immediately, preventing GPU memory exhaustion and timeouts. 2. Per-file adjustment: At runtime, automatically adjusts batch size down to the estimated segment count for short audio files, preventing unnecessary memory allocation and padding. Changes: - Add MAX_BATCH_SIZE constant (512) in src/constants.rs - Update parse_batch_size() validator to enforce maximum - Add per-file batch size adjustment in process_file() - Adjustment happens before create_batch_context() to save GPU memory - Log adjustments at DEBUG level to avoid spam with large file sets - Add comprehensive tests for new validation behavior Fixes issue where users could specify batch sizes like 2560 that caused GPU hangs and process termination. * fix: address code review findings Fixes identified by code review (Claude + Gemini): 1. Critical: Prevent effective_batch_size from becoming 0 - Empty or corrupt files with duration_hint=0 could set batch_size to 0 - This would trigger process_batch for every chunk (batch size 1) - Now keeps original batch_size when estimated_segments is 0 2. Low: Add trim() for whitespace handling in parse_batch_size - Config files/env vars may include leading/trailing whitespace - Now accepts inputs like " 32 " or " 64 " - Added test case to verify whitespace handling All 230 tests pass. * fix: address Claude code review feedback Fixed items 1-3 from Claude's review: 1. Channel capacity now uses effective_batch_size instead of batch_size - Ensures memory optimization is consistent for short files - Channel buffer size now matches adjusted batch allocation 2. Added comment clarifying cast_possible_truncation scope - Documents the truncation happens in closure (u64 -> usize) - Notes it's safe in practice (would need 408 years of audio) 3. Tests now use MAX_BATCH_SIZE constant in assertions - If MAX_BATCH_SIZE changes, tests will catch mismatches - More robust than hardcoded "512" strings Items 4-6 deferred: - Item 4: .trim() consistency across validators (needs architectural decision) - Item 5: Extract adjustment logic (quality improvement, not blocking) - Item 6: Pre-existing code, not introduced by this PR
1 parent e9b1697 commit b259737

File tree

3 files changed

+115
-35
lines changed

3 files changed

+115
-35
lines changed

src/cli/validators.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
//!
33
//! Shared validation functions for CLI argument parsing.
44
5+
use crate::constants::MAX_BATCH_SIZE;
6+
57
/// Parse and validate confidence value (0.0-1.0).
68
pub fn parse_confidence(s: &str) -> Result<f32, String> {
79
let value: f32 = s
@@ -49,16 +51,25 @@ pub fn parse_longitude(s: &str) -> Result<f64, String> {
4951
parse_bounded_float(s, -180.0, 180.0, "longitude")
5052
}
5153

52-
/// Parse and validate batch size (must be at least 1).
54+
/// Parse and validate batch size (must be between 1 and `MAX_BATCH_SIZE`).
5355
pub fn parse_batch_size(s: &str) -> Result<usize, String> {
5456
let value: usize = s
57+
.trim()
5558
.parse()
5659
.map_err(|_| format!("'{s}' is not a valid number"))?;
5760

5861
if value < 1 {
5962
return Err(format!("batch_size must be at least 1, got {value}"));
6063
}
6164

65+
if value > MAX_BATCH_SIZE {
66+
return Err(format!(
67+
"batch_size must be between 1 and {MAX_BATCH_SIZE}, got {value}\n\n\
68+
This limit prevents GPU memory exhaustion.\n\
69+
If processing fails with batch_size={MAX_BATCH_SIZE}, try reducing it further or use --cpu."
70+
));
71+
}
72+
6273
Ok(value)
6374
}
6475

@@ -124,4 +135,40 @@ mod tests {
124135
assert!(parse_batch_size("-1").is_err());
125136
assert!(parse_batch_size("abc").is_err());
126137
}
138+
139+
#[test]
140+
fn test_parse_batch_size_at_maximum() {
141+
assert_eq!(parse_batch_size("512").ok(), Some(MAX_BATCH_SIZE));
142+
}
143+
144+
#[test]
145+
fn test_parse_batch_size_above_maximum() {
146+
let result = parse_batch_size("513");
147+
assert!(result.is_err());
148+
let err = result.unwrap_err();
149+
assert!(err.contains(&format!(
150+
"batch_size must be between 1 and {MAX_BATCH_SIZE}"
151+
)));
152+
assert!(err.contains("GPU memory exhaustion"));
153+
}
154+
155+
#[test]
156+
fn test_parse_batch_size_way_above_maximum() {
157+
let result = parse_batch_size("2560");
158+
assert!(result.is_err());
159+
let err = result.unwrap_err();
160+
assert!(err.contains(&format!(
161+
"batch_size must be between 1 and {MAX_BATCH_SIZE}"
162+
)));
163+
assert!(err.contains("GPU memory exhaustion"));
164+
}
165+
166+
#[test]
167+
fn test_parse_batch_size_with_whitespace() {
168+
// Test leading/trailing whitespace (common in config files)
169+
assert_eq!(parse_batch_size(" 32").ok(), Some(32));
170+
assert_eq!(parse_batch_size("32 ").ok(), Some(32));
171+
assert_eq!(parse_batch_size(" 32 ").ok(), Some(32));
172+
assert_eq!(parse_batch_size(" 64 ").ok(), Some(64));
173+
}
127174
}

src/constants.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ pub const DEFAULT_OVERLAP: f32 = 0.0;
2020
/// See `determine_default_batch_size()` in `lib.rs` for dynamic batch size selection.
2121
pub const DEFAULT_BATCH_SIZE: usize = 8;
2222

23+
/// Maximum allowed batch size to prevent GPU memory exhaustion.
24+
///
25+
/// This hard limit prevents users from specifying absurdly large batch sizes
26+
/// that would cause GPU memory exhaustion and system hangs. The limit is
27+
/// conservative enough to work on most consumer GPUs while still allowing
28+
/// efficient processing of large files.
29+
///
30+
/// Batch sizes larger than the number of segments in a file are automatically
31+
/// adjusted down at runtime to avoid unnecessary memory allocation and padding.
32+
pub const MAX_BATCH_SIZE: usize = 512;
33+
2334
/// Batch size defaults by execution provider and model type.
2435
pub mod batch_size {
2536
/// CPU batch size for all models.

src/pipeline/processor.rs

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -435,14 +435,63 @@ pub fn process_file(
435435
None
436436
};
437437

438-
// Create batch context for GPU memory efficiency (if batch_size > 1)
438+
// Calculate segment parameters (needed for batch size adjustment and progress bar)
439+
#[allow(
440+
clippy::cast_possible_truncation,
441+
clippy::cast_sign_loss,
442+
clippy::cast_precision_loss
443+
)]
444+
let segment_samples = (segment_duration * target_rate as f32) as usize;
445+
#[allow(
446+
clippy::cast_possible_truncation,
447+
clippy::cast_sign_loss,
448+
clippy::cast_precision_loss
449+
)]
450+
let overlap_samples = (overlap * target_rate as f32) as usize;
451+
452+
// Estimate segment count for batch size adjustment and progress bar
453+
let estimated_segments = estimate_segment_count(duration_hint, segment_duration, overlap);
454+
455+
// Adjust batch size if it exceeds the estimated segment count
456+
// This prevents unnecessary memory allocation and padding for short files
457+
// Cast is safe in practice: would need ~408 years of audio to overflow on 32-bit
458+
#[allow(clippy::cast_possible_truncation)]
459+
let effective_batch_size = estimated_segments.map_or(batch_size, |est_segments| {
460+
let est_segments_usize = est_segments as usize;
461+
// Handle empty or corrupt files - never set batch size to 0
462+
if est_segments_usize == 0 {
463+
batch_size
464+
} else if batch_size > est_segments_usize {
465+
debug!(
466+
"Batch size {} exceeds segment count ({} segments), using {} for this file",
467+
batch_size, est_segments_usize, est_segments_usize
468+
);
469+
est_segments_usize
470+
} else {
471+
batch_size
472+
}
473+
});
474+
475+
// Log audio info
476+
if let Some(duration) = duration_hint {
477+
info!(
478+
"Processing ~{} of audio ({:.1}s)",
479+
progress::format_duration(duration),
480+
duration
481+
);
482+
} else {
483+
info!("Processing audio (duration unknown)");
484+
}
485+
486+
// Create batch context for GPU memory efficiency (if effective_batch_size > 1)
439487
// Context is created once and reused for all batches in this file
440-
let mut batch_context = if batch_size > 1 {
441-
match classifier.create_batch_context(batch_size) {
488+
// IMPORTANT: This uses effective_batch_size to avoid over-allocating memory
489+
let mut batch_context = if effective_batch_size > 1 {
490+
match classifier.create_batch_context(effective_batch_size) {
442491
Ok(ctx) => {
443492
debug!(
444493
"Created BatchInferenceContext for up to {} segments ({} bytes input buffer)",
445-
batch_size,
494+
effective_batch_size,
446495
ctx.input_buffer_bytes()
447496
);
448497
Some(ctx)
@@ -460,34 +509,6 @@ pub fn process_file(
460509
None
461510
};
462511

463-
// Log audio info
464-
if let Some(duration) = duration_hint {
465-
info!(
466-
"Processing ~{} of audio ({:.1}s)",
467-
progress::format_duration(duration),
468-
duration
469-
);
470-
} else {
471-
info!("Processing audio (duration unknown)");
472-
}
473-
474-
// Calculate segment parameters
475-
#[allow(
476-
clippy::cast_possible_truncation,
477-
clippy::cast_sign_loss,
478-
clippy::cast_precision_loss
479-
)]
480-
let segment_samples = (segment_duration * target_rate as f32) as usize;
481-
#[allow(
482-
clippy::cast_possible_truncation,
483-
clippy::cast_sign_loss,
484-
clippy::cast_precision_loss
485-
)]
486-
let overlap_samples = (overlap * target_rate as f32) as usize;
487-
488-
// Estimate segment count for progress bar
489-
let estimated_segments = estimate_segment_count(duration_hint, segment_duration, overlap);
490-
491512
// Create progress bar
492513
let file_name = input_path
493514
.file_name()
@@ -522,7 +543,8 @@ pub fn process_file(
522543
let progress_guard = progress::ProgressGuard::new(segment_progress, "Inference complete");
523544

524545
// Create channel with capacity for 2 batches (backpressure)
525-
let channel_capacity = batch_size.saturating_mul(2).max(4);
546+
// Use effective_batch_size to match adjusted memory allocation
547+
let channel_capacity = effective_batch_size.saturating_mul(2).max(4);
526548
let (tx, rx) = sync_channel::<ChunkResult>(channel_capacity);
527549

528550
// Spawn decode thread
@@ -546,7 +568,7 @@ pub fn process_file(
546568
classifier,
547569
input_path,
548570
min_confidence,
549-
batch_size,
571+
effective_batch_size,
550572
progress_guard.get(),
551573
&mut batch_context,
552574
reporter,

0 commit comments

Comments
 (0)