Skip to content

Commit 2471b78

Browse files
committed
feat: derive token cache expiry from JWT claim
Currently the code will cache a JWT for 5 minutes. This creates a timing bug and forces unnecessary caching. Since tokens are cached for 5 minutes, it is possible that a token with 1 second of expiration left is cached. The cache code pulls the token out of the cache without any validation, the expired token will be used for the next 4 minutes and 59 seconds. This work sets the cache expiration time based on the token expiration time. This results in each token being cached once and evicted from the cache at expiration. In cases where the JWT needs to be valid beyond this check (e.g., for downstream requests), users can optionally choose to expire the cache before the JWT expiration time with `TokenExpiryBuffer` option. This option also ensures JWTs that are within the expiration window are treated as expired.
1 parent 0b83f5c commit 2471b78

9 files changed

Lines changed: 142 additions & 63 deletions

File tree

api_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"net/http/httptest"
8+
"reflect"
89
"strings"
910
"testing"
1011
"time"
@@ -708,3 +709,75 @@ func TestWrapMCPEndpointWithValidToken(t *testing.T) {
708709
t.Errorf("status = %d, want 200", rec.Code)
709710
}
710711
}
712+
713+
// TestTokenExpiryBuffer tests that the cache TTL is derived from the JWT exp claim.
714+
func TestTokenExpiryBuffer(t *testing.T) {
715+
tests := []struct {
716+
name string
717+
expirationFromNow time.Duration
718+
tokenExpiryBuffer time.Duration
719+
want *User
720+
wantErr string
721+
}{
722+
{
723+
name: "success",
724+
expirationFromNow: time.Minute,
725+
want: &User{Subject: "testuser"},
726+
},
727+
{
728+
name: "expired token",
729+
expirationFromNow: -time.Minute,
730+
wantErr: "authentication failed: failed to parse and validate token: token has invalid claims: token is expired",
731+
},
732+
{
733+
name: "token expires in buffer",
734+
tokenExpiryBuffer: 5 * time.Minute,
735+
expirationFromNow: time.Minute,
736+
wantErr: "authentication failed: token expired or expiring too soon",
737+
},
738+
}
739+
740+
for _, tt := range tests {
741+
t.Run(tt.name, func(t *testing.T) {
742+
cfg := &Config{
743+
Mode: "native",
744+
Provider: "hmac",
745+
Audience: "api://test",
746+
JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
747+
ServerURL: "https://test-server.com",
748+
Issuer: "https://test.example.com",
749+
TokenExpiryBuffer: tt.tokenExpiryBuffer,
750+
}
751+
srv, err := NewServer(cfg)
752+
if err != nil {
753+
t.Fatalf("NewServer: %v", err)
754+
}
755+
756+
// Create a valid HMAC token
757+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
758+
"sub": "testuser",
759+
"aud": "api://test",
760+
"iss": "https://test.example.com",
761+
"exp": time.Now().Add(tt.expirationFromNow).Unix(),
762+
})
763+
tokenString, _ := token.SignedString(cfg.JWTSecret)
764+
765+
user, err := srv.ValidateTokenCached(context.Background(), tokenString)
766+
if tt.wantErr != "" {
767+
if err == nil {
768+
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
769+
}
770+
if !strings.Contains(err.Error(), tt.wantErr) {
771+
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
772+
}
773+
return
774+
}
775+
if err != nil {
776+
t.Fatalf("unexpected error: %s", err)
777+
}
778+
if !reflect.DeepEqual(user, tt.want) {
779+
t.Errorf("expected user %v got user %v", tt.want, user)
780+
}
781+
})
782+
}
783+
}

config.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"strconv"
66
"strings"
7+
"time"
78

89
"github.com/tuannvm/oauth-mcp-proxy/provider"
910
)
@@ -41,6 +42,11 @@ type Config struct {
4142
// The issuer URL to use for issuer validation.
4243
// This should only be set if the issuer in the token differs from the standard issuer URL.
4344
ValidatorIssuer string
45+
46+
// TokenExpiryBuffer is subtracted from the JWT's exp claim to determine
47+
// effective cache expiry. Tokens with less than this duration remaining
48+
// are treated as expired and rejected. Defaults to 0 (no buffer).
49+
TokenExpiryBuffer time.Duration
4450
}
4551

