Skip to content

Commit 37f13cf

Browse files
committed
feat(vmcp): add auth-retry with circuit breaker to BackendClient
`NewHTTPBackendClient` now wraps the raw HTTP client in a `retryingBackendClient` decorator that: - Intercepts `ErrAuthenticationFailed` (401/403) returned by any `BackendClient` method and retries up to `maxAuthRetries` (3) times with exponential backoff (100 ms base, doubling each attempt). - Maintains a per-backend `authCircuitBreaker` that opens after `authCircuitBreakerThreshold` (5) consecutive fully-exhausted retry sequences, preventing runaway latency from permanently broken credentials. - Uses `singleflight` to coalesce concurrent backoff waits for the same backend, so N goroutines racing on a 401 sleep only once per attempt. - Wraps each retry sequence in an OpenTelemetry `auth.retry` span so the overhead is visible in distributed traces. - Never logs raw credentials. `IsAuthenticationError` is refactored from chained `if` blocks into a package-level `authErrorPatterns` slice (fixes gocyclo limit and adds the mcp-go `"unauthorized (401)"` format that was previously missed). The integration test helper gains `WithHTTPMiddleware` so tests can inject transient HTTP errors without modifying the MCP server logic. New tests: - 9 unit tests (`auth_retry_test.go`) covering success, single-retry, max-retries, non-auth passthrough, circuit breaker open/reset, and context cancellation. - 1 integration test (`auth_retry_integration_test.go`) against a real mcp-go streamable-HTTP server. - 1 E2E Ginkgo suite (`virtualmcp_auth_retry_test.go`) deploying a Python 401 backend alongside a healthy yardstick backend and asserting that the failing backend is marked `BackendStatusUnavailable` while the stable backend remains ready. Closes: #3869
1 parent 0de510d commit 37f13cf

10 files changed

Lines changed: 1162 additions & 37 deletions

File tree

