Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/python/acf/sdk_integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""SDK integration helpers for sidecar risk context contracts."""

from .risk_context import RiskContext, aggregate_risk

__all__ = ["RiskContext", "aggregate_risk"]
133 changes: 133 additions & 0 deletions sdk/python/acf/sdk_integration/risk_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Risk context contract between aggregate and policy stages.

This module defines a fixed-size, O(1) aggregation surface for policy input.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any


_SIGNAL_OBFUSCATION = "obfuscation"
_SIGNAL_LEXICAL = "lexical"
_SIGNAL_SEMANTIC = "semantic"
_SIGNAL_PROVENANCE = "provenance"

ALLOWED_SIGNALS = {
_SIGNAL_OBFUSCATION,
_SIGNAL_LEXICAL,
_SIGNAL_SEMANTIC,
_SIGNAL_PROVENANCE,
}

WEIGHTS = {
_SIGNAL_OBFUSCATION: 0.3,
_SIGNAL_LEXICAL: 0.3,
_SIGNAL_SEMANTIC: 0.2,
_SIGNAL_PROVENANCE: 0.2,
}

# TODO(v2): evolve hook multipliers into policy-configured profiles.
HOOK_MULTIPLIERS = {
"on_prompt": 1.0,
"on_context": 1.0,
"on_tool_call": 1.1,
"on_memory": 1.0,
}


def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
Comment thread
Ananya44444 marked this conversation as resolved.


def _normalize_signals(signals: dict[str, float]) -> dict[str, float]:
"""Return a fixed-size signal map, ignoring non-allowed keys."""

# Explicit fixed-key extraction keeps aggregation O(1), even if callers
# pass additional keys.
return {
_SIGNAL_OBFUSCATION: _clamp01(
float(signals.get(_SIGNAL_OBFUSCATION, 0.0))
),
_SIGNAL_LEXICAL: _clamp01(float(signals.get(_SIGNAL_LEXICAL, 0.0))),
_SIGNAL_SEMANTIC: _clamp01(float(signals.get(_SIGNAL_SEMANTIC, 0.0))),
_SIGNAL_PROVENANCE: _clamp01(
float(signals.get(_SIGNAL_PROVENANCE, 0.0))
Comment thread
Ananya44444 marked this conversation as resolved.
Outdated
),
}


@dataclass(frozen=True)
class RiskContext:
"""Normalized policy input object produced by the aggregator."""

score: float
signals: dict[str, float]
provenance: dict[str, Any]
metadata: dict[str, Any]

def to_dict(self) -> dict[str, Any]:
return {
"score": self.score,
"signals": self.signals,
"provenance": self.provenance,
"metadata": self.metadata,
}
Comment thread
Ananya44444 marked this conversation as resolved.


def aggregate_risk(
*,
signals: dict[str, float],
provenance: dict[str, Any],
metadata: dict[str, Any],
) -> RiskContext:
"""Build a fixed-shape `RiskContext` with O(1) weighted scoring.

Expected signals keys: obfuscation, lexical, semantic, provenance.
Missing keys default to 0.0. All values are clamped to [0.0, 1.0].
"""

normalized_signals = _normalize_signals(signals)

normalized_provenance = {
"execution_id": str(provenance.get("execution_id", "")),
"trusted": bool(provenance.get("trusted", False)),
"nonce_valid": bool(provenance.get("nonce_valid", False)),
}
Comment thread
Ananya44444 marked this conversation as resolved.

normalized_metadata = {
"hook": str(metadata.get("hook", "")),
"timestamp": int(metadata.get("timestamp", 0)),
Comment thread
Ananya44444 marked this conversation as resolved.
Outdated
}

score = _clamp01(
(
WEIGHTS[_SIGNAL_OBFUSCATION]
* normalized_signals[_SIGNAL_OBFUSCATION]
)
+ (WEIGHTS[_SIGNAL_LEXICAL] * normalized_signals[_SIGNAL_LEXICAL])
+ (WEIGHTS[_SIGNAL_SEMANTIC] * normalized_signals[_SIGNAL_SEMANTIC])
+ (
WEIGHTS[_SIGNAL_PROVENANCE]
* normalized_signals[_SIGNAL_PROVENANCE]
)
)

trust_penalty = 0.2 if not normalized_provenance["trusted"] else 0.0
nonce_penalty = 0.1 if not normalized_provenance["nonce_valid"] else 0.0
score = _clamp01(score + trust_penalty + nonce_penalty)

hook = normalized_metadata["hook"]
hook_multiplier = HOOK_MULTIPLIERS.get(hook, 1.0)
score = _clamp01(score * hook_multiplier)

return RiskContext(
score=score,
signals=normalized_signals,
provenance=normalized_provenance,
metadata=normalized_metadata,
)
177 changes: 177 additions & 0 deletions sdk/python/tests/test_risk_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Tests for acf.sdk_integration.risk_context."""