4652
// Validate validates the configuration
@@ -265,6 +271,13 @@ func (b *ConfigBuilder) WithValidatorIssuer(validatorIssuer string) *ConfigBuild
265271
return b
266272
}
267273

274+
// WithTokenExpiryBuffer sets the buffer subtracted from JWT exp for cache expiry.
275+
// Tokens with less than d remaining until expiry are treated as expired and rejected.
276+
func (b *ConfigBuilder) WithTokenExpiryBuffer(d time.Duration) *ConfigBuilder {
277+
b.config.TokenExpiryBuffer = d
278+
return b
279+
}
280+
268281
// WithServerURL sets the full server URL directly
269282
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
270283
b.config.ServerURL = url

context_propagation_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func TestContextPropagation(t *testing.T) {
4141

4242
// Test 1: Normal context works
4343
ctx := context.Background()
44-
user, err := server.validator.ValidateToken(ctx, tokenString)
44+
user, _, err := server.validator.ValidateToken(ctx, tokenString)
4545
if err != nil {
4646
t.Fatalf("ValidateToken with normal context failed: %v", err)
4747
}
@@ -83,7 +83,7 @@ func TestContextPropagation(t *testing.T) {
8383

8484
// For HMAC validator (local-only), this still succeeds
8585
// because HMAC doesn't do I/O and doesn't check context cancellation
86-
user, err := server.validator.ValidateToken(ctx, tokenString)
86+
user, _, err := server.validator.ValidateToken(ctx, tokenString)
8787

8888
// HMAC validation is local-only, so it succeeds even with cancelled context
8989
if err != nil {
@@ -126,7 +126,7 @@ func TestContextPropagation(t *testing.T) {
126126
defer cancel()
127127

128128
// Validate with timeout context
129-
user, err := server.validator.ValidateToken(ctx, tokenString)
129+
user, _, err := server.validator.ValidateToken(ctx, tokenString)
130130
if err != nil {
131131
t.Fatalf("ValidateToken with timeout context failed: %v", err)
132132
}

docs/SECURITY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ oauth.WithOAuth(mux, &oauth.Config{
200200

201201
### Cache Behavior
202202

203-
- **Cache TTL:** 5 minutes (hardcoded in v0.1.0)
203+
- **Cache TTL:** No longer than the expiration on the JWT, expiry buffer
204+
configurable
204205
- **Cache scope:** Per Server instance
205206
- **Cache key:** SHA-256 hash of token
206207

integration_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestIntegration(t *testing.T) {
4747
tokenString, _ := token.SignedString(cfg.JWTSecret)
4848

4949
// Validate token using provider package directly
50-
user, err := validator.ValidateToken(context.Background(), tokenString)
50+
user, _, err := validator.ValidateToken(context.Background(), tokenString)
5151
if err != nil {
5252
t.Fatalf("ValidateToken failed: %v", err)
5353
}
@@ -90,7 +90,7 @@ func TestIntegration(t *testing.T) {
9090

9191
tokenString, _ := token.SignedString(rootCfg.JWTSecret)
9292

93-
user, err := validator.ValidateToken(context.Background(), tokenString)
93+
user, _, err := validator.ValidateToken(context.Background(), tokenString)
9494
if err != nil {
9595
t.Fatalf("ValidateToken after conversion failed: %v", err)
9696
}
@@ -276,7 +276,7 @@ func TestValidatorIntegration(t *testing.T) {
276276

277277
tokenString, _ := token.SignedString(cfg.JWTSecret)
278278

279-
user, err := v.ValidateToken(context.Background(), tokenString)
279+
user, _, err := v.ValidateToken(context.Background(), tokenString)
280280
if err != nil {
281281
t.Fatalf("ValidateToken failed: %v", err)
282282
}

middleware.go

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ package oauth
22

33
import (
44
"context"
5-
"crypto/sha256"
65
"fmt"
76
"log"
87
"net/http"
98
"strings"
10-
"time"
119

1210
"github.com/mark3labs/mcp-go/mcp"
1311
"github.com/mark3labs/mcp-go/server"
@@ -38,36 +36,13 @@ func (s *Server) Middleware() func(server.ToolHandlerFunc) server.ToolHandlerFun
3836
return nil, fmt.Errorf("authentication required: missing OAuth token")
3937
}
4038

41-
// Create token hash for caching
42-
tokenHash := fmt.Sprintf("%x", sha256.Sum256([]byte(tokenString)))
43-
44-
// Check cache first
45-
if cached, exists := s.cache.getCachedToken(tokenHash); exists {
46-
s.logger.Info("Using cached authentication for tool: %s (user: %s)", req.Params.Name, cached.User.Username)
47-
ctx = context.WithValue(ctx, userContextKey, cached.User)
48-
return next(ctx, req)
49-
}
50-
51-
// Log token hash for debugging (prevents sensitive data exposure)
52-
tokenHashFull := fmt.Sprintf("%x", sha256.Sum256([]byte(tokenString)))
53-
tokenHashPreview := tokenHashFull[:16] + "..."
54-
s.logger.Info("Validating token for tool %s (hash: %s)", req.Params.Name, tokenHashPreview)
55-
56-
// Validate token using configured provider (with request context for timeout/cancellation)
57-
user, err := s.validator.ValidateToken(ctx, tokenString)
39+
user, err := s.ValidateTokenCached(ctx, tokenString)
5840
if err != nil {
5941
s.logger.Error("Token validation failed for tool %s: %v", req.Params.Name, err)
6042
return nil, fmt.Errorf("authentication failed: %w", err)
6143
}
62-
63-
// Cache the validation result (expire in 5 minutes)
64-
expiresAt := time.Now().Add(5 * time.Minute)
65-
s.cache.setCachedToken(tokenHash, user, expiresAt)
66-
67-
// Add user to context for downstream handlers
6844
ctx = context.WithValue(ctx, userContextKey, user)
69-
s.logger.Info("Authenticated user %s for tool: %s (cached for 5 minutes)", user.Username, req.Params.Name)
70-
45+
s.logger.Info("Authenticated user %s for tool: %s", user.Username, req.Params.Name)
7146
return next(ctx, req)
7247
}
7348
}

oauth.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ func (s *Server) RegisterHandlers(mux *http.ServeMux) {
104104
// This is the core validation method that SDK adapters can use.
105105
//
106106
// The method:
107-
// 1. Checks token cache (5-minute TTL)
107+
// 1. Checks token cache (keyed by SHA-256 hash)
108108
// 2. Validates token using configured provider if not cached
109-
// 3. Caches validation result for future requests
109+
// 3. Caches validation result until JWT exp (minus TokenExpiryBuffer)
110110
// 4. Returns authenticated User or error
111111
//
112112
// This method is used internally by both WrapHandler and adapter middleware.
@@ -120,16 +120,22 @@ func (s *Server) ValidateTokenCached(ctx context.Context, token string) (*User,
120120

121121
s.logger.Info("Validating token (hash: %s...)", tokenHash[:16])
122122

123-
user, err := s.validator.ValidateToken(ctx, token)
123+
user, expiry, err := s.validator.ValidateToken(ctx, token)
124124
if err != nil {
125125
s.logger.Error("Token validation failed: %v", err)
126126
return nil, fmt.Errorf("authentication failed: %w", err)
127127
}
128128

129-
expiresAt := time.Now().Add(5 * time.Minute)
130-
s.cache.setCachedToken(tokenHash, user, expiresAt)
129+
if s.config != nil {
130+
expiry = expiry.Add(-s.config.TokenExpiryBuffer)
131+
}
132+
if time.Now().After(expiry) {
133+
return nil, fmt.Errorf("authentication failed: token expired or expiring too soon")
134+
}
135+
136+
s.cache.setCachedToken(tokenHash, user, expiry)
131137

132-
s.logger.Info("Authenticated user %s (cached for 5 minutes)", user.Username)
138+
s.logger.Info("Authenticated user %s (cache expires at %s)", user.Username, expiry.Format(time.RFC3339))
133139
return user, nil
134140
}
135141

provider/provider.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type Config struct {
4141

4242
// TokenValidator interface for OAuth token validation
4343
type TokenValidator interface {
44-
ValidateToken(ctx context.Context, token string) (*User, error)
44+
ValidateToken(ctx context.Context, token string) (*User, time.Time, error)
4545
Initialize(cfg *Config) error
4646
}
4747

@@ -80,7 +80,7 @@ func (v *HMACValidator) Initialize(cfg *Config) error {
8080
}
8181

8282
// ValidateToken validates JWT token using HMAC-SHA256
83-
func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
83+
func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (*User, time.Time, error) {
8484
// Note: ctx parameter accepted for interface compliance, but HMAC validation is local-only (no I/O)
8585
// Remove Bearer prefix if present
8686
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
@@ -94,26 +94,26 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (
9494
return []byte(v.secret), nil
9595
})
9696
if err != nil {
97-
return nil, fmt.Errorf("failed to parse and validate token: %w", err)
97+
return nil, time.Time{}, fmt.Errorf("failed to parse and validate token: %w", err)
9898
}
9999

100100
if !token.Valid {
101-
return nil, fmt.Errorf("invalid token")
101+
return nil, time.Time{}, fmt.Errorf("invalid token")
102102
}
103103

104104
claims, ok := token.Claims.(jwt.MapClaims)
105105
if !ok {
106-
return nil, fmt.Errorf("invalid token claims")
106+
return nil, time.Time{}, fmt.Errorf("invalid token claims")
107107
}
108108

109109
// Validate required claims including audience
110110
if err := validateTokenClaims(claims); err != nil {
111-
return nil, fmt.Errorf("token validation failed: %w", err)
111+
return nil, time.Time{}, fmt.Errorf("token validation failed: %w", err)
112112
}
113113

114114
// Validate audience claim for security
115115
if err := v.validateAudience(claims); err != nil {
116-
return nil, fmt.Errorf("audience validation failed: %w", err)
116+
return nil, time.Time{}, fmt.Errorf("audience validation failed: %w", err)
117117
}
118118

119119
// Extract user information
@@ -124,10 +124,17 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (
124124
}
125125

126126
if user.Subject == "" {
127-
return nil, fmt.Errorf("missing subject in token")
127+
return nil, time.Time{}, fmt.Errorf("missing subject in token")
128128
}
129129

130-
return user, nil
130+
// Extract expiry from already-parsed claims defaulting to 5 minutes in the
131+
// future.
132+
expiry := time.Now().Add(5 * time.Minute)
133+
if expVal, ok := claims["exp"].(float64); ok {
134+
expiry = time.Unix(int64(expVal), 0)
135+
}
136+
137+
return user, expiry, nil
131138
}
132139

133140
// validateAudience validates the audience claim matches the expected value
@@ -223,7 +230,7 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
223230
}
224231

225232
// ValidateToken validates JWT token using OIDC/JWKS
226-
func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
233+
func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, time.Time, error) {
227234
// Remove Bearer prefix if present
228235
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
229236

@@ -234,7 +241,7 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
234241
// go-oidc handles RSA signature validation, JWKS fetching, and key rotation
235242
idToken, err := v.verifier.Verify(ctx, tokenString)
236243
if err != nil {
237-
return nil, fmt.Errorf("token verification failed: %w", err)
244+
return nil, time.Time{}, fmt.Errorf("token verification failed: %w", err)
238245
}
239246

240247
// Extract claims from verified token
@@ -253,28 +260,28 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
253260
}
254261

255262
if err := idToken.Claims(&claims); err != nil {
256-
return nil, fmt.Errorf("failed to extract claims: %w", err)
263+
return nil, time.Time{}, fmt.Errorf("failed to extract claims: %w", err)
257264
}
258265

259266
// Extract raw claims for audience validation
260267
var rawClaims jwt.MapClaims
261268
if err := idToken.Claims(&rawClaims); err != nil {
262-
return nil, fmt.Errorf("failed to extract raw claims: %w", err)
269+
return nil, time.Time{}, fmt.Errorf("failed to extract raw claims: %w", err)
263270
}
264271

265272
// Run extra validation functions
266273
for i, fn := range v.TokenValidators {
267274
err := fn(rawClaims)
268275
if err != nil {
269-
return nil, fmt.Errorf("validation function %d failed with error: %w", i, err)
276+
return nil, time.Time{}, fmt.Errorf("validation function %d failed with error: %w", i, err)
270277
}
271278
}
272279

273280
return &User{
274281
Subject: claims.Subject,
275282
Username: claims.PreferredUsername,
276283
Email: claims.Email,
277-
}, nil
284+
}, idToken.Expiry, nil
278285
}
279286

280287
// validateAudience validates the audience claim matches the expected value for OIDC tokens

0 commit comments

Comments
 (0)