Skip to content

Commit cf5bc17

Browse files
committed
refactor: replace unsafe global environment variable manipulation with thread-local mocks and pure sanitization functions for improved test safety.
1 parent 179db32 commit cf5bc17

5 files changed

Lines changed: 64 additions & 84 deletions

File tree

src/guard/env.rs

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ pub const DENYLIST: &[&str] = &[
2828
];
2929

3030
pub fn sanitize_env() -> Vec<(String, String)> {
31-
env::vars()
31+
sanitize_vars(env::vars())
32+
}
33+
34+
pub fn sanitize_vars(vars: impl IntoIterator<Item = (String, String)>) -> Vec<(String, String)> {
35+
vars.into_iter()
3236
.filter(|(k, _)| !DENYLIST.iter().any(|d| d.eq_ignore_ascii_case(k)))
3337
.collect()
3438
}
@@ -44,46 +48,37 @@ mod tests {
4448

4549
#[test]
4650
fn test_sanitize_env_menghapus_ld_preload() {
47-
unsafe {
48-
std::env::set_var("LD_PRELOAD", "bad.so");
49-
}
50-
let sanitized = sanitize_env();
51+
let mock_env = vec![
52+
("LD_PRELOAD".to_string(), "bad.so".to_string()),
53+
("NORMAL_VAR".to_string(), "123".to_string()),
54+
];
55+
let sanitized = sanitize_vars(mock_env);
5156
let contains = sanitized.iter().any(|(k, _)| k == "LD_PRELOAD");
5257
assert!(!contains);
53-
unsafe {
54-
std::env::remove_var("LD_PRELOAD");
55-
}
5658
}
5759

5860
#[test]
5961
fn test_sanitize_env_menghapus_semua_denylist_entries() {
60-
for key in DENYLIST {
61-
unsafe {
62-
std::env::set_var(key, "malicious_payload");
63-
}
64-
}
62+
let mock_env: Vec<(String, String)> = DENYLIST
63+
.iter()
64+
.map(|key| (key.to_string(), "malicious_payload".to_string()))
65+
.collect();
6566

66-
let sanitized = sanitize_env();
67+
let sanitized = sanitize_vars(mock_env);
6768

6869
for (k, _) in sanitized {
6970
assert!(!DENYLIST.iter().any(|d| d.eq_ignore_ascii_case(&k)));
7071
}
71-
72-
for key in DENYLIST {
73-
unsafe {
74-
std::env::remove_var(key);
75-
}
76-
}
7772
}
7873

7974
#[test]
8075
fn test_sanitize_env_mempertahankan_path_and_normal_vars() {
81-
unsafe {
82-
std::env::set_var("PATH", "/usr/bin:/bin");
83-
std::env::set_var("NORMAL_VAR", "123");
84-
}
76+
let mock_env = vec![
77+
("PATH".to_string(), "/usr/bin:/bin".to_string()),
78+
("NORMAL_VAR".to_string(), "123".to_string()),
79+
];
8580

86-
let sanitized = sanitize_env();
81+
let sanitized = sanitize_vars(mock_env);
8782
let has_path = sanitized.iter().any(|(k, _)| k.to_uppercase() == "PATH");
8883
let has_normal = sanitized
8984
.iter()

src/hooks/dispatcher.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,6 @@ mod tests {
146146
"workingDirectory": "/tmp"
147147
});
148148

149-
unsafe {
150-
std::env::set_var("OMNI_CONTINUE", "1");
151-
std::env::set_var("OMNI_FRESH", "0");
152-
}
153149
let out = process_payload(&input.to_string(), store, session);
154150

155151
assert!(out.is_some());

src/hooks/session_start.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ mod tests {
192192
let db_path = dir.path().join("omni.db");
193193
// Set transcript dir to clean temp dir so find_pending() doesn't interfere
194194
let transcript_dir = dir.path().join("transcripts");
195-
unsafe {
196-
std::env::set_var("OMNI_TRANSCRIPT_DIR", transcript_dir.to_str().unwrap());
197-
}
195+
crate::store::transcript::MOCK_TRANSCRIPT_DIR.with(|d| {
196+
*d.borrow_mut() = Some(transcript_dir);
197+
});
198198
(
199199
Arc::new(Store::open_path(&db_path).expect("must succeed")),
200200
dir,

src/store/transcript.rs

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,17 @@ pub fn cleanup_old(days: u32) {
337337

338338
// ─── Helpers ────────────────────────────────────────────
339339

340+
#[cfg(test)]
341+
thread_local! {
342+
pub static MOCK_TRANSCRIPT_DIR: std::cell::RefCell<Option<PathBuf>> = const { std::cell::RefCell::new(None) };
343+
}
344+
340345
fn transcripts_dir() -> PathBuf {
346+
#[cfg(test)]
347+
if let Some(mock) = MOCK_TRANSCRIPT_DIR.with(|d| d.borrow().clone()) {
348+
return mock;
349+
}
350+
341351
if let Ok(custom) = std::env::var("OMNI_TRANSCRIPT_DIR") {
342352
return PathBuf::from(custom);
343353
}
@@ -368,19 +378,12 @@ fn truncate_payload(payload: &str, max_bytes: usize) -> String {
368378
#[cfg(test)]
369379
mod tests {
370380
use super::*;
371-
use std::sync::Mutex;
372381
use tempfile::tempdir;
373382

374-
// Serialize all transcript tests — they share the OMNI_TRANSCRIPT_DIR env var
375-
static TEST_LOCK: Mutex<()> = Mutex::new(());
376-
377-
fn setup_test_dir() -> (tempfile::TempDir, std::sync::MutexGuard<'static, ()>) {
378-
let guard = TEST_LOCK.lock().unwrap();
383+
fn setup_test_dir() -> tempfile::TempDir {
379384
let dir = tempdir().unwrap();
380-
unsafe {
381-
std::env::set_var("OMNI_TRANSCRIPT_DIR", dir.path().to_str().unwrap());
382-
}
383-
(dir, guard)
385+
MOCK_TRANSCRIPT_DIR.with(|d| *d.borrow_mut() = Some(dir.path().to_path_buf()));
386+
dir
384387
}
385388

386389
#[test]
@@ -414,7 +417,7 @@ mod tests {
414417

415418
#[test]
416419
fn test_save_and_load_roundtrip() {
417-
let (_dir, _lock) = setup_test_dir();
420+
let _dir = setup_test_dir();
418421

419422
let mut t = Transcript::new("roundtrip_1", "/project");
420423
let entry = TranscriptEntry::new_input("git status", Some("git"));
@@ -428,7 +431,7 @@ mod tests {
428431

429432
#[test]
430433
fn test_atomic_write_no_corrupt_file() {
431-
let (_dir, _lock) = setup_test_dir();
434+
let _dir = setup_test_dir();
432435

433436
let t = Transcript::new("atomic_1", "/project");
434437
t.save().unwrap();
@@ -444,7 +447,7 @@ mod tests {
444447

445448
#[test]
446449
fn test_mark_last_completed() {
447-
let (_dir, _lock) = setup_test_dir();
450+
let _dir = setup_test_dir();
448451

449452
let mut t = Transcript::new("complete_1", "/project");
450453
t.append_entry(TranscriptEntry::new_input("input1", None))
@@ -461,7 +464,7 @@ mod tests {
461464

462465
#[test]
463466
fn test_mark_last_failed() {
464-
let (_dir, _lock) = setup_test_dir();
467+
let _dir = setup_test_dir();
465468

466469
let mut t = Transcript::new("fail_1", "/project");
467470
t.append_entry(TranscriptEntry::new_input("bad input", None))
@@ -497,13 +500,13 @@ mod tests {
497500

498501
#[test]
499502
fn test_find_pending_with_no_transcripts() {
500-
let (_dir, _lock) = setup_test_dir();
503+
let _dir = setup_test_dir();
501504
assert!(find_pending().is_none());
502505
}
503506

504507
#[test]
505508
fn test_find_pending_finds_interrupted_session() {
506-
let (_dir, _lock) = setup_test_dir();
509+
let _dir = setup_test_dir();
507510

508511
let mut t = Transcript::new("interrupted_1", "/project");
509512
t.append_entry(TranscriptEntry::new_input("pending work", None))
@@ -516,7 +519,7 @@ mod tests {
516519

517520
#[test]
518521
fn test_find_pending_skips_completed_sessions() {
519-
let (_dir, _lock) = setup_test_dir();
522+
let _dir = setup_test_dir();
520523

521524
let mut t = Transcript::new("done_1", "/project");
522525
t.append_entry(TranscriptEntry::new_input("done work", None))
@@ -528,7 +531,7 @@ mod tests {
528531

529532
#[test]
530533
fn test_load_or_new_creates_if_missing() {
531-
let (_dir, _lock) = setup_test_dir();
534+
let _dir = setup_test_dir();
532535

533536
let t = Transcript::load_or_new("new_session", "/project");
534537
assert_eq!(t.session_id, "new_session");
@@ -537,7 +540,7 @@ mod tests {
537540

538541
#[test]
539542
fn test_load_or_new_loads_if_exists() {
540-
let (_dir, _lock) = setup_test_dir();
543+
let _dir = setup_test_dir();
541544

542545
let mut original = Transcript::new("existing_1", "/project");
543546
original
@@ -551,7 +554,7 @@ mod tests {
551554

552555
#[test]
553556
fn test_cleanup_old_removes_stale_transcripts() {
554-
let (_dir, _lock) = setup_test_dir();
557+
let _dir = setup_test_dir();
555558

556559
// Create a transcript with very old updated_at
557560
let mut t = Transcript::new("old_1", "/project");
@@ -570,7 +573,7 @@ mod tests {
570573

571574
#[test]
572575
fn test_list_recent_returns_all() {
573-
let (_dir, _lock) = setup_test_dir();
576+
let _dir = setup_test_dir();
574577

575578
let t1 = Transcript::new("list_a", "/project");
576579
t1.save().unwrap();
@@ -600,7 +603,7 @@ mod tests {
600603

601604
#[test]
602605
fn test_snapshot_state_persists() {
603-
let (_dir, _lock) = setup_test_dir();
606+
let _dir = setup_test_dir();
604607

605608
let mut t = Transcript::new("state_1", "/project");
606609
let mut state = SessionState::new();

tests/security_tests.rs

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,24 @@ fn run_pipeline(input: &str) -> String {
1111

1212
#[test]
1313
fn test_env_sanitization_denylist() {
14-
use omni::guard::env::{DENYLIST, sanitize_env};
14+
use omni::guard::env::{DENYLIST, sanitize_vars};
1515

16-
// Set some dangerous env vars (unsafe in Rust 2024)
16+
// Set some dangerous env vars in a mock environment
17+
let mut mock_env: Vec<(String, String)> = Vec::new();
1718
for var in DENYLIST.iter().take(3) {
18-
unsafe {
19-
std::env::set_var(var, "INJECTED_VALUE");
20-
}
19+
mock_env.push((var.to_string(), "INJECTED_VALUE".to_string()));
2120
}
2221

23-
let sanitized = sanitize_env();
22+
let sanitized = sanitize_vars(mock_env);
2423

2524
// Verify denylist vars are NOT in sanitized output
2625
for var in DENYLIST {
2726
assert!(
2827
!sanitized.iter().any(|(k, _)| k.eq_ignore_ascii_case(var)),
29-
"Denylist variable {} should be removed by sanitize_env",
28+
"Denylist variable {} should be removed by sanitize_vars",
3029
var
3130
);
3231
}
33-
34-
// Cleanup
35-
for var in DENYLIST.iter().take(3) {
36-
unsafe {
37-
std::env::remove_var(var);
38-
}
39-
}
4032
}
4133

4234
#[test]
@@ -115,16 +107,17 @@ fn test_pipeline_deterministic() {
115107

116108
#[test]
117109
fn test_env_sanitization_removes_dangerous_vars() {
118-
use omni::guard::env::{DENYLIST, sanitize_env};
119-
120-
// Set beberapa dangerous vars
121-
unsafe {
122-
std::env::set_var("LD_PRELOAD", "malicious.so");
123-
std::env::set_var("BASH_ENV", "evil_script.sh");
124-
std::env::set_var("NODE_OPTIONS", "--require=evil");
125-
}
110+
use omni::guard::env::{DENYLIST, sanitize_vars};
111+
112+
// Set beberapa dangerous vars in a mock env
113+
let mock_env = vec![
114+
("LD_PRELOAD".to_string(), "malicious.so".to_string()),
115+
("BASH_ENV".to_string(), "evil_script.sh".to_string()),
116+
("NODE_OPTIONS".to_string(), "--require=evil".to_string()),
117+
("PATH".to_string(), "/usr/bin:/bin".to_string()),
118+
];
126119

127-
let sanitized = sanitize_env();
120+
let sanitized = sanitize_vars(mock_env);
128121

129122
// Verify semua DENYLIST entries hilang
130123
for key in DENYLIST {
@@ -140,13 +133,6 @@ fn test_env_sanitization_removes_dangerous_vars() {
140133
sanitized.iter().any(|(k, _)| k.to_uppercase() == "PATH"),
141134
"PATH should still be in sanitized env"
142135
);
143-
144-
// Cleanup
145-
unsafe {
146-
std::env::remove_var("LD_PRELOAD");
147-
std::env::remove_var("BASH_ENV");
148-
std::env::remove_var("NODE_OPTIONS");
149-
}
150136
}
151137

152138
#[test]

0 commit comments

Comments
 (0)