Skip to content

Commit 7ee83f1

Browse files
committed
feat: derive token cache expiry from JWT claim
Currently the code will cache a JWT for 5 minutes. This creates a timing bug and forces unnecessary caching. Since tokens are cached for 5 minutes, it is possible that a token with 1 second of expiration left is cached. The cache code pulls the token out of the cache without any validation, the expired token will be used for the next 4 minutes and 59 seconds. This work sets the cache expiration time based on the token expiration time. This results in each token being cached once and evicted from the cache at expiration. In cases where the JWT needs to be valid beyond this check (e.g., for downstream requests), users can optionally choose to expire the cache before the JWT expiration time with `TokenExpiryBuffer` option. This option also ensures JWTs that are within the expiration window are treated as expired. Signed-off-by: Tyler Drombosky <bowsky@gmail.com>
1 parent 0b83f5c commit 7ee83f1

9 files changed

Lines changed: 230 additions & 64 deletions

File tree

api_test.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"net/http/httptest"
8+
"reflect"
89
"strings"
910
"testing"
1011
"time"
@@ -708,3 +709,156 @@ func TestWrapMCPEndpointWithValidToken(t *testing.T) {
708709
t.Errorf("status = %d, want 200", rec.Code)
709710
}
710711
}
712+
713+
// TestValidateTokenCached tests that only non-expired tokens are cached.
714+
func TestValidateTokenCached(t *testing.T) {
715+
t.Parallel()
716+
717+
tests := []struct {
718+
name string
719+
expirationFromNow time.Duration
720+
tokenExpiryBuffer time.Duration
721+
want *User
722+
wantErr string
723+
}{
724+
{
725+
name: "success",
726+
expirationFromNow: time.Minute,
727+
want: &User{Subject: "testuser"},
728+
},
729+
{
730+
name: "expired token",
731+
expirationFromNow: -time.Minute,
732+
wantErr: "authentication failed: failed to parse and validate token: token has invalid claims: token is expired",
733+
},
734+
{
735+
name: "token expires in buffer",
736+
tokenExpiryBuffer: 5 * time.Minute,
737+
expirationFromNow: time.Minute,
738+
wantErr: "authentication failed: token expired or expiring too soon",
739+
},
740+
}
741+
742+
for _, tt := range tests {
743+
t.Run(tt.name, func(t *testing.T) {
744+
t.Parallel()
745+
746+
cfg := &Config{
747+
Mode: "native",
748+
Provider: "hmac",
749+
Audience: "api://test",
750+
JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
751+
ServerURL: "https://test-server.com",
752+
Issuer: "https://test.example.com",
753+
TokenExpiryBuffer: tt.tokenExpiryBuffer,
754+
}
755+
srv, err := NewServer(cfg)
756+
if err != nil {
757+
t.Fatalf("NewServer: %v", err)
758+
}
759+
760+
// Create a valid HMAC token
761+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
762+
"sub": "testuser",
763+
"aud": "api://test",
764+
"iss": "https://test.example.com",
765+
"exp": time.Now().Add(tt.expirationFromNow).Unix(),
766+
})
767+
tokenString, err := token.SignedString(cfg.JWTSecret)
768+
if err != nil {
769+
t.Fatalf("SignedString: %v", err)
770+
}
771+
772+
user, err := srv.ValidateTokenCached(context.Background(), tokenString)
773+
if tt.wantErr != "" {
774+
if err == nil {
775+
t.Fatalf("expected error %q, got nil", tt.wantErr)
776+
}
777+
if err.Error() != tt.wantErr {
778+
t.Errorf("expected error %q, got %q", tt.wantErr, err.Error())
779+
}
780+
return
781+
}
782+
if err != nil {
783+
t.Fatalf("unexpected error: %s", err)
784+
}
785+
if !reflect.DeepEqual(user, tt.want) {
786+
t.Errorf("expected user %v got user %v", tt.want, user)
787+
}
788+
})
789+
}
790+
}
791+
792+
// TestValidateTokenCached tests that the cache expires correctly.
793+
func TestValidateTokenCached_Expires(t *testing.T) {
794+
t.Parallel()
795+
796+
tests := []struct {
797+
name string
798+
expirationFromNow time.Duration
799+
tokenExpiryBuffer time.Duration
800+
wantErr string
801+
}{
802+
{
803+
name: "default expiry buffer",
804+
expirationFromNow: time.Second,
805+
wantErr: "authentication failed: failed to parse and validate token: token has invalid claims: token is expired",
806+
},
807+
{
808+
name: "custom expiry buffer",
809+
expirationFromNow: 5 * time.Second,
810+
tokenExpiryBuffer: 4 * time.Second,
811+
wantErr: "authentication failed: token expired or expiring too soon",
812+
},
813+
}
814+
815+
for _, tt := range tests {
816+
t.Run(tt.name, func(t *testing.T) {
817+
t.Parallel()
818+
819+
cfg := &Config{
820+
Mode: "native",
821+
Provider: "hmac",
822+
Audience: "api://test",
823+
JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
824+
ServerURL: "https://test-server.com",
825+
Issuer: "https://test.example.com",
826+
TokenExpiryBuffer: tt.tokenExpiryBuffer,
827+
}
828+
srv, err := NewServer(cfg)
829+
if err != nil {
830+
t.Fatalf("NewServer: %v", err)
831+
}
832+
833+
// Create a valid HMAC token
834+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
835+
"sub": "testuser",
836+
"aud": "api://test",
837+
"iss": "https://test.example.com",
838+
"exp": time.Now().Add(tt.expirationFromNow).Unix(),
839+
})
840+
tokenString, err := token.SignedString(cfg.JWTSecret)
841+
if err != nil {
842+
t.Fatalf("SignedString: %v", err)
843+
}
844+
845+
// Token is successfully verified and cached
846+
_, err = srv.ValidateTokenCached(context.Background(), tokenString)
847+
if err != nil {
848+
t.Fatalf("unexpected error %s", err)
849+
}
850+
851+
// Wait twice as long as the token should take to expire.
852+
time.Sleep(2 * (tt.expirationFromNow - tt.tokenExpiryBuffer))
853+
854+
_, err = srv.ValidateTokenCached(context.Background(), tokenString)
855+
if err == nil {
856+
t.Fatal("expected error, got nil")
857+
}
858+
859+
if err.Error() != tt.wantErr {
860+
t.Errorf("expected error %q, got %q", tt.wantErr, err.Error())
861+
}
862+
})
863+
}
864+
}

