-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathfirewall.py
More file actions
160 lines (136 loc) · 5.52 KB
/
firewall.py
File metadata and controls
160 lines (136 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Firewall — the main developer-facing class.
Provides the four v1 hook call sites:
on_prompt(text) -> Decision
on_context(chunks) -> list[ChunkResult]
on_tool_call(name, params) -> Decision
on_memory(key, value, op) -> Decision
Each method builds a RiskContext JSON payload, delegates to Transport,
and returns the decoded Decision (or raises FirewallError on failure).
"""
from __future__ import annotations
import binascii
import json
import os
from typing import Any
from .models import (
ChunkResult,
Decision,
FirewallError,
SanitiseResult,
)
from .transport import Transport, DEFAULT_SOCKET_PATH
class Firewall:
"""Entry point for the ACF SDK.
Args:
socket_path: Path to the sidecar IPC address. Defaults to
``/tmp/acf.sock`` on Linux/macOS or ``\\\\.\\pipe\\acf``
on Windows, or the ACF_SOCKET_PATH environment variable.
hmac_key: Raw bytes of the HMAC key. If None, read ACF_HMAC_KEY
from the environment (hex-encoded) and decode it.
Raises:
FirewallError: If no HMAC key can be resolved.
"""
def __init__(
self,
socket_path: str | None = None,
hmac_key: bytes | None = None,
) -> None:
resolved_path = (
socket_path
or os.environ.get("ACF_SOCKET_PATH")
or DEFAULT_SOCKET_PATH
)
if hmac_key is None:
raw = os.environ.get("ACF_HMAC_KEY", "")
if not raw:
raise FirewallError(
"No HMAC key provided. Pass hmac_key= or set ACF_HMAC_KEY "
"(hex-encoded, min 32 bytes)."
)
try:
hmac_key = binascii.unhexlify(raw)
except (ValueError, binascii.Error) as exc:
raise FirewallError(f"ACF_HMAC_KEY is not valid hex: {exc}") from exc
self._transport = Transport(socket_path=resolved_path, key=hmac_key)
# ── v1 hook call sites ────────────────────────────────────────────────────
def on_prompt(self, text: str) -> Decision | SanitiseResult:
"""Evaluate a user prompt before it enters the model context.
Returns Decision.ALLOW, Decision.BLOCK, or a SanitiseResult.
"""
payload = self._build_payload("on_prompt", text, provenance="user")
return self._send(payload)
def on_context(self, chunks: list[str]) -> list[ChunkResult]:
"""Evaluate RAG chunks before injection into the model context.
Each chunk is evaluated independently. Returns one ChunkResult per chunk.
Chunks with a BLOCK decision have sanitised_text=None.
"""
results = []
for chunk in chunks:
payload = self._build_payload("on_context", chunk, provenance="rag")
decision = self._send(payload)
if isinstance(decision, SanitiseResult):
results.append(ChunkResult(
original=chunk,
decision=Decision.SANITISE,
sanitised_text=decision.sanitised_text,
))
else:
results.append(ChunkResult(
original=chunk,
decision=decision,
sanitised_text=None,
))
return results
def on_tool_call(self, name: str, params: dict[str, Any]) -> Decision | SanitiseResult:
"""Evaluate a tool call before the tool executes.
Returns Decision.ALLOW, Decision.BLOCK, or a SanitiseResult.
"""
payload = self._build_payload(
"on_tool_call",
{"name": name, "params": params},
provenance="agent",
)
return self._send(payload)
def on_memory(self, key: str, value: str, op: str = "write") -> Decision | SanitiseResult:
"""Evaluate a memory read or write before it is committed.
op: "write" (default) or "read".
Returns Decision.ALLOW, Decision.BLOCK, or a SanitiseResult.
"""
payload = self._build_payload(
"on_memory",
{"key": key, "value": value, "op": op},
provenance="agent",
)
return self._send(payload)
# ── internals ────────────────────────────────────────────────────────────
def _build_payload(
self,
hook_type: str,
content: Any,
*,
provenance: str = "sdk",
session_id: str = "",
) -> bytes:
ctx = {
"score": 0.0,
"signals": [],
"provenance": provenance,
"session_id": session_id,
"hook_type": hook_type,
"payload": content,
"state": None,
}
return json.dumps(ctx, separators=(",", ":"), sort_keys=True).encode("utf-8")
def _send(self, payload: bytes) -> Decision | SanitiseResult:
resp = self._transport.send(payload)
decision = Decision.from_byte(resp["decision"])
if decision == Decision.SANITISE:
raw = resp["sanitised_payload"]
text = raw.decode("utf-8", errors="replace") if raw else None
return SanitiseResult(
decision=decision,
sanitised_payload=raw,
sanitised_text=text,
)
return decision