from acf.sdk_integration.risk_context import (
RiskContext,
WEIGHTS,
aggregate_risk,
)


def test_risk_context_structure():
ctx = aggregate_risk(
signals={
"obfuscation": 0.8,
"lexical": 0.7,
"semantic": 0.5,
"provenance": 1.0,
},
provenance={
"execution_id": "exec-123",
"trusted": True,
"nonce_valid": True,
},
metadata={
"hook": "on_prompt",
"timestamp": 1710000000,
},
)

assert isinstance(ctx, RiskContext)

payload = ctx.to_dict()
assert set(payload.keys()) == {
"score",
"signals",
"provenance",
"metadata",
}
assert set(payload["signals"].keys()) == {
"obfuscation",
"lexical",
"semantic",
"provenance",
}
assert set(payload["provenance"].keys()) == {
"execution_id",
"trusted",
"nonce_valid",
}
assert set(payload["metadata"].keys()) == {"hook", "timestamp"}

opa_input = {"input": payload}
assert "input" in opa_input
assert opa_input["input"]["metadata"]["hook"] == "on_prompt"


def test_score_bounds():
ctx = aggregate_risk(
signals={
"obfuscation": 99.0,
"lexical": 50.0,
"semantic": 22.0,
"provenance": 9.0,
},
provenance={
"execution_id": "exec-bounds",
"trusted": False,
"nonce_valid": False,
},
metadata={
"hook": "on_context",
"timestamp": 1710000001,
},
)

assert 0.0 <= ctx.score <= 1.0


def test_weighted_scoring():
ctx = aggregate_risk(
signals={
"obfuscation": 0.5,
"lexical": 0.5,
"semantic": 0.5,
"provenance": 0.5,
},
provenance={
"execution_id": "exec-weighted",
"trusted": True,
"nonce_valid": True,
},
metadata={
"hook": "on_prompt",
"timestamp": 1710000100,
},
)

expected = (
(WEIGHTS["obfuscation"] * 0.5)
+ (WEIGHTS["lexical"] * 0.5)
+ (WEIGHTS["semantic"] * 0.5)
+ (WEIGHTS["provenance"] * 0.5)
)
assert round(ctx.score, 2) == round(expected, 2)


def test_deterministic_output():
kwargs = {
"signals": {
"obfuscation": 0.21,
"lexical": 0.49,
"semantic": 0.13,
"provenance": 0.77,
},
"provenance": {
"execution_id": "exec-deterministic",
"trusted": True,
"nonce_valid": True,
},
"metadata": {
"hook": "on_tool_call",
"timestamp": 1710000002,
},
}

a = aggregate_risk(**kwargs)
b = aggregate_risk(**kwargs)

assert a.to_dict() == b.to_dict()


def test_signal_normalization():
ctx = aggregate_risk(
signals={
"obfuscation": -10.0,
"lexical": 0.25,
"semantic": 100.0,
"provenance": -1.0,
},
provenance={
"execution_id": "exec-normalize",
"trusted": 1,
"nonce_valid": 0,
},
metadata={
"hook": "on_memory",
"timestamp": "1710000003",
},
)

assert ctx.signals["obfuscation"] == 0.0
assert ctx.signals["lexical"] == 0.25
assert ctx.signals["semantic"] == 1.0
assert ctx.signals["provenance"] == 0.0

assert ctx.provenance["trusted"] is True
assert ctx.provenance["nonce_valid"] is False
assert ctx.metadata["timestamp"] == 1710000003

Comment thread
Ananya44444 marked this conversation as resolved.

def test_missing_signals_defaults():
ctx = aggregate_risk(
signals={},
provenance={
"execution_id": "exec-defaults",
"trusted": True,
"nonce_valid": True,
},
metadata={
"hook": "on_prompt",
"timestamp": 1710000200,
},
)

assert ctx.signals["obfuscation"] == 0.0
assert ctx.signals["lexical"] == 0.0
assert ctx.signals["semantic"] == 0.0
assert ctx.signals["provenance"] == 0.0
Loading