config.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"strconv"
66
"strings"
7+
"time"
78

89
"github.com/tuannvm/oauth-mcp-proxy/provider"
910
)
@@ -41,6 +42,11 @@ type Config struct {
4142
// The issuer URL to use for issuer validation.
4243
// This should only be set if the issuer in the token differs from the standard issuer URL.
4344
ValidatorIssuer string
45+
46+
// TokenExpiryBuffer is subtracted from the JWT's exp claim to determine
47+
// effective cache expiry. Tokens with less than this duration remaining
48+
// are treated as expired and rejected. Defaults to 0 (no buffer).
49+
TokenExpiryBuffer time.Duration
4450
}
4551

4652
// Validate validates the configuration
@@ -95,6 +101,10 @@ func (c *Config) Validate() error {
95101
return fmt.Errorf("proxy mode requires RedirectURIs or FixedRedirectURI")
96102
}
97103
}
104+
// Validate TokenExpiryBuffer is positive.
105+
if c.TokenExpiryBuffer < 0 {
106+
return fmt.Errorf("TokenExpiryBuffer must be >= 0")
107+
}
98108

99109
return nil
100110
}
@@ -265,6 +275,13 @@ func (b *ConfigBuilder) WithValidatorIssuer(validatorIssuer string) *ConfigBuild
265275
return b
266276
}
267277

