Skip to content

Commit 8303a1a

Browse files
kovtcharovOvtcharov
andauthored
fix(security,agents): harden SSRF DNS-rebind, fail loudly, cover untested modules (#1299)
A DNS-rebind attacker could slip past the web client's SSRF guard: the pre-flight IP check and the actual TCP connect used **separate** DNS lookups, so a host could answer the check with a public IP and the connect with a private/internal one. This PR pins the resolved IP and validates the exact address it dials through a single authority, closing the window for both http and https (HTTPS cert-name verification still binds to the real hostname; the SNI-vhosting trade-off is documented, not silently accepted). It also removes three silent-fallback violations that hid real failures, and adds tests to three high-risk modules that had none. Reviewer-relevant threads: - **SSRF hardening** (`web/client.py`) — the security fix above; worth a close read of `PinnedIPAdapter`. - **Fail-loudly fixes** — corrupt memory-settings now logs instead of reverting to defaults silently; Telegram background startup re-raises on PID-write failure (a supervisor can no longer be fooled into thinking a dead process started); a raising system-prompt fragment now logs instead of vanishing from the prompt. - **New coverage** — DockerAgent (subprocess/path-allowlist), the home-dir discovery classifiers, and Jira JQL templating; the API non-streaming completion happy-path is now tested with a mocked backend instead of `@pytest.mark.skip`. ## Test plan - [ ] `pytest tests/unit/agents/test_discovery.py tests/unit/agents/test_docker_agent.py tests/unit/agents/test_jql_templates.py tests/unit/test_web_client_ip_pinning.py -q` — 100 pass - [ ] `pytest tests/unit/test_web_client_edge_cases.py tests/test_rag.py -q` — no regression (96 pass) - [ ] `pytest tests/test_api.py -q` — completion happy-path now runs (no longer skipped) - [ ] `python util/lint.py --black --isort --flake8` — clean on changed files - [ ] Agent eval (running separately) confirms no regression from the prompt-fragment logging change — the change is logging-only on the exception branch, so the composed prompt is byte-identical on the happy path --------- Co-authored-by: Ovtcharov <kovtchar@amd.com>
1 parent 0721f9f commit 8303a1a

11 files changed

Lines changed: 1224 additions & 114 deletions

File tree

src/gaia/agents/base/agent.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,14 @@ def _get_mixin_prompts(self) -> list[str]:
485485
fragment = getattr(self, attr_name)()
486486
if fragment:
487487
prompts.append(fragment)
488-
except Exception:
489-
pass
488+
except Exception as e:
489+
# A raising fragment is dropped from the composed prompt; surface it
490+
# so a silently degraded system prompt is diagnosable.
491+
logger.warning(
492+
"system-prompt fragment %s() raised, skipping it: %s",
493+
attr_name,
494+
e,
495+
)
490496

491497
return prompts
492498

src/gaia/agents/base/memory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ def _load_memory_settings() -> Dict:
6262
try:
6363
if _MEMORY_SETTINGS_PATH.exists():
6464
return json.loads(_MEMORY_SETTINGS_PATH.read_text(encoding="utf-8"))
65-
except Exception:
66-
pass
65+
except (OSError, json.JSONDecodeError) as e:
66+
logger.warning(
67+
"failed to load memory settings from %s: %s", _MEMORY_SETTINGS_PATH, e
68+
)
6769
return {}
6870

6971

src/gaia/mcp/mcp_bridge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def _parse_cd(self, value: str):
5757
name = p.split("=", 1)[1].strip().strip('"')
5858
elif pl.startswith("filename="):
5959
filename = p.split("=", 1)[1].strip().strip('"')
60-
except Exception:
61-
pass
60+
except (AttributeError, IndexError, ValueError) as e:
61+
logger.debug("Failed to parse Content-Disposition %r: %s", value, e)
6262
return name, filename
6363

6464
def on_part_begin(self):

src/gaia/messaging/telegram.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,12 @@ def start(self, token: str, background: bool = False) -> None:
184184
with open(pid_path, "w", encoding="utf-8") as f:
185185
f.write(str(os.getpid()))
186186
except OSError as e:
187-
log.exception("Failed to write PID file for telegram adapter: %s", e)
187+
# The PID file is load-bearing in background mode: without it a
188+
# supervisor cannot find or kill the process. Fail loudly.
189+
raise RuntimeError(
190+
f"Failed to write telegram PID file at {pid_path}: {e}. "
191+
f"Ensure ~/.gaia is writable, or start without --background."
192+
) from e
188193

189194
log_path = os.path.join(pid_dir, "telegram.log")
190195
try:

src/gaia/rag/sdk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,8 @@ def _load_embedder(self):
450450
# Force fresh load - must unload first
451451
try:
452452
self.llm_client.unload_model()
453-
except Exception:
454-
pass # Ignore if nothing to unload
453+
except Exception as e:
454+
self.log.warning("unload_model failed (continuing): %s", e)
455455

456456
try:
457457
self.llm_client.load_model(

src/gaia/web/client.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,42 @@
4040
ALLOWED_SCHEMES = {"http", "https"}
4141
BLOCKED_PORTS = {22, 23, 25, 445, 3306, 5432, 6379, 27017}
4242

43+
44+
def _is_blocked_ip(ip: "ipaddress._BaseAddress") -> bool:
45+
"""Return True if ``ip`` points at a private/internal range we must not fetch."""
46+
return (
47+
ip.is_private
48+
or ip.is_loopback
49+
or ip.is_link_local
50+
or ip.is_reserved
51+
or ip.is_multicast
52+
)
53+
54+
55+
def _assert_ip_allowed(ip_str: str, hostname: str) -> None:
56+
"""Raise ValueError if ``ip_str`` is a private/reserved address.
57+
58+
The single authority for "is this IP safe to connect to". Both
59+
``WebClient.validate_url`` (pre-flight DNS check) and
60+
``PinnedIPAdapter`` (the IP it actually connects to) route through this,
61+
so a DNS rebind that slips a private IP past the pre-flight lookup is
62+
still caught at connect time on the *exact* address being dialed.
63+
"""
64+
try:
65+
ip = ipaddress.ip_address(ip_str)
66+
except ValueError:
67+
# Not parseable as an IP — treat as unsafe rather than letting an
68+
# unvalidated value reach the socket layer (fail loudly).
69+
raise ValueError(
70+
f"Blocked: {hostname} resolved to unparseable address {ip_str!r}."
71+
)
72+
if _is_blocked_ip(ip):
73+
raise ValueError(
74+
f"Blocked: {hostname} resolves to private/reserved IP {ip}. "
75+
"Cannot fetch internal network addresses."
76+
)
77+
78+
4379
# Tags to remove during text extraction
4480
REMOVE_TAGS = [
4581
"script",
@@ -72,17 +108,41 @@ class PinnedIPAdapter(HTTPAdapter):
72108
DNS-rebind attacks between ``WebClient.validate_url`` and the actual
73109
TCP connect.
74110
111+
Crucially, the pinned IP is itself validated (``_assert_ip_allowed``)
112+
before it is cached or connected to. ``validate_url`` runs a *separate*
113+
pre-flight ``getaddrinfo``; an attacker controlling DNS could answer that
114+
lookup with a public IP and answer the adapter's lookup with a private
115+
one. Validating the exact address the adapter is about to dial closes
116+
that residual rebind window for BOTH http and https.
117+
75118
For HTTPS, the original hostname is encoded in the URL's userinfo
76119
section (``originalhostname@pinnedip:port``) so that urllib3 creates
77120
separate connection-pool keys per original hostname. This avoids a
78121
race where two threads requesting different hostnames that resolve to
79122
the same IP would overwrite each other's ``assert_hostname`` on a
80123
shared pool.
124+
125+
Residual HTTPS limitation (documented, not silently ignored):
126+
Because ``requests`` derives the urllib3 pool host — and therefore the
127+
TLS SNI ``server_hostname`` — from the request URL's hostname (which we
128+
rewrote to the pinned IP), the ClientHello SNI is sent as the IP, not the
129+
original hostname. ``assert_hostname`` still forces certificate-name
130+
verification against the real hostname (so verification is NOT disabled
131+
and we never trust a cert for the bare IP), but servers that rely on SNI
132+
for virtual hosting (most CDNs / shared hosts) may return the wrong
133+
certificate or reject the handshake, surfacing as a TLS error rather than
134+
a silent downgrade. This affects whether legitimate HTTPS *succeeds* — it
135+
does not weaken the SSRF block, which fires on the validated IP before any
136+
bytes are sent. Fixing SNI cleanly requires a custom urllib3
137+
``PoolManager``/``HTTPSConnection`` that decouples ``server_hostname``
138+
from the connect address; that is intentionally out of scope here in
139+
favour of a correct, narrower guarantee.
81140
"""
82141

83142
def __init__(self, *args, **kwargs):
84143
super().__init__(*args, **kwargs)
85144
self._pinned_cache: Dict[Tuple[str, int], str] = {}
145+
self._warned_https_sni = False
86146

87147
def _resolve_first_ip(self, host: str, port: int) -> str:
88148
key = (host, port)
@@ -94,6 +154,12 @@ def _resolve_first_ip(self, host: str, port: int) -> str:
94154
raise OSError(f"getaddrinfo returned no addresses for {host}:{port}")
95155

96156
ip = infos[0][4][0] # sockaddr[0] of the first result
157+
# Validate the EXACT IP we are about to pin & connect to. validate_url
158+
# did a pre-flight getaddrinfo, but that was a separate lookup — a DNS
159+
# rebind could hand validate_url a public IP and hand us a private one.
160+
# Re-check here so the address actually dialed is always safe; cache
161+
# only after it passes so a poisoned answer is never reused.
162+
_assert_ip_allowed(ip, host)
97163
self._pinned_cache[key] = ip
98164
return ip
99165

@@ -134,7 +200,22 @@ def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response
134200
pinned_ip = self._resolve_first_ip(host, port)
135201

136202
if parsed.scheme == "https":
137-
# Encode original hostname in userinfo for unique pool keys
203+
# Encode original hostname in userinfo for unique pool keys.
204+
# See class docstring: SNI is sent as the pinned IP, so
205+
# SNI-vhosted servers may fail the handshake. Cert-name
206+
# verification still binds to the real hostname.
207+
if not getattr(self, "_warned_https_sni", False):
208+
log.warning(
209+
"PinnedIPAdapter: HTTPS request to %s is pinned to %s; "
210+
"TLS SNI will be sent as the IP. Servers using "
211+
"SNI-based virtual hosting may return the wrong "
212+
"certificate or reject the handshake. Certificate-name "
213+
"verification still validates against %s.",
214+
host,
215+
pinned_ip,
216+
host,
217+
)
218+
self._warned_https_sni = True
138219
new_netloc = f"{host}@{pinned_ip}:{port}"
139220
else:
140221
new_netloc = f"{pinned_ip}:{port}"
@@ -270,13 +351,7 @@ def _validate_host_ip(self, hostname: str) -> None:
270351
except ValueError:
271352
continue
272353

273-
if (
274-
ip.is_private
275-
or ip.is_loopback
276-
or ip.is_link_local
277-
or ip.is_reserved
278-
or ip.is_multicast
279-
):
354+
if _is_blocked_ip(ip):
280355
raise ValueError(
281356
f"Blocked: {hostname} resolves to private/reserved IP {ip}. "
282357
"Cannot fetch internal network addresses."

tests/test_api.py

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,95 @@ def setup(self):
5454
pytest.skip(f"API dependencies not available: {IMPORT_ERROR}")
5555
self.client = TestClient(app)
5656

57+
# -------------------------------------------------------------------------
58+
# Non-streaming happy path (mocked agent backend — no Lemonade required)
59+
# -------------------------------------------------------------------------
60+
61+
def test_basic_completion_with_mocked_agent(self, mocker):
62+
"""Non-streaming POST returns a schema-valid OpenAI completion.
63+
64+
The agent/Lemonade backend is mocked: registry.get_agent yields a stub
65+
whose process_query returns a canned result dict, so the handler's
66+
non-streaming branch runs end-to-end without a live LLM server.
67+
"""
68+
# Stub agent: NOT an ApiAgent, so the handler uses the len//4 token
69+
# estimate path (deterministic, no tokenizer needed).
70+
fake_agent = mocker.MagicMock()
71+
fake_agent.process_query.return_value = {
72+
"status": "success",
73+
"result": "def hello():\n return 'hello world'",
74+
}
75+
76+
from gaia.api.openai_server import registry as server_registry
77+
78+
mocker.patch.object(server_registry, "get_agent", return_value=fake_agent)
79+
80+
payload = {
81+
"model": "gaia-code",
82+
"messages": [
83+
{"role": "user", "content": "Write a hello world function in Python"}
84+
],
85+
"stream": False,
86+
}
87+
response = self.client.post("/v1/chat/completions", json=payload)
88+
89+
assert response.status_code == 200, response.text
90+
data = response.json()
91+
92+
# Top-level OpenAI-compatible structure.
93+
assert data["object"] == "chat.completion"
94+
assert data["id"].startswith("chatcmpl-")
95+
assert isinstance(data["created"], int)
96+
assert data["model"] == "gaia-code"
97+
98+
# The agent was invoked with the extracted user message.
99+
fake_agent.process_query.assert_called_once()
100+
call_args, _ = fake_agent.process_query.call_args
101+
assert call_args[0] == "Write a hello world function in Python"
102+
103+
# Choices.
104+
assert len(data["choices"]) == 1
105+
choice = data["choices"][0]
106+
assert choice["index"] == 0
107+
assert choice["message"]["role"] == "assistant"
108+
assert choice["message"]["content"] == (
109+
"def hello():\n return 'hello world'"
110+
)
111+
assert choice["finish_reason"] == "stop"
112+
113+
# Usage accounting.
114+
usage = data["usage"]
115+
assert usage["prompt_tokens"] > 0
116+
assert usage["completion_tokens"] > 0
117+
assert usage["total_tokens"] == (
118+
usage["prompt_tokens"] + usage["completion_tokens"]
119+
)
120+
121+
def test_completion_uses_last_user_message(self, mocker):
122+
"""The handler passes the LAST user message (not system/assistant) to the agent."""
123+
fake_agent = mocker.MagicMock()
124+
fake_agent.process_query.return_value = {"result": "ok"}
125+
126+
from gaia.api.openai_server import registry as server_registry
127+
128+
mocker.patch.object(server_registry, "get_agent", return_value=fake_agent)
129+
130+
payload = {
131+
"model": "gaia-code",
132+
"messages": [
133+
{"role": "system", "content": "You are helpful"},
134+
{"role": "user", "content": "first question"},
135+
{"role": "assistant", "content": "an earlier answer"},
136+
{"role": "user", "content": "second question"},
137+
],
138+
"stream": False,
139+
}
140+
response = self.client.post("/v1/chat/completions", json=payload)
141+
142+
assert response.status_code == 200, response.text
143+
call_args, _ = fake_agent.process_query.call_args
144+
assert call_args[0] == "second question"
145+
57146
# -------------------------------------------------------------------------
58147
# Model Validation Tests
59148
# -------------------------------------------------------------------------
@@ -308,47 +397,6 @@ def test_422_error_has_detail_field(self):
308397
class TestChatCompletionsNonStreaming:
309398
"""Test POST /v1/chat/completions without streaming"""
310399

311-
@pytest.mark.skip(reason="Skipped: API server returns 500 - see issue for fix")
312-
def test_basic_completion_with_code_agent(self, api_server, api_client):
313-
"""Test that gaia-code returns valid OpenAI-compatible completion"""
314-
payload = {
315-
"model": "gaia-code",
316-
"messages": [
317-
{"role": "user", "content": "Write a hello world function in Python"}
318-
],
319-
"stream": False,
320-
}
321-
response = api_client.post(f"{api_server}/v1/chat/completions", json=payload)
322-
323-
assert response.status_code == 200
324-
data = response.json()
325-
326-
# Verify OpenAI-compatible structure
327-
assert data["object"] == "chat.completion"
328-
assert "id" in data
329-
assert data["id"].startswith("chatcmpl-")
330-
assert "created" in data
331-
assert isinstance(data["created"], int)
332-
assert data["model"] == "gaia-code"
333-
334-
# Verify choices
335-
assert "choices" in data
336-
assert len(data["choices"]) == 1
337-
choice = data["choices"][0]
338-
assert choice["index"] == 0
339-
assert choice["message"]["role"] == "assistant"
340-
assert isinstance(choice["message"]["content"], str)
341-
assert len(choice["message"]["content"]) > 0
342-
assert choice["finish_reason"] in ["stop", "length"]
343-
344-
# Verify token usage
345-
assert "usage" in data
346-
assert data["usage"]["prompt_tokens"] > 0
347-
assert data["usage"]["completion_tokens"] > 0
348-
assert data["usage"]["total_tokens"] == (
349-
data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"]
350-
)
351-
352400
def test_invalid_model_returns_404(self, api_server, api_client):
353401
"""Test that invalid model returns 404 error"""
354402
payload = {

0 commit comments

Comments
 (0)