pkg/vmcp/client/auth_retry.go

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package client
5+
6+
import (
7+
"context"
8+
"errors"
9+
"fmt"
10+
"log/slog"
11+
"sync"
12+
"time"
13+
14+
"go.opentelemetry.io/otel"
15+
"go.opentelemetry.io/otel/attribute"
16+
"go.opentelemetry.io/otel/codes"
17+
"go.opentelemetry.io/otel/trace"
18+
"golang.org/x/sync/singleflight"
19+
20+
"github.com/stacklok/toolhive/pkg/vmcp"
21+
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
22+
)
23+
24+
const (
25+
// authRetryInstrumentationName is the OpenTelemetry instrumentation scope for auth retries.
26+
authRetryInstrumentationName = "github.com/stacklok/toolhive/pkg/vmcp/client"
27+
28+
// maxAuthRetries is the maximum number of retry attempts after an auth failure.
29+
maxAuthRetries = 3
30+
31+
// authCircuitBreakerThreshold is the number of consecutive auth failures before
32+
// the circuit breaker opens and disables further retries for a backend.
33+
authCircuitBreakerThreshold = 5
34+
35+
// initialRetryBackoff is the base duration for exponential backoff between retries.
36+
// Attempt 1: 100ms, Attempt 2: 200ms, Attempt 3: 400ms.
37+
initialRetryBackoff = 100 * time.Millisecond
38+
)
39+
40+
// authCircuitBreaker tracks consecutive auth failures per backend and opens the circuit
41+
// after too many failures to prevent excessive latency from repeated auth retries.
42+
type authCircuitBreaker struct {
43+
mu sync.Mutex
44+
consecutiveFails int
45+
open bool
46+
}
47+
48+
// canRetry returns true if auth retries are still allowed (circuit is closed).
49+
func (cb *authCircuitBreaker) canRetry() bool {
50+
cb.mu.Lock()
51+
defer cb.mu.Unlock()
52+
return !cb.open
53+
}
54+
55+
// recordSuccess resets the consecutive failure counter and closes the circuit.
56+
func (cb *authCircuitBreaker) recordSuccess() {
57+
cb.mu.Lock()
58+
defer cb.mu.Unlock()
59+
cb.consecutiveFails = 0
60+
cb.open = false
61+
}
62+
63+
// recordFailure increments the failure counter and opens the circuit if the threshold is exceeded.
64+
func (cb *authCircuitBreaker) recordFailure(threshold int, backendID string) {
65+
cb.mu.Lock()
66+
defer cb.mu.Unlock()
67+
cb.consecutiveFails++
68+
if !cb.open && cb.consecutiveFails >= threshold {
69+
cb.open = true
70+
slog.Warn("auth circuit breaker opened: too many consecutive auth failures, disabling retries",
71+
"backend", backendID, "consecutive_failures", cb.consecutiveFails)
72+
}
73+
}
74+
75+
// retryingBackendClient wraps a BackendClient and automatically retries operations that
76+
// fail due to authentication errors (401/403). It uses:
77+
// - Exponential backoff with a maximum of [maxAuthRetries] attempts
78+
// - A per-backend circuit breaker to stop retrying after [authCircuitBreakerThreshold] consecutive failures
79+
// - singleflight to deduplicate concurrent backoff waits for the same backend
80+
// - OpenTelemetry spans to surface auth-retry latency in distributed traces
81+
//
82+
// Raw credentials are never logged.
83+
type retryingBackendClient struct {
84+
inner vmcp.BackendClient
85+
registry vmcpauth.OutgoingAuthRegistry
86+
87+
// sf deduplicates concurrent backoff waits for the same backend at the same attempt number.
88+
sf singleflight.Group
89+
90+
// breakers maps backendID -> *authCircuitBreaker. LoadOrStore is used for concurrent safety.
91+
breakers sync.Map
92+
93+
tracer trace.Tracer
94+
maxRetries int
95+
cbThreshold int
96+
initialBackoff time.Duration
97+
98+
// backoffFn is the sleep function used inside singleflight. nil uses time.After.
99+
// Tests inject a counted hook to assert coalescing without real wall-clock delays.
100+
backoffFn func(ctx context.Context, d time.Duration) error
101+
}
102+
103+
// newRetryingBackendClient wraps inner with auth-failure retry logic.
104+
func newRetryingBackendClient(inner vmcp.BackendClient, registry vmcpauth.OutgoingAuthRegistry) *retryingBackendClient {
105+
return &retryingBackendClient{
106+
inner: inner,
107+
registry: registry,
108+
tracer: otel.Tracer(authRetryInstrumentationName),
109+
maxRetries: maxAuthRetries,
110+
cbThreshold: authCircuitBreakerThreshold,
111+
initialBackoff: initialRetryBackoff,
112+
}
113+
}
114+
115+
// getBreaker returns (or lazily creates) the auth circuit breaker for a backend.
116+
func (r *retryingBackendClient) getBreaker(backendID string) *authCircuitBreaker {
117+
v, _ := r.breakers.LoadOrStore(backendID, &authCircuitBreaker{})
118+
return v.(*authCircuitBreaker) //nolint:forcetypeassert
119+
}
120+
121+
// withAuthRetry executes op, and if it returns ErrAuthenticationFailed, retries up to
122+
// r.maxRetries times with exponential backoff, using singleflight to deduplicate concurrent
123+
// backoff waits per backend. Auth-retry overhead is surfaced as an OpenTelemetry span.
124+
func (r *retryingBackendClient) withAuthRetry(
125+
ctx context.Context,
126+
backendID string,
127+
op func(context.Context) error,
128+
) error {
129+
breaker := r.getBreaker(backendID)
130+
131+
err := op(ctx)
132+
if err == nil {
133+
breaker.recordSuccess()
134+
return nil
135+
}
136+
if !errors.Is(err, vmcp.ErrAuthenticationFailed) {
137+
return err
138+
}
139+
if !breaker.canRetry() {
140+
slog.Debug("auth circuit breaker open, skipping auth retry",
141+
"backend", backendID)
142+
return err
143+
}
144+
145+
// Start a span to surface auth-retry latency in distributed traces.
146+
ctx, span := r.tracer.Start(ctx, "auth.retry",
147+
trace.WithAttributes(
148+
attribute.String("target.workload_id", backendID),
149+
attribute.Int("max_retries", r.maxRetries),
150+
),
151+
trace.WithSpanKind(trace.SpanKindInternal),
152+
)
153+
defer span.End()
154+
155+
var lastErr error
156+
for attempt := uint(1); attempt <= uint(r.maxRetries); attempt++ { //nolint:gosec // maxRetries is always positive
157+
// Use singleflight to deduplicate concurrent backoff waits for the same backend
158+
// and attempt number. The first goroutine sleeps; the others coalesce with it.
159+
// DoChan is used instead of Do so every caller can also select on its own
160+
// ctx.Done() — otherwise a coalesced caller with a short deadline would be
161+
// stuck for the full backoff duration of the leader's longer-lived context.
162+
sfKey := fmt.Sprintf("%s:attempt:%d", backendID, attempt)
163+
ch := r.sf.DoChan(sfKey, func() (any, error) {
164+
backoff := time.Duration(uint(1)<<(attempt-1)) * r.initialBackoff //nolint:gosec // attempt bounded by maxAuthRetries
165+
if r.backoffFn != nil {
166+
return nil, r.backoffFn(ctx, backoff)
167+
}
168+
select {
169+
case <-ctx.Done():
170+
return nil, ctx.Err()
171+
case <-time.After(backoff):
172+
return nil, nil
173+
}
174+
})
175+
var sfErr error
176+
select {
177+
case <-ctx.Done():
178+
sfErr = ctx.Err()
179+
case res := <-ch:
180+
sfErr = res.Err
181+
}
182+
if sfErr != nil {
183+
span.RecordError(sfErr)
184+
return sfErr
185+
}
186+
187+
span.AddEvent("auth.retry.attempt",
188+
trace.WithAttributes(attribute.Int("attempt", int(attempt)))) //nolint:gosec // attempt bounded by maxAuthRetries
189+
190+
retryErr := op(ctx)
191+
if retryErr == nil {
192+
breaker.recordSuccess()
193+
span.SetStatus(codes.Ok, "auth retry succeeded")
194+
return nil
195+
}
196+
197+
lastErr = retryErr
198+
if !errors.Is(retryErr, vmcp.ErrAuthenticationFailed) {
199+
// Non-auth error on retry — no point continuing auth retries.
200+
span.RecordError(retryErr)
201+
return retryErr
202+
}
203+
}
204+
205+
// All retries exhausted with auth failures — update circuit breaker.
206+
breaker.recordFailure(r.cbThreshold, backendID)
207+
span.RecordError(lastErr)
208+
span.SetStatus(codes.Error, "auth retry exhausted")
209+
return lastErr
210+
}
211+
212+
// CallTool implements vmcp.BackendClient.
213+
func (r *retryingBackendClient) CallTool(
214+
ctx context.Context,
215+
target *vmcp.BackendTarget,
216+
toolName string,
217+
arguments map[string]any,
218+
meta map[string]any,
219+
) (*vmcp.ToolCallResult, error) {
220+
var result *vmcp.ToolCallResult
221+
err := r.withAuthRetry(ctx, target.WorkloadID, func(ctx context.Context) error {
222+
var opErr error
223+
result, opErr = r.inner.CallTool(ctx, target, toolName, arguments, meta)
224+
return opErr
225+
})
226+
return result, err
227+
}
228+
229+
// ReadResource implements vmcp.BackendClient.
230+
func (r *retryingBackendClient) ReadResource(
231+
ctx context.Context,
232+
target *vmcp.BackendTarget,
233+
uri string,
234+
) (*vmcp.ResourceReadResult, error) {
235+
var result *vmcp.ResourceReadResult
236+
err := r.withAuthRetry(ctx, target.WorkloadID, func(ctx context.Context) error {
237+
var opErr error
238+
result, opErr = r.inner.ReadResource(ctx, target, uri)
239+
return opErr
240+
})
241+
return result, err
242+
}
243+
244+
// GetPrompt implements vmcp.BackendClient.
245+
func (r *retryingBackendClient) GetPrompt(
246+
ctx context.Context,
247+
target *vmcp.BackendTarget,
248+
name string,
249+
arguments map[string]any,
250+
) (*vmcp.PromptGetResult, error) {
251+
var result *vmcp.PromptGetResult
252+
err := r.withAuthRetry(ctx, target.WorkloadID, func(ctx context.Context) error {
253+
var opErr error
254+
result, opErr = r.inner.GetPrompt(ctx, target, name, arguments)
255+
return opErr
256+
})
257+
return result, err
258+
}
259+
260+
// ListCapabilities implements vmcp.BackendClient.
261+
func (r *retryingBackendClient) ListCapabilities(
262+
ctx context.Context,
263+
target *vmcp.BackendTarget,
264+
) (*vmcp.CapabilityList, error) {
265+
var result *vmcp.CapabilityList
266+
err := r.withAuthRetry(ctx, target.WorkloadID, func(ctx context.Context) error {
267+
var opErr error
268+
result, opErr = r.inner.ListCapabilities(ctx, target)
269+
return opErr
270+
})
271+
return result, err
272+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package client_test
5+
6+
import (
7+
"context"
8+
"net"
9+
"net/http"
10+
"net/http/httptest"
11+
"sync/atomic"
12+
"testing"
13+
"time"
14+
15+
"github.com/mark3labs/mcp-go/mcp"
16+
"github.com/mark3labs/mcp-go/server"
17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
19+
20+
"github.com/stacklok/toolhive/pkg/vmcp"
21+
"github.com/stacklok/toolhive/pkg/vmcp/auth"
22+
"github.com/stacklok/toolhive/pkg/vmcp/auth/strategies"
23+
vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client"
24+
)
25+
26+
// TestAuthRetry_Transient401_ListCapabilities verifies the end-to-end retry path when a
27+
// backend MCP server returns HTTP 401 on the first request it receives.
28+
//
29+
// NewHTTPBackendClient wraps httpBackendClient with retryingBackendClient.
30+
// ListCapabilities creates a fresh MCP client per call (Start + Initialize + List*).
31+
// mcp-go returns ErrUnauthorized ("unauthorized (401)") for a 401 response, which
32+
// IsAuthenticationError now matches, so retryingBackendClient should retry and succeed.
33+
func TestAuthRetry_Transient401_ListCapabilities(t *testing.T) {
34+
t.Parallel()
35+
36+
var requestCount atomic.Int32
37+
backend, cleanup := startTransient401Server(t, &requestCount)
38+
defer cleanup()
39+
40+
registry := auth.NewDefaultOutgoingAuthRegistry()
41+
require.NoError(t, registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}))
42+
43+
backendClient, err := vmcpclient.NewHTTPBackendClient(registry)
44+
require.NoError(t, err)
45+
46+
target := &vmcp.BackendTarget{
47+
WorkloadID: "test-backend",
48+
WorkloadName: "Test Backend",
49+
BaseURL: backend,
50+
TransportType: "streamable-http",
51+
}
52+
53+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
54+
defer cancel()
55+
56+
// ListCapabilities should succeed despite the initial 401 — the retry wrapper
57+
// must recreate the MCP client and successfully complete the capability query.
58+
caps, err := backendClient.ListCapabilities(ctx, target)
59+
60+
require.NoError(t, err, "ListCapabilities should succeed after auth retry")
61+
require.NotNil(t, caps)
62+
assert.Len(t, caps.Tools, 1, "should discover the echo tool after retry")
63+
64+
// Confirm the retry was exercised: the backend received more than one batch of
65+
// requests (the 401 attempt + the successful retry).
66+
assert.Greater(t, int(requestCount.Load()), 1,
67+
"backend must have received >1 request, confirming retry was exercised")
68+
}
69+
70+
// startTransient401Server starts an httptest.Server backed by a real mcp-go MCP server.
71+
// It returns 401 for the first request, then passes through to the real handler.
72+
// The returned cleanup function must be deferred by the caller.
73+
func startTransient401Server(tb testing.TB, requestCount *atomic.Int32) (baseURL string, cleanup func()) {
74+
tb.Helper()
75+
76+
mcpSrv := server.NewMCPServer("test-backend", "1.0.0",
77+
server.WithToolCapabilities(true),
78+
)
79+
mcpSrv.AddTool(
80+
mcp.Tool{Name: "echo", Description: "Echo the input"},
81+
func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
82+
return &mcp.CallToolResult{
83+
Content: []mcp.Content{mcp.NewTextContent("ok")},
84+
}, nil
85+
},
86+
)
87+
88+
streamable := server.NewStreamableHTTPServer(mcpSrv, server.WithEndpointPath("/mcp"))
89+
90+
httpSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91+
if r.Method == http.MethodDelete {
92+
// Allow session-close DELETE to pass through without counting.
93+
streamable.ServeHTTP(w, r)
94+
return
95+
}
96+
n := requestCount.Add(1)
97+
if n <= 1 {
98+
w.WriteHeader(http.StatusUnauthorized)
99+
return
100+
}
101+
streamable.ServeHTTP(w, r)
102+
}))
103+
104+
// Bind to a free port on loopback.
105+
ln, err := net.Listen("tcp", "127.0.0.1:0")
106+
require.NoError(tb, err)
107+
httpSrv.Listener = ln
108+
httpSrv.Start()
109+
110+
tb.Logf("started transient-401 backend at %s/mcp (will fail first non-DELETE request)", httpSrv.URL)
111+
112+
return httpSrv.URL + "/mcp", httpSrv.Close
113+
}

0 commit comments

Comments
 (0)