Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions src/agentic_ci/backends/openshell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import os
import shlex
import subprocess
import tempfile
import threading
from typing import TYPE_CHECKING

from agentic_ci import log
Expand All @@ -14,6 +16,51 @@
if TYPE_CHECKING:
from agentic_ci.harness import Harness

# GCP access tokens minted by the OpenShell gateway live for 3600s. The
# gateway's refresh worker is supposed to rotate them ahead of expiry, but
# around the hourly boundary a transient mint failure (retried only every 60s)
# can let the token lapse, producing a burst of 401s that exhausts the agent's
# retry budget and kills the run mid-way (see NVIDIA/OpenShell PR #1763).
#
# Force a rotation well inside the token lifetime so a freshly minted token is
# always present, tolerating a couple of failed rotations without draining the
# token's remaining life.
_TOKEN_KEEPALIVE_INTERVAL = 1200 # rotate every 20 min

# Phase-offset the first rotation by 10 min so the 20-min cadence lands at
# 10/30/50/70/... min, never coinciding with the ~hourly expiry boundary that
# the gateway refresh worker and the agent's client token cache already act on.
# Rotating on top of that natural re-fetch correlated with extra transient
# errors; offsetting avoids the collision.
_TOKEN_KEEPALIVE_OFFSET = 600 # 10 min


def _token_keepalive(stop: threading.Event) -> None:
"""Force-rotate the gateway's GCP access token on a phase-offset 20-min
cadence until *stop* is set. Failures are logged but never raised."""
if stop.wait(_TOKEN_KEEPALIVE_OFFSET):
return
while True:
try:
provider.rotate_token()
except subprocess.CalledProcessError as exc:
print(
f" [token-keepalive] rotate failed (rc={exc.returncode}): "
f"{exc.stderr.strip() if exc.stderr else ''}",
flush=True,
)
if stop.wait(_TOKEN_KEEPALIVE_INTERVAL):
return


# Claude Code's API retry budget (the "Retry N/10" counter in the stream).
# The default is 10, but a Vertex token-rotation lapse can produce a burst of
# retryable "unknown" errors that, on stock 60-min token intervals, exhausted
# all 10 retries and killed the run. The 20-min token keepalive shortens those
# windows; this widens the budget so even an unlucky long lapse recovers.
# Belt-and-suspenders with the keepalive above. Overridable via env var.
_DEFAULT_MAX_RETRIES = "20"

_OPENSHELL_HOST = "host.openshell.internal"


Expand Down Expand Up @@ -111,16 +158,34 @@ def run(
"--",
*agent_args,
]
proc = sandbox.exec_cmd_streaming(cmd)

rc, stream_complete = self._process_stream(proc, streaming)
self._wait_for_otel_flush(otel_port)
stop_keepalive = threading.Event()
keepalive: threading.Thread | None = None

log.section("Downloading workdir")
sandbox.download(sandbox_workdir, self.workdir)
# The token-lapse race only affects the OpenShell gateway's minted
# Vertex credential; the API-key auth path is unaffected.
if self.harness.auth_mode == "vertex":
log.section("Starting GCP token keepalive")
keepalive = threading.Thread(
target=_token_keepalive, args=(stop_keepalive,), daemon=True
)
keepalive.start()

rc = self._resolve_exit_code(rc, stream_complete)
return rc
try:
proc = sandbox.exec_cmd_streaming(cmd)

rc, stream_complete = self._process_stream(proc, streaming)
self._wait_for_otel_flush(otel_port)

log.section("Downloading workdir")
sandbox.download(sandbox_workdir, self.workdir)

rc = self._resolve_exit_code(rc, stream_complete)
return rc
finally:
stop_keepalive.set()
if keepalive:
keepalive.join(timeout=5)