278+
// WithTokenExpiryBuffer sets the buffer subtracted from JWT exp for cache expiry.
279+
// Tokens with less than d remaining until expiry are treated as expired and rejected.
280+
func (b *ConfigBuilder) WithTokenExpiryBuffer(d time.Duration) *ConfigBuilder {
281+
b.config.TokenExpiryBuffer = d
282+
return b
283+
}
284+
268285
// WithServerURL sets the full server URL directly
269286
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
270287
b.config.ServerURL = url

context_propagation_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func TestContextPropagation(t *testing.T) {
4141

4242
// Test 1: Normal context works
4343
ctx := context.Background()
44-
user, err := server.validator.ValidateToken(ctx, tokenString)
44+
user, _, err := server.validator.ValidateToken(ctx, tokenString)
4545
if err != nil {
4646
t.Fatalf("ValidateToken with normal context failed: %v", err)
4747
}
@@ -83,7 +83,7 @@ func TestContextPropagation(t *testing.T) {
8383

8484
// For HMAC validator (local-only), this still succeeds
8585
// because HMAC doesn't do I/O and doesn't check context cancellation
86-
user, err := server.validator.ValidateToken(ctx, tokenString)
86+
user, _, err := server.validator.ValidateToken(ctx, tokenString)
8787

8888
// HMAC validation is local-only, so it succeeds even with cancelled context
8989
if err != nil {
@@ -126,7 +126,7 @@ func TestContextPropagation(t *testing.T) {
126126
defer cancel()
127127

128128
// Validate with timeout context
129-
user, err := server.validator.ValidateToken(ctx, tokenString)
129+
user, _, err := server.validator.ValidateToken(ctx, tokenString)
130130
if err != nil {
131131
t.Fatalf("ValidateToken with timeout context failed: %v", err)
132132
}

docs/SECURITY.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ oauth.WithOAuth(mux, &oauth.Config{
200200

201201
### Cache Behavior
202202

203-
- **Cache TTL:** 5 minutes (hardcoded in v0.1.0)
203+
- **Cache TTL:** Until the JWT `exp` claim, minus the configured expiry buffer.
204+
Tokens inside the buffer are rejected, and tokens without `exp` fall back
205+
to a 5-minute cache lifetime.
204206
- **Cache scope:** Per Server instance
205207
- **Cache key:** SHA-256 hash of token
206208

integration_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestIntegration(t *testing.T) {
4747
tokenString, _ := token.SignedString(cfg.JWTSecret)
4848

4949
// Validate token using provider package directly
50-
user, err := validator.ValidateToken(context.Background(), tokenString)
50+
user, _, err := validator.ValidateToken(context.Background(), tokenString)
5151
if err != nil {
5252
t.Fatalf("ValidateToken failed: %v", err)
5353
}
@@ -90,7 +90,7 @@ func TestIntegration(t *testing.T) {
9090

9191
tokenString, _ := token.SignedString(rootCfg.JWTSecret)
9292

93-
user, err := validator.ValidateToken(context.Background(), tokenString)
93+
user, _, err := validator.ValidateToken(context.Background(), tokenString)
9494
if err != nil {
9595
t.Fatalf("ValidateToken after conversion failed: %v", err)
9696
}
@@ -276,7 +276,7 @@ func TestValidatorIntegration(t *testing.T) {
276276

277277
tokenString, _ := token.SignedString(cfg.JWTSecret)
278278

279-
user, err := v.ValidateToken(context.Background(), tokenString)
279+
user, _, err := v.ValidateToken(context.Background(), tokenString)
280280
if err != nil {
281281
t.Fatalf("ValidateToken failed: %v", err)
282282
}

middleware.go

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ package oauth
22

33
import (
44
"context"
5-
"crypto/sha256"
65
"fmt"
76
"log"
87
"net/http"
98
"strings"
10-
"time"
119

1210
"github.com/mark3labs/mcp-go/mcp"
1311
"github.com/mark3labs/mcp-go/server"
@@ -38,36 +36,13 @@ func (s *Server) Middleware() func(server.ToolHandlerFunc) server.ToolHandlerFun
3836
return nil, fmt.Errorf("authentication required: missing OAuth token")
3937
}
4038

41-
// Create token hash for caching
42-
tokenHash := fmt.Sprintf("%x", sha256.Sum256([]byte(tokenString)))
43-
44-
// Check cache first
45-
if cached, exists := s.cache.getCachedToken(tokenHash); exists {
46-
s.logger.Info("Using cached authentication for tool: %s (user: %s)", req.Params.Name, cached.User.Username)
47-
ctx = context.WithValue(ctx, userContextKey, cached.User)
48-
return next(ctx, req)
49-
}
50-
51-
// Log token hash for debugging (prevents sensitive data exposure)
52-
tokenHashFull := fmt.Sprintf("%x", sha256.Sum256([]byte(tokenString)))
53-
tokenHashPreview := tokenHashFull[:16] + "..."
54-
s.logger.Info("Validating token for tool %s (hash: %s)", req.Params.Name, tokenHashPreview)
55-
56-
// Validate token using configured provider (with request context for timeout/cancellation)
57-
user, err := s.validator.ValidateToken(ctx, tokenString)
39+
user, err := s.ValidateTokenCached(ctx, tokenString)
5840
if err != nil {
5941
s.logger.Error("Token validation failed for tool %s: %v", req.Params.Name, err)
60-
return nil, fmt.Errorf("authentication failed: %w", err)
42+
return nil, err
6143
}
62-
63-
// Cache the validation result (expire in 5 minutes)
64-
expiresAt := time.Now().Add(5 * time.Minute)
65-
s.cache.setCachedToken(tokenHash, user, expiresAt)
66-
67-
// Add user to context for downstream handlers
6844
ctx = context.WithValue(ctx, userContextKey, user)
69-
s.logger.Info("Authenticated user %s for tool: %s (cached for 5 minutes)", user.Username, req.Params.Name)
70-
45+
s.logger.Info("Authenticated user %s for tool: %s", user.Username, req.Params.Name)
7146
return next(ctx, req)
7247
}
7348
}

oauth.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ func (s *Server) RegisterHandlers(mux *http.ServeMux) {
104104
// This is the core validation method that SDK adapters can use.
105105
//
106106
// The method:
107-
// 1. Checks token cache (5-minute TTL)
107+
// 1. Checks token cache (keyed by SHA-256 hash)
108108
// 2. Validates token using configured provider if not cached
109-
// 3. Caches validation result for future requests
109+
// 3. Caches validation result until JWT exp (minus TokenExpiryBuffer)
110110
// 4. Returns authenticated User or error
111111
//
112112
// This method is used internally by both WrapHandler and adapter middleware.
@@ -120,16 +120,23 @@ func (s *Server) ValidateTokenCached(ctx context.Context, token string) (*User,
120120

121121
s.logger.Info("Validating token (hash: %s...)", tokenHash[:16])
122122

123-
user, err := s.validator.ValidateToken(ctx, token)
123+
user, expiry, err := s.validator.ValidateToken(ctx, token)
124124
if err != nil {
125125
s.logger.Error("Token validation failed: %v", err)
126126
return nil, fmt.Errorf("authentication failed: %w", err)
127127
}
128128

129-
expiresAt := time.Now().Add(5 * time.Minute)
130-
s.cache.setCachedToken(tokenHash, user, expiresAt)
129+
if s.config != nil {
130+
expiry = expiry.Add(-s.config.TokenExpiryBuffer)
131+
}
132+
if time.Now().After(expiry) {
133+
s.logger.Info("Token rejected: expires within buffer (hash: %s...)", tokenHash[:16])
134+
return nil, fmt.Errorf("authentication failed: token expired or expiring too soon")
135+
}
136+
137+
s.cache.setCachedToken(tokenHash, user, expiry)
131138

132-
s.logger.Info("Authenticated user %s (cached for 5 minutes)", user.Username)
139+
s.logger.Info("Authenticated user %s (cache expires at %s)", user.Username, expiry.Format(time.RFC3339))
133140
return user, nil
134141
}
135142

0 commit comments

Comments
 (0)