Skip to content

Commit 25b04c5

Browse files
committed
fix: harden remote extends and resolve CI breakage
1 parent 294195d commit 25b04c5

34 files changed

+682
-329
lines changed

crates/clawdstrike/src/engine.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,10 +482,8 @@ impl HushEngineBuilder {
482482
Ok(v) => (v, None),
483483
Err(e) => (Vec::new(), Some(e.to_string())),
484484
};
485-
let custom_guards = build_custom_guards_from_policy(
486-
&self.policy,
487-
self.custom_guard_registry.as_ref(),
488-
)?;
485+
let custom_guards =
486+
build_custom_guards_from_policy(&self.policy, self.custom_guard_registry.as_ref())?;
489487

490488
Ok(HushEngine {
491489
policy: self.policy,

crates/clawdstrike/src/guards/custom.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,3 @@ impl CustomGuardRegistry {
5050
factory.build(config)
5151
}
5252
}
53-

crates/clawdstrike/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ pub use engine::{GuardReport, HushEngine};
5959
pub use error::{Error, Result};
6060
pub use guards::{
6161
CustomGuardFactory, CustomGuardRegistry, EgressAllowlistGuard, ForbiddenPathGuard, Guard,
62-
GuardContext, GuardResult, JailbreakConfig, JailbreakGuard, McpToolGuard,
63-
PatchIntegrityGuard, PromptInjectionGuard, SecretLeakGuard, Severity,
62+
GuardContext, GuardResult, JailbreakConfig, JailbreakGuard, McpToolGuard, PatchIntegrityGuard,
63+
PromptInjectionGuard, SecretLeakGuard, Severity,
6464
};
6565
pub use hygiene::{
6666
detect_prompt_injection, detect_prompt_injection_with_limit, wrap_user_content, DedupeStatus,

crates/clawdstrike/src/policy.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ fn default_json_object() -> serde_json::Value {
2828
serde_json::Value::Object(serde_json::Map::new())
2929
}
3030

31-
/// Policy-driven custom guard configuration.
31+
/// Policy-driven custom guard configuration (`policy.custom_guards[]`).
3232
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3333
#[serde(deny_unknown_fields)]
34-
pub struct CustomGuardSpec {
34+
pub struct PolicyCustomGuardSpec {
3535
/// Installed guard id (resolved via `CustomGuardRegistry`).
3636
pub id: String,
3737
/// Enable/disable this custom guard.
@@ -167,7 +167,7 @@ pub struct Policy {
167167
pub guards: GuardConfigs,
168168
/// Policy-driven custom guards (resolved by runtimes via a registry).
169169
#[serde(default)]
170-
pub custom_guards: Vec<CustomGuardSpec>,
170+
pub custom_guards: Vec<PolicyCustomGuardSpec>,
171171
/// Global settings
172172
#[serde(default)]
173173
pub settings: PolicySettings,
@@ -860,15 +860,18 @@ impl Policy {
860860
}
861861
}
862862

863-
fn merge_custom_guards(base: &[CustomGuardSpec], child: &[CustomGuardSpec]) -> Vec<CustomGuardSpec> {
863+
fn merge_custom_guards(
864+
base: &[PolicyCustomGuardSpec],
865+
child: &[PolicyCustomGuardSpec],
866+
) -> Vec<PolicyCustomGuardSpec> {
864867
if child.is_empty() {
865868
return base.to_vec();
866869
}
867870
if base.is_empty() {
868871
return child.to_vec();
869872
}
870873

871-
let mut out: Vec<CustomGuardSpec> = base.to_vec();
874+
let mut out: Vec<PolicyCustomGuardSpec> = base.to_vec();
872875
let mut index: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
873876
for (i, cg) in out.iter().enumerate() {
874877
index.insert(cg.id.clone(), i);

crates/clawdstrike/tests/policy_extends.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
33
#![allow(clippy::expect_used, clippy::unwrap_used)]
44

5-
use clawdstrike::Policy;
65
use clawdstrike::policy::{PolicyLocation, PolicyResolver, ResolvedPolicySource};
6+
use clawdstrike::Policy;
77
use std::fs;
88
use tempfile::TempDir;
99

@@ -250,7 +250,11 @@ fn test_policy_extends_resolver_detects_cycles_across_non_file_keys() {
250250
}
251251

252252
impl PolicyResolver for MapResolver {
253-
fn resolve(&self, reference: &str, _from: &PolicyLocation) -> clawdstrike::Result<ResolvedPolicySource> {
253+
fn resolve(
254+
&self,
255+
reference: &str,
256+
_from: &PolicyLocation,
257+
) -> clawdstrike::Result<ResolvedPolicySource> {
254258
let yaml = self.policies.get(reference).cloned().ok_or_else(|| {
255259
clawdstrike::Error::ConfigError(format!("Unknown policy ref: {}", reference))
256260
})?;

crates/clawdstrike/tests/threat_intel_guards.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ async fn virustotal_file_hash_denies_and_caches() {
6363
std::env::set_var("VT_BASE_URL_TEST", format!("{}/api/v3", base));
6464

6565
let yaml = r#"
66-
version: "1.0.0"
66+
version: "1.1.0"
6767
name: "ti"
6868
guards:
6969
custom:
@@ -118,7 +118,7 @@ async fn safe_browsing_denies_on_match() {
118118
std::env::set_var("GSB_BASE_URL_TEST", format!("{}/v4", base));
119119

120120
let yaml = r#"
121-
version: "1.0.0"
121+
version: "1.1.0"
122122
name: "ti"
123123
guards:
124124
egress_allowlist:
@@ -175,7 +175,7 @@ async fn snyk_denies_on_upgradable_vulns() {
175175
std::env::set_var("SNYK_BASE_URL_TEST", format!("{}/api/v1", base));
176176

177177
let yaml = r#"
178-
version: "1.0.0"
178+
version: "1.1.0"
179179
name: "ti"
180180
guards:
181181
custom:

crates/hush-cli/src/hush_run.rs

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,8 @@ fn load_policy(
452452
policy: &str,
453453
remote_extends: &remote_extends::RemoteExtendsConfig,
454454
) -> anyhow::Result<LoadedPolicy> {
455-
let loaded = policy_diff::load_policy_from_arg(policy, true, remote_extends).map_err(|e| {
456-
anyhow::anyhow!(
457-
"Failed to load policy {}: {}",
458-
e.source,
459-
e.message
460-
)
461-
})?;
455+
let loaded = policy_diff::load_policy_from_arg(policy, true, remote_extends)
456+
.map_err(|e| anyhow::anyhow!("Failed to load policy {}: {}", e.source, e.message))?;
462457

463458
Ok(loaded)
464459
}
@@ -475,8 +470,7 @@ fn load_or_create_signer(path: &Path, stderr: &mut dyn Write) -> anyhow::Result<
475470
let pub_path = PathBuf::from(format!("{}.pub", path.display()));
476471
let pub_hex = std::fs::read_to_string(&pub_path)
477472
.with_context(|| format!("read public key {}", pub_path.display()))?;
478-
let public_key =
479-
PublicKey::from_hex(pub_hex.trim()).context("parse public key hex")?;
473+
let public_key = PublicKey::from_hex(pub_hex.trim()).context("parse public key hex")?;
480474
return Ok(Box::new(hush_core::TpmSealedSeedSigner::new(
481475
public_key, blob,
482476
)));
@@ -515,9 +509,13 @@ fn load_or_create_signer(path: &Path, stderr: &mut dyn Write) -> anyhow::Result<
515509
#[derive(Clone, Debug)]
516510
enum SandboxWrapper {
517511
None,
518-
SandboxExec { profile_path: PathBuf },
512+
SandboxExec {
513+
profile_path: PathBuf,
514+
},
519515
#[cfg(target_os = "linux")]
520-
Bwrap { args: Vec<String> },
516+
Bwrap {
517+
args: Vec<String>,
518+
},
521519
}
522520

523521
fn maybe_prepare_sandbox(
@@ -693,8 +691,8 @@ async fn start_connect_proxy(
693691
let outcome = outcome.clone();
694692

695693
tokio::spawn(async move {
696-
let _ = handle_connect_proxy_client(socket, engine, context, event_tx, outcome)
697-
.await;
694+
let _ =
695+
handle_connect_proxy_client(socket, engine, context, event_tx, outcome).await;
698696
});
699697
}
700698
});
@@ -713,8 +711,7 @@ async fn handle_connect_proxy_client(
713711
.await
714712
.context("read proxy request header")?;
715713

716-
let header_str =
717-
std::str::from_utf8(&header).context("proxy request header must be UTF-8")?;
714+
let header_str = std::str::from_utf8(&header).context("proxy request header must be UTF-8")?;
718715
let mut lines = header_str.split("\r\n");
719716
let request_line = lines
720717
.next()
@@ -772,9 +769,7 @@ async fn handle_connect_proxy_client(
772769
if !result.allowed {
773770
// If we already sent 200 (IP + SNI path), we can only close the tunnel.
774771
if sni_buf.is_empty() {
775-
client
776-
.write_all(b"HTTP/1.1 403 Forbidden\r\n\r\n")
777-
.await?;
772+
client.write_all(b"HTTP/1.1 403 Forbidden\r\n\r\n").await?;
778773
}
779774
return Ok(());
780775
}

0 commit comments

Comments
 (0)