def _write_env_script(self, model, otel_port=None, otel_rate_file=None):
"""Write env vars to a script inside the sandbox, sourced before the agent runs.
Expand All @@ -147,6 +212,10 @@ def _write_env_script(self, model, otel_port=None, otel_rate_file=None):
else:
lines.append("export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1")

if self.harness.auth_mode == "vertex":
max_retries = os.environ.get("CLAUDE_CODE_MAX_RETRIES", _DEFAULT_MAX_RETRIES)
lines.append(f"export CLAUDE_CODE_MAX_RETRIES={shlex.quote(max_retries)}")

Comment on lines +215 to +218

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Validate and bound CLAUDE_CODE_MAX_RETRIES before exporting it.

This path exports arbitrary env text. Non-integer/negative values can break retry behavior, and very large values can unintentionally extend run time/cost.

Proposed fix
 if self.harness.auth_mode == "vertex":
-    max_retries = os.environ.get("CLAUDE_CODE_MAX_RETRIES", _DEFAULT_MAX_RETRIES)
-    lines.append(f"export CLAUDE_CODE_MAX_RETRIES={shlex.quote(max_retries)}")
+    raw_max_retries = os.environ.get("CLAUDE_CODE_MAX_RETRIES", _DEFAULT_MAX_RETRIES)
+    try:
+        max_retries = int(raw_max_retries)
+    except ValueError:
+        max_retries = int(_DEFAULT_MAX_RETRIES)
+    max_retries = max(1, min(max_retries, 50))
+    lines.append(f"export CLAUDE_CODE_MAX_RETRIES={max_retries}")

As per coding guidelines, "**: REVIEW PRIORITIES: 3. Bug-prone patterns and error handling gaps 4. Performance problems."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/agentic_ci/backends/openshell/__init__.py` around lines 215 - 218, The
CLAUDE_CODE_MAX_RETRIES environment variable is being exported without
validation, which can cause issues if it contains non-integer, negative, or
excessively large values that break retry behavior or extend run time/cost. When
auth_mode is "vertex", parse the retrieved max_retries value to validate it is a
valid non-negative integer within acceptable bounds before exporting it. If the
value fails validation (cannot be parsed as an integer, is negative, or exceeds
a reasonable maximum), fall back to using _DEFAULT_MAX_RETRIES instead of
exporting invalid data.

Source: Coding guidelines

for key, val in self._extra_env.items():
lines.append(f"export {key}={shlex.quote(val)}")

Expand Down
13 changes: 13 additions & 0 deletions src/agentic_ci/backends/openshell/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ def _create_gcp_provider_sa(project, region):

# The refresh worker runs on a 60s interval. Request an immediate
# rotation so the initial access token is minted before the agent starts.
rotate_token()


def rotate_token():
"""Force-rotate the gateway's GCP access token.

The OpenShell gateway refresh worker mints tokens on a 60s interval,
but can let a token lapse around the hourly expiry boundary when a
transient mint failure is only retried after 60s while the old token
keeps aging. Calling this proactively keeps a fresh token in play.

Raises subprocess.CalledProcessError on failure.
"""
_run(
[
"openshell",
Expand Down
104 changes: 103 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Tests for backend factory."""

import threading
from unittest import mock

import pytest

from agentic_ci.backends import create_backend
from agentic_ci.backends.local import LocalBackend
from agentic_ci.backends.openshell import OpenShellBackend
from agentic_ci.backends.openshell import OpenShellBackend, _token_keepalive
from agentic_ci.backends.podman import PodmanBackend
from agentic_ci.harness import ClaudeCodeHarness, create_harness

Expand Down Expand Up @@ -134,3 +135,104 @@ def test_env_script_includes_enabled_plugins_var(self, monkeypatch, tmp_path):
def test_env_script_omits_enabled_plugins_when_unset(self, monkeypatch, tmp_path):
script = self._capture_script(monkeypatch, tmp_path)
assert "AGENT_ENABLED_PLUGINS" not in script

