diff --git a/src/testing/fault_injection.rs b/src/testing/fault_injection.rs new file mode 100644 index 000000000..f9f8d23bd --- /dev/null +++ b/src/testing/fault_injection.rs @@ -0,0 +1,432 @@ +//! Fault injection framework for testing retry, failover, and circuit breaker behavior. +//! +//! Provides [`FaultInjector`] which can be attached to [`StubLlm`](super::StubLlm) to +//! produce configurable error sequences, random failures, and delays. +//! +//! # Example +//! +//! ```rust,no_run +//! use ironclaw::testing::fault_injection::*; +//! +//! // Fail twice with transient errors, then succeed +//! let injector = FaultInjector::sequence([ +//! FaultAction::Fail(FaultType::RequestFailed), +//! FaultAction::Fail(FaultType::RateLimited { retry_after: None }), +//! FaultAction::Succeed, +//! ]); +//! ``` + +use std::sync::Mutex; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Duration; + +use crate::llm::error::LlmError; + +/// The type of fault to inject. +#[derive(Debug, Clone)] +pub enum FaultType { + /// Transient request failure (retryable). + RequestFailed, + /// Rate limited with optional retry-after duration. + RateLimited { retry_after: Option }, + /// Authentication failure (non-retryable). + AuthFailed, + /// Invalid response from provider (retryable). + InvalidResponse, + /// I/O error (retryable). + IoError, + /// Context length exceeded (non-retryable). + ContextLengthExceeded, + /// Session expired (transient for circuit breaker, not retryable). + SessionExpired, +} + +impl FaultType { + /// Convert to the corresponding `LlmError`. + pub fn to_llm_error(&self, provider: &str) -> LlmError { + match self { + FaultType::RequestFailed => LlmError::RequestFailed { + provider: provider.to_string(), + reason: "injected fault: request failed".to_string(), + }, + FaultType::RateLimited { retry_after } => LlmError::RateLimited { + provider: provider.to_string(), + retry_after: *retry_after, + }, + FaultType::AuthFailed => LlmError::AuthFailed { + provider: provider.to_string(), + }, + FaultType::InvalidResponse => LlmError::InvalidResponse { + provider: provider.to_string(), + reason: "injected fault: invalid response".to_string(), + }, + FaultType::IoError => LlmError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "injected fault: connection reset", + )), + FaultType::ContextLengthExceeded => LlmError::ContextLengthExceeded { + used: 100_000, + limit: 50_000, + }, + FaultType::SessionExpired => LlmError::SessionExpired { + provider: provider.to_string(), + }, + } + } +} + +/// Action to take on a given call. +#[derive(Debug, Clone)] +pub enum FaultAction { + /// Return a successful response. + Succeed, + /// Return an error of the given type. + Fail(FaultType), + /// Sleep for the given duration, then succeed. + Delay(Duration), +} + +/// How the fault sequence is consumed. +#[derive(Debug, Clone)] +pub enum FaultMode { + /// Play the sequence once, then succeed for all subsequent calls. + SequenceOnce, + /// Loop the sequence forever. + SequenceLoop, + /// Fail randomly at the given rate (0.0 = never, 1.0 = always) with + /// the specified fault type. Uses a seeded RNG for reproducibility. + /// The seed is stored so that [`FaultInjector::reset()`] can re-initialize + /// the RNG for test reproducibility. + Random { + error_rate: f64, + fault: FaultType, + seed: u64, + }, +} + +/// A configurable fault injector for [`StubLlm`](super::StubLlm). +/// +/// Thread-safe: uses atomic call counter and mutex-protected RNG. +pub struct FaultInjector { + actions: Vec, + mode: FaultMode, + call_index: AtomicU32, + /// Seeded RNG for Random mode, behind Mutex for Sync. + rng_state: Mutex, +} + +impl std::fmt::Debug for FaultInjector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FaultInjector") + .field("call_index", &self.call_index.load(Ordering::Relaxed)) + .field("mode", &self.mode) + .finish() + } +} + +impl FaultInjector { + /// Create a fault injector that plays actions once, then succeeds. + pub fn sequence(actions: impl IntoIterator) -> Self { + Self { + actions: actions.into_iter().collect(), + mode: FaultMode::SequenceOnce, + call_index: AtomicU32::new(0), + rng_state: Mutex::new(0), + } + } + + /// Create a fault injector that loops the action sequence forever. + pub fn sequence_loop(actions: impl IntoIterator) -> Self { + Self { + actions: actions.into_iter().collect(), + mode: FaultMode::SequenceLoop, + call_index: AtomicU32::new(0), + rng_state: Mutex::new(0), + } + } + + /// Create a fault injector with random failures at the given rate. + /// + /// # Panics + /// + /// Panics if `error_rate` is not in `0.0..=1.0` or is NaN. + /// + /// The seed is guarded against zero, which is a fixed point for xorshift. + pub fn random(error_rate: f64, fault: FaultType, seed: u64) -> Self { + assert!( + !error_rate.is_nan() && (0.0..=1.0).contains(&error_rate), + "error_rate must be in 0.0..=1.0 and not NaN, got {error_rate}" + ); + let seed = if seed == 0 { 1 } else { seed }; + Self { + actions: Vec::new(), + mode: FaultMode::Random { + error_rate, + fault, + seed, + }, + call_index: AtomicU32::new(0), + rng_state: Mutex::new(seed), + } + } + + /// Get the action for the next call. + pub fn next_action(&self) -> FaultAction { + let index = self.call_index.fetch_add(1, Ordering::Relaxed) as usize; + + match &self.mode { + FaultMode::SequenceOnce => { + if index < self.actions.len() { + self.actions[index].clone() + } else { + FaultAction::Succeed + } + } + FaultMode::SequenceLoop => { + if self.actions.is_empty() { + FaultAction::Succeed + } else { + self.actions[index % self.actions.len()].clone() + } + } + FaultMode::Random { + error_rate, fault, .. + } => { + // Simple xorshift64 PRNG for reproducible randomness. + let random_val = { + let mut state = self.rng_state.lock().unwrap_or_else(|p| p.into_inner()); + *state ^= *state << 13; + *state ^= *state >> 7; + *state ^= *state << 17; + (*state as f64) / (u64::MAX as f64) + }; + if random_val <= *error_rate { + FaultAction::Fail(fault.clone()) + } else { + FaultAction::Succeed + } + } + } + } + + /// Get the total number of calls made. + pub fn call_count(&self) -> u32 { + self.call_index.load(Ordering::Relaxed) + } + + /// Reset the injector to its initial state. + /// + /// For `Random` mode, re-initializes the RNG from the stored seed, + /// which is useful for test reproducibility. + /// For all modes, resets the call counter to zero. + pub fn reset(&self) { + self.call_index.store(0, Ordering::Relaxed); + if let FaultMode::Random { seed, .. } = &self.mode { + let mut state = self.rng_state.lock().unwrap_or_else(|p| p.into_inner()); + *state = *seed; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sequence_once_plays_then_succeeds() { + let injector = FaultInjector::sequence([ + FaultAction::Fail(FaultType::RequestFailed), + FaultAction::Fail(FaultType::RateLimited { retry_after: None }), + FaultAction::Succeed, + ]); + + // First two calls should fail + assert!(matches!( + injector.next_action(), + FaultAction::Fail(FaultType::RequestFailed) + )); + assert!(matches!( + injector.next_action(), + FaultAction::Fail(FaultType::RateLimited { .. }) + )); + // Third call is explicit succeed + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + // Beyond sequence: implicit succeed + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + assert_eq!(injector.call_count(), 5); + } + + #[test] + fn sequence_loop_repeats() { + let injector = FaultInjector::sequence_loop([ + FaultAction::Fail(FaultType::RequestFailed), + FaultAction::Succeed, + ]); + + assert!(matches!(injector.next_action(), FaultAction::Fail(_))); + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + assert!(matches!(injector.next_action(), FaultAction::Fail(_))); + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + } + + #[test] + fn random_mode_is_deterministic_with_seed() { + let injector1 = FaultInjector::random(0.5, FaultType::RequestFailed, 42); + let injector2 = FaultInjector::random(0.5, FaultType::RequestFailed, 42); + + let results1: Vec = (0..20) + .map(|_| matches!(injector1.next_action(), FaultAction::Fail(_))) + .collect(); + let results2: Vec = (0..20) + .map(|_| matches!(injector2.next_action(), FaultAction::Fail(_))) + .collect(); + + assert_eq!(results1, results2, "Same seed should produce same sequence"); + } + + #[test] + fn fault_type_produces_correct_llm_errors() { + let provider = "test-provider"; + + assert!(matches!( + FaultType::RequestFailed.to_llm_error(provider), + LlmError::RequestFailed { .. } + )); + assert!(matches!( + FaultType::RateLimited { + retry_after: Some(Duration::from_secs(5)) + } + .to_llm_error(provider), + LlmError::RateLimited { .. } + )); + assert!(matches!( + FaultType::AuthFailed.to_llm_error(provider), + LlmError::AuthFailed { .. } + )); + assert!(matches!( + FaultType::InvalidResponse.to_llm_error(provider), + LlmError::InvalidResponse { .. } + )); + assert!(matches!( + FaultType::IoError.to_llm_error(provider), + LlmError::Io(_) + )); + assert!(matches!( + FaultType::ContextLengthExceeded.to_llm_error(provider), + LlmError::ContextLengthExceeded { .. } + )); + assert!(matches!( + FaultType::SessionExpired.to_llm_error(provider), + LlmError::SessionExpired { .. } + )); + } + + #[test] + fn delay_action_exists() { + let injector = FaultInjector::sequence([FaultAction::Delay(Duration::from_millis(100))]); + assert!(matches!(injector.next_action(), FaultAction::Delay(_))); + } + + #[test] + fn random_seed_zero_does_not_always_fail() { + // seed=0 is a fixed point for xorshift; the constructor guards it to 1. + let injector = FaultInjector::random(0.5, FaultType::RequestFailed, 0); + let failures = (0..100) + .filter(|_| matches!(injector.next_action(), FaultAction::Fail(_))) + .count(); + assert!(failures < 100, "seed=0 must not produce stuck RNG"); + } + + #[test] + fn empty_sequence_always_succeeds() { + let injector = FaultInjector::sequence([]); + for _ in 0..10 { + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + } + } + + #[test] + fn reset_restores_random_rng_from_stored_seed() { + let injector = FaultInjector::random(0.5, FaultType::RequestFailed, 42); + let run1: Vec = (0..20) + .map(|_| matches!(injector.next_action(), FaultAction::Fail(_))) + .collect(); + + injector.reset(); + assert_eq!(injector.call_count(), 0); + + let run2: Vec = (0..20) + .map(|_| matches!(injector.next_action(), FaultAction::Fail(_))) + .collect(); + + assert_eq!(run1, run2, "reset() should reproduce the same sequence"); + } + + #[test] + #[should_panic(expected = "error_rate must be in 0.0..=1.0")] + fn random_rejects_error_rate_above_one() { + FaultInjector::random(1.5, FaultType::RequestFailed, 42); + } + + #[test] + #[should_panic(expected = "error_rate must be in 0.0..=1.0")] + fn random_rejects_negative_error_rate() { + FaultInjector::random(-0.1, FaultType::RequestFailed, 42); + } + + #[test] + #[should_panic(expected = "error_rate must be in 0.0..=1.0 and not NaN")] + fn random_rejects_nan_error_rate() { + FaultInjector::random(f64::NAN, FaultType::RequestFailed, 42); + } + + #[test] + fn error_rate_one_always_fails() { + let injector = FaultInjector::random(1.0, FaultType::RequestFailed, 42); + for _ in 0..100 { + assert!( + matches!(injector.next_action(), FaultAction::Fail(_)), + "error_rate=1.0 must always produce failures" + ); + } + } + + #[test] + fn error_rate_zero_never_fails() { + let injector = FaultInjector::random(0.0, FaultType::RequestFailed, 42); + for _ in 0..100 { + assert!( + matches!(injector.next_action(), FaultAction::Succeed), + "error_rate=0.0 must never produce failures" + ); + } + } + + #[tokio::test] + async fn delay_action_pauses_execution() { + tokio::time::pause(); + let injector = FaultInjector::sequence([ + FaultAction::Delay(Duration::from_secs(10)), + FaultAction::Succeed, + ]); + + // First action is a delay + let action = injector.next_action(); + assert!(matches!(action, FaultAction::Delay(d) if d == Duration::from_secs(10))); + + // Simulate what StubLlm does: sleep then succeed + if let FaultAction::Delay(d) = action { + let start = tokio::time::Instant::now(); + tokio::time::sleep(d).await; + let elapsed = start.elapsed(); + assert!( + elapsed >= Duration::from_secs(10), + "delay should have paused for at least 10s, got {elapsed:?}" + ); + } + + // Next action succeeds + assert!(matches!(injector.next_action(), FaultAction::Succeed)); + } +} diff --git a/src/testing/mod.rs b/src/testing/mod.rs index ff522e3ad..ba260eae3 100644 --- a/src/testing/mod.rs +++ b/src/testing/mod.rs @@ -19,9 +19,11 @@ //! ``` pub mod credentials; +pub mod fault_injection; use std::sync::Arc; use std::sync::Mutex; + use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use async_trait::async_trait; @@ -84,6 +86,9 @@ pub struct StubLlm { call_count: AtomicU32, should_fail: AtomicBool, error_kind: StubErrorKind, + /// Optional fault injector for fine-grained failure control. + /// When set, takes precedence over the `should_fail` / `error_kind` fields. + fault_injector: Option>, } impl StubLlm { @@ -95,6 +100,7 @@ impl StubLlm { call_count: AtomicU32::new(0), should_fail: AtomicBool::new(false), error_kind: StubErrorKind::Transient, + fault_injector: None, } } @@ -106,6 +112,7 @@ impl StubLlm { call_count: AtomicU32::new(0), should_fail: AtomicBool::new(true), error_kind: StubErrorKind::Transient, + fault_injector: None, } } @@ -117,6 +124,7 @@ impl StubLlm { call_count: AtomicU32::new(0), should_fail: AtomicBool::new(true), error_kind: StubErrorKind::NonTransient, + fault_injector: None, } } @@ -131,11 +139,39 @@ impl StubLlm { self.call_count.load(Ordering::Relaxed) } + /// Attach a fault injector for fine-grained failure control. + /// + /// When set, the injector's `next_action()` is consulted on every call, + /// taking precedence over the `should_fail` / `error_kind` fields. + pub fn with_fault_injector(mut self, injector: Arc) -> Self { + self.fault_injector = Some(injector); + self + } + /// Toggle whether calls should fail at runtime. pub fn set_failing(&self, fail: bool) { self.should_fail.store(fail, Ordering::Relaxed); } + /// Check the fault injector or should_fail flag, returning an error if + /// the call should fail, or None if it should succeed. + async fn check_faults(&self) -> Option { + if let Some(ref injector) = self.fault_injector { + match injector.next_action() { + fault_injection::FaultAction::Fail(fault) => { + return Some(fault.to_llm_error(&self.model_name)); + } + fault_injection::FaultAction::Delay(duration) => { + tokio::time::sleep(duration).await; + } + fault_injection::FaultAction::Succeed => {} + } + } else if self.should_fail.load(Ordering::Relaxed) { + return Some(self.make_error()); + } + None + } + fn make_error(&self) -> LlmError { match self.error_kind { StubErrorKind::Transient => LlmError::RequestFailed { @@ -168,8 +204,8 @@ impl LlmProvider for StubLlm { async fn complete(&self, _request: CompletionRequest) -> Result { self.call_count.fetch_add(1, Ordering::Relaxed); - if self.should_fail.load(Ordering::Relaxed) { - return Err(self.make_error()); + if let Some(err) = self.check_faults().await { + return Err(err); } Ok(CompletionResponse { content: self.response.clone(), @@ -186,8 +222,8 @@ impl LlmProvider for StubLlm { _request: ToolCompletionRequest, ) -> Result { self.call_count.fetch_add(1, Ordering::Relaxed); - if self.should_fail.load(Ordering::Relaxed) { - return Err(self.make_error()); + if let Some(err) = self.check_faults().await { + return Err(err); } Ok(ToolCompletionResponse { content: Some(self.response.clone()), @@ -1508,4 +1544,29 @@ mod tests { .await .expect("update actuals"); } + + #[tokio::test] + async fn stub_llm_fault_injector_sequence() { + use crate::llm::LlmProvider; + use crate::testing::fault_injection::{FaultAction, FaultInjector, FaultType}; + + let injector = Arc::new(FaultInjector::sequence([ + FaultAction::Fail(FaultType::RateLimited { retry_after: None }), + FaultAction::Succeed, + ])); + + let stub = StubLlm::new("hello").with_fault_injector(injector); + + let req = crate::llm::CompletionRequest::new(vec![crate::llm::ChatMessage::user("hi")]); + + // First call should fail with RateLimited + let result = stub.complete(req.clone()).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), LlmError::RateLimited { .. })); + + // Second call should succeed + let result = stub.complete(req).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().content, "hello"); + } }