|
| 1 | +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +package client |
| 5 | + |
| 6 | +import ( |
| 7 | + "context" |
| 8 | + "errors" |
| 9 | + "fmt" |
| 10 | + "log/slog" |
| 11 | + "sync" |
| 12 | + "time" |
| 13 | + |
| 14 | + "go.opentelemetry.io/otel" |
| 15 | + "go.opentelemetry.io/otel/attribute" |
| 16 | + "go.opentelemetry.io/otel/codes" |
| 17 | + "go.opentelemetry.io/otel/trace" |
| 18 | + "golang.org/x/sync/singleflight" |
| 19 | + |
| 20 | + "github.com/stacklok/toolhive/pkg/vmcp" |
| 21 | + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" |
| 22 | +) |
| 23 | + |
| 24 | +const ( |
| 25 | + // authRetryInstrumentationName is the OpenTelemetry instrumentation scope for auth retries. |
| 26 | + authRetryInstrumentationName = "github.com/stacklok/toolhive/pkg/vmcp/client" |
| 27 | + |
| 28 | + // maxAuthRetries is the maximum number of retry attempts after an auth failure. |
| 29 | + maxAuthRetries = 3 |
| 30 | + |
| 31 | + // authCircuitBreakerThreshold is the number of consecutive auth failures before |
| 32 | + // the circuit breaker opens and disables further retries for a backend. |
| 33 | + authCircuitBreakerThreshold = 5 |
| 34 | + |
| 35 | + // initialRetryBackoff is the base duration for exponential backoff between retries. |
| 36 | + // Attempt 1: 100ms, Attempt 2: 200ms, Attempt 3: 400ms. |
| 37 | + initialRetryBackoff = 100 * time.Millisecond |
| 38 | +) |
| 39 | + |
| 40 | +// authCircuitBreaker tracks consecutive auth failures per backend and opens the circuit |
| 41 | +// after too many failures to prevent excessive latency from repeated auth retries. |
| 42 | +type authCircuitBreaker struct { |
| 43 | + mu sync.Mutex |
| 44 | + consecutiveFails int |
| 45 | + open bool |
| 46 | +} |
| 47 | + |
| 48 | +// canRetry returns true if auth retries are still allowed (circuit is closed). |
| 49 | +func (cb *authCircuitBreaker) canRetry() bool { |
| 50 | + cb.mu.Lock() |
| 51 | + defer cb.mu.Unlock() |
| 52 | + return !cb.open |
| 53 | +} |
| 54 | + |
| 55 | +// recordSuccess resets the consecutive failure counter and closes the circuit. |
| 56 | +func (cb *authCircuitBreaker) recordSuccess() { |
| 57 | + cb.mu.Lock() |
| 58 | + defer cb.mu.Unlock() |
| 59 | + cb.consecutiveFails = 0 |
| 60 | + cb.open = false |
| 61 | +} |
| 62 | + |
| 63 | +// recordFailure increments the failure counter and opens the circuit if the threshold is exceeded. |
| 64 | +func (cb *authCircuitBreaker) recordFailure(threshold int, backendID string) { |
| 65 | + cb.mu.Lock() |
| 66 | + defer cb.mu.Unlock() |
| 67 | + cb.consecutiveFails++ |
| 68 | + if !cb.open && cb.consecutiveFails >= threshold { |
| 69 | + cb.open = true |
| 70 | + slog.Warn("auth circuit breaker opened: too many consecutive auth failures, disabling retries", |
| 71 | + "backend", backendID, "consecutive_failures", cb.consecutiveFails) |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +// retryingBackendClient wraps a BackendClient and automatically retries operations that |
| 76 | +// fail due to authentication errors (401/403). It uses: |
| 77 | +// - Exponential backoff with a maximum of [maxAuthRetries] attempts |
| 78 | +// - A per-backend circuit breaker to stop retrying after [authCircuitBreakerThreshold] consecutive failures |
| 79 | +// - singleflight to deduplicate concurrent backoff waits for the same backend |
| 80 | +// - OpenTelemetry spans to surface auth-retry latency in distributed traces |
| 81 | +// |
| 82 | +// Raw credentials are never logged. |
| 83 | +type retryingBackendClient struct { |
| 84 | + inner vmcp.BackendClient |
| 85 | + registry vmcpauth.OutgoingAuthRegistry |
| 86 | + |
| 87 | + // sf deduplicates concurrent backoff waits for the same backend at the same attempt number. |
| 88 | + sf singleflight.Group |
| 89 | + |
| 90 | + // breakers maps backendID -> *authCircuitBreaker. LoadOrStore is used for concurrent safety. |
| 91 | + breakers sync.Map |
| 92 | + |
| 93 | + tracer trace.Tracer |
| 94 | + maxRetries int |
| 95 | + cbThreshold int |
| 96 | + initialBackoff time.Duration |
| 97 | + |
| 98 | + // backoffFn is the sleep function used inside singleflight. nil uses time.After. |
| 99 | + // Tests inject a counted hook to assert coalescing without real wall-clock delays. |
| 100 | + backoffFn func(ctx context.Context, d time.Duration) error |
| 101 | +} |
| 102 | + |
| 103 | +// newRetryingBackendClient wraps inner with auth-failure retry logic. |
| 104 | +func newRetryingBackendClient(inner vmcp.BackendClient, registry vmcpauth.OutgoingAuthRegistry) *retryingBackendClient { |
| 105 | + return &retryingBackendClient{ |
| 106 | + inner: inner, |
| 107 | + registry: registry, |
| 108 | + tracer: otel.Tracer(authRetryInstrumentationName), |
| 109 | + maxRetries: maxAuthRetries, |
| 110 | + cbThreshold: authCircuitBreakerThreshold, |
| 111 | + initialBackoff: initialRetryBackoff, |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +// getBreaker returns (or lazily creates) the auth circuit breaker for a backend. |
| 116 | +func (r *retryingBackendClient) getBreaker(backendID string) *authCircuitBreaker { |
| 117 | + v, _ := r.breakers.LoadOrStore(backendID, &authCircuitBreaker{}) |
| 118 | + return v.(*authCircuitBreaker) //nolint:forcetypeassert |
| 119 | +} |
| 120 | + |
| 121 | +// withAuthRetry executes op, and if it returns ErrAuthenticationFailed, retries up to |
| 122 | +// r.maxRetries times with exponential backoff, using singleflight to deduplicate concurrent |
| 123 | +// backoff waits per backend. Auth-retry overhead is surfaced as an OpenTelemetry span. |
| 124 | +func (r *retryingBackendClient) withAuthRetry( |
| 125 | + ctx context.Context, |
| 126 | + backendID string, |
| 127 | + op func(context.Context) error, |
| 128 | +) error { |
| 129 | + breaker := r.getBreaker(backendID) |
| 130 | + |
| 131 | + err := op(ctx) |
| 132 | + if err == nil { |
| 133 | + breaker.recordSuccess() |
| 134 | + return nil |
| 135 | + } |
| 136 | + if !errors.Is(err, vmcp.ErrAuthenticationFailed) { |
| 137 | + return err |
| 138 | + } |
| 139 | + if !breaker.canRetry() { |
| 140 | + slog.Debug("auth circuit breaker open, skipping auth retry", |
| 141 | + "backend", backendID) |
| 142 | + return err |
| 143 | + } |
| 144 | + |
| 145 | + // Start a span to surface auth-retry latency in distributed traces. |
| 146 | + ctx, span := r.tracer.Start(ctx, "auth.retry", |
| 147 | + trace.WithAttributes( |
| 148 | + attribute.String("target.workload_id", backendID), |
| 149 | + attribute.Int("max_retries", r.maxRetries), |
| 150 | + ), |
| 151 | + trace.WithSpanKind(trace.SpanKindInternal), |
| 152 | + ) |
| 153 | + defer span.End() |
| 154 | + |
| 155 | + var lastErr error |
| 156 | + backoff := r.initialBackoff |
| 157 | + for attempt := 1; attempt <= r.maxRetries; attempt++ { |
| 158 | + // Use singleflight to deduplicate concurrent backoff waits for the same backend |
| 159 | + // and attempt number. The first goroutine sleeps; the others coalesce with it. |
| 160 | + // DoChan is used instead of Do so every caller can also select on its own |
| 161 | + // ctx.Done() — otherwise a coalesced caller with a short deadline would be |
| 162 | + // stuck for the full backoff duration of the leader's longer-lived context. |
| 163 | + sfKey := fmt.Sprintf("%s:attempt:%d", backendID, attempt) |
| 164 | + // The singleflight function uses a detached context so that a cancelled |
| 165 | + // leader goroutine does not propagate its error to all coalesced callers. |
| 166 | + // Per-caller cancellation is handled by the outer select on ctx.Done() below. |
| 167 | + detachedCtx := context.WithoutCancel(ctx) |
| 168 | + currentBackoff := backoff |
| 169 | + ch := r.sf.DoChan(sfKey, func() (any, error) { |
| 170 | + if r.backoffFn != nil { |
| 171 | + return nil, r.backoffFn(detachedCtx, currentBackoff) |
| 172 | + } |
| 173 | + select { |
| 174 | + case <-detachedCtx.Done(): |
| 175 | + return nil, detachedCtx.Err() |
| 176 | + case <-time.After(currentBackoff): |
| 177 | + return nil, nil |
| 178 | + } |
| 179 | + }) |
| 180 | + var sfErr error |
| 181 | + select { |
| 182 | + case <-ctx.Done(): |
| 183 | + sfErr = ctx.Err() |
| 184 | + case res := <-ch: |
| 185 | + sfErr = res.Err |
| 186 | + } |
| 187 | + if sfErr != nil { |
| 188 | + span.RecordError(sfErr) |
| 189 | + return sfErr |
| 190 | + } |
| 191 | + |
| 192 | + span.AddEvent("auth.retry.attempt", |
| 193 | + trace.WithAttributes(attribute.Int("attempt", attempt))) |
| 194 | + |
| 195 | + retryErr := op(ctx) |
| 196 | + if retryErr == nil { |
| 197 | + breaker.recordSuccess() |
| 198 | + span.SetStatus(codes.Ok, "auth retry succeeded") |
| 199 | + return nil |
| 200 | + } |
| 201 | + |
| 202 | + lastErr = retryErr |
| 203 | + if !errors.Is(retryErr, vmcp.ErrAuthenticationFailed) { |
| 204 | + // Non-auth error on retry — no point continuing auth retries. |
| 205 | + span.RecordError(retryErr) |
| 206 | + return retryErr |
| 207 | + } |
| 208 | + backoff *= 2 |
| 209 | + } |
| 210 | + |
| 211 | + // All retries exhausted with auth failures — update circuit breaker. |
| 212 | + breaker.recordFailure(r.cbThreshold, backendID) |
| 213 | + span.RecordError(lastErr) |
| 214 | + span.SetStatus(codes.Error, "auth retry exhausted") |
| 215 | + return lastErr |
| 216 | +} |
| 217 | + |
| 218 | +// retryResult is a generic helper that wraps withAuthRetry for operations that return a value, |
| 219 | +// eliminating the boilerplate of capturing a result variable in every BackendClient method. |
| 220 | +func retryResult[T any]( |
| 221 | + ctx context.Context, r *retryingBackendClient, backendID string, op func(context.Context) (T, error), |
| 222 | +) (T, error) { |
| 223 | + var result T |
| 224 | + err := r.withAuthRetry(ctx, backendID, func(ctx context.Context) error { |
| 225 | + var opErr error |
| 226 | + result, opErr = op(ctx) |
| 227 | + return opErr |
| 228 | + }) |
| 229 | + return result, err |
| 230 | +} |
| 231 | + |
| 232 | +// CallTool implements vmcp.BackendClient. |
| 233 | +func (r *retryingBackendClient) CallTool( |
| 234 | + ctx context.Context, |
| 235 | + target *vmcp.BackendTarget, |
| 236 | + toolName string, |
| 237 | + arguments map[string]any, |
| 238 | + meta map[string]any, |
| 239 | +) (*vmcp.ToolCallResult, error) { |
| 240 | + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.ToolCallResult, error) { |
| 241 | + return r.inner.CallTool(ctx, target, toolName, arguments, meta) |
| 242 | + }) |
| 243 | +} |
| 244 | + |
| 245 | +// ReadResource implements vmcp.BackendClient. |
| 246 | +func (r *retryingBackendClient) ReadResource( |
| 247 | + ctx context.Context, |
| 248 | + target *vmcp.BackendTarget, |
| 249 | + uri string, |
| 250 | +) (*vmcp.ResourceReadResult, error) { |
| 251 | + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.ResourceReadResult, error) { |
| 252 | + return r.inner.ReadResource(ctx, target, uri) |
| 253 | + }) |
| 254 | +} |
| 255 | + |
| 256 | +// GetPrompt implements vmcp.BackendClient. |
| 257 | +func (r *retryingBackendClient) GetPrompt( |
| 258 | + ctx context.Context, |
| 259 | + target *vmcp.BackendTarget, |
| 260 | + name string, |
| 261 | + arguments map[string]any, |
| 262 | +) (*vmcp.PromptGetResult, error) { |
| 263 | + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.PromptGetResult, error) { |
| 264 | + return r.inner.GetPrompt(ctx, target, name, arguments) |
| 265 | + }) |
| 266 | +} |
| 267 | + |
| 268 | +// ListCapabilities implements vmcp.BackendClient. |
| 269 | +func (r *retryingBackendClient) ListCapabilities( |
| 270 | + ctx context.Context, |
| 271 | + target *vmcp.BackendTarget, |
| 272 | +) (*vmcp.CapabilityList, error) { |
| 273 | + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.CapabilityList, error) { |
| 274 | + return r.inner.ListCapabilities(ctx, target) |
| 275 | + }) |
| 276 | +} |
0 commit comments