def _capture_script_vertex(self, monkeypatch, tmp_path, **env_overrides):
"""Like _capture_script but with Vertex auth (no API key)."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.setenv("ANTHROPIC_VERTEX_PROJECT_ID", "test-project")
monkeypatch.delenv("AGENT_ENABLED_PLUGINS", raising=False)
monkeypatch.delenv("CLAUDE_CODE_MAX_RETRIES", raising=False)
for key, val in env_overrides.items():
if val is None:
monkeypatch.delenv(key, raising=False)
else:
monkeypatch.setenv(key, val)

harness = ClaudeCodeHarness()
backend = OpenShellBackend(workdir=str(tmp_path), harness=harness)

captured = []

def mock_upload(path):
with open(path) as f:
captured.append(f.read())

with (
mock.patch("agentic_ci.backends.openshell.sandbox.upload", side_effect=mock_upload),
mock.patch("agentic_ci.backends.openshell.sandbox.exec_cmd"),
):
backend._write_env_script("claude-opus-4-6")

assert len(captured) == 1
return captured[0]

def test_env_script_sets_max_retries_for_vertex(self, monkeypatch, tmp_path):
script = self._capture_script_vertex(monkeypatch, tmp_path)
assert "CLAUDE_CODE_MAX_RETRIES=20" in script

def test_env_script_max_retries_override(self, monkeypatch, tmp_path):
script = self._capture_script_vertex(monkeypatch, tmp_path, CLAUDE_CODE_MAX_RETRIES="30")
assert "CLAUDE_CODE_MAX_RETRIES=30" in script

def test_env_script_omits_max_retries_for_api_key(self, monkeypatch, tmp_path):
script = self._capture_script(monkeypatch, tmp_path)
assert "CLAUDE_CODE_MAX_RETRIES" not in script


class TestTokenKeepalive:
"""Tests for _token_keepalive and its integration in OpenShellBackend.run()."""

def test_keepalive_calls_rotate_token(self, monkeypatch):
"""After the phase offset, rotate_token is called on each interval tick."""
monkeypatch.setattr("agentic_ci.backends.openshell._TOKEN_KEEPALIVE_OFFSET", 0)
monkeypatch.setattr("agentic_ci.backends.openshell._TOKEN_KEEPALIVE_INTERVAL", 0)
call_count = 0
stop = threading.Event()

def mock_rotate():
nonlocal call_count
call_count += 1
if call_count >= 3:
stop.set()

monkeypatch.setattr("agentic_ci.backends.openshell.provider.rotate_token", mock_rotate)
_token_keepalive(stop)
assert call_count >= 3

def test_keepalive_stops_on_event_during_offset(self, monkeypatch):
"""Setting stop during the initial offset exits without rotating."""
monkeypatch.setattr("agentic_ci.backends.openshell._TOKEN_KEEPALIVE_OFFSET", 10)
rotate_called = False

def mock_rotate():
nonlocal rotate_called
rotate_called = True

monkeypatch.setattr("agentic_ci.backends.openshell.provider.rotate_token", mock_rotate)
stop = threading.Event()
stop.set()
_token_keepalive(stop)
assert not rotate_called

def test_keepalive_logs_rotate_failure(self, monkeypatch, capsys):
"""CalledProcessError from rotate_token is logged, not raised."""
import subprocess

monkeypatch.setattr("agentic_ci.backends.openshell._TOKEN_KEEPALIVE_OFFSET", 0)
monkeypatch.setattr("agentic_ci.backends.openshell._TOKEN_KEEPALIVE_INTERVAL", 0)
stop = threading.Event()
call_count = 0

def mock_rotate():
nonlocal call_count
call_count += 1
if call_count == 1:
raise subprocess.CalledProcessError(1, "openshell", stderr="auth error")
stop.set()

monkeypatch.setattr("agentic_ci.backends.openshell.provider.rotate_token", mock_rotate)
_token_keepalive(stop)
captured = capsys.readouterr()
assert "[token-keepalive] rotate failed" in captured.out
assert "auth error" in captured.out
assert call_count >= 2
35 changes: 35 additions & 0 deletions tests/test_openshell_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Tests for OpenShell provider token rotation."""

import subprocess
from unittest import mock

import pytest

from agentic_ci.backends.openshell.provider import PROVIDER_NAME, rotate_token


class TestRotateToken:
def test_rotate_token_calls_openshell(self):
with mock.patch("agentic_ci.backends.openshell.provider._run") as mock_run:
rotate_token()

mock_run.assert_called_once_with(
[
"openshell",
"provider",
"refresh",
"rotate",
"--credential-key",
"GCP_SA_ACCESS_TOKEN",
PROVIDER_NAME,
],
check=True,
)

def test_rotate_token_propagates_failure(self):
with mock.patch(
"agentic_ci.backends.openshell.provider._run",
side_effect=subprocess.CalledProcessError(1, "openshell"),
):
with pytest.raises(subprocess.CalledProcessError):
rotate_token()
Loading