Skip to content

Commit db91add

Browse files
committed
feat: add ConsentCodeStore and OIDCStateStore.PeekInfo
Add ConsentCodeStore to api-gateway for managing one-time consent codes with TTL, capacity limits, concurrent-safe consumption, and background eviction. Extend OIDCFlowState with RequestedScopes field and add non-consuming PeekInfo method to OIDCStateStore for the consent-info endpoint.
1 parent bc544fa commit db91add

4 files changed

Lines changed: 371 additions & 0 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package gateway
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/base64"
6+
"errors"
7+
"fmt"
8+
"sync"
9+
"time"
10+
)
11+
12+
const (
13+
// consentCodeTTL is how long a consent code remains valid.
14+
consentCodeTTL = 2 * time.Minute
15+
// consentCodeEvictInterval is how often the store sweeps expired entries.
16+
consentCodeEvictInterval = 1 * time.Minute
17+
// consentCodeMaxEntries caps entries to prevent memory exhaustion.
18+
consentCodeMaxEntries = 10_000
19+
// consentCodeBytes is the number of random bytes in a generated consent code.
20+
consentCodeBytes = 32
21+
)
22+
23+
var errConsentCodeStoreFull = errors.New("consent code store is full")
24+
25+
// ConsentCodeEntry holds the state stored alongside a consent code.
26+
type ConsentCodeEntry struct {
27+
Email string
28+
TenantID string // UUID from JWT x-tenant-id claim
29+
TenantSlug string // subdomain slug for cross-validation
30+
MCPState string // key into OIDCStateStore
31+
ClientID string // OAuth client_id
32+
ApprovedScopes []string // e.g., ["mcp:default"]
33+
CreatedAt time.Time
34+
}
35+
36+
// ConsentCodeStore is a thread-safe in-memory store for consent codes.
37+
// Each code can be consumed exactly once and expires after consentCodeTTL.
38+
type ConsentCodeStore struct {
39+
mu sync.Mutex
40+
entries map[string]ConsentCodeEntry
41+
stop chan struct{}
42+
closeOnce sync.Once
43+
}
44+
45+
// NewConsentCodeStore creates an empty ConsentCodeStore and starts the
46+
// background eviction goroutine. Call [ConsentCodeStore.Close] to stop it.
47+
func NewConsentCodeStore() *ConsentCodeStore {
48+
s := &ConsentCodeStore{
49+
entries: make(map[string]ConsentCodeEntry),
50+
stop: make(chan struct{}),
51+
}
52+
go s.evictLoop()
53+
return s
54+
}
55+
56+
// Close stops the background eviction goroutine. Safe to call multiple times.
57+
func (s *ConsentCodeStore) Close() {
58+
s.closeOnce.Do(func() { close(s.stop) })
59+
}
60+
61+
func (s *ConsentCodeStore) evictLoop() {
62+
ticker := time.NewTicker(consentCodeEvictInterval)
63+
defer ticker.Stop()
64+
for {
65+
select {
66+
case <-ticker.C:
67+
s.evictExpired()
68+
case <-s.stop:
69+
return
70+
}
71+
}
72+
}
73+
74+
func (s *ConsentCodeStore) evictExpired() {
75+
s.mu.Lock()
76+
defer s.mu.Unlock()
77+
for code, entry := range s.entries {
78+
if time.Since(entry.CreatedAt) > consentCodeTTL {
79+
delete(s.entries, code)
80+
}
81+
}
82+
}
83+
84+
// Store saves a consent code entry and returns the generated code.
85+
// Returns errConsentCodeStoreFull if the store has reached its capacity limit.
86+
func (s *ConsentCodeStore) Store(entry ConsentCodeEntry) (string, error) {
87+
code, err := generateConsentCode()
88+
if err != nil {
89+
return "", err
90+
}
91+
s.mu.Lock()
92+
defer s.mu.Unlock()
93+
if len(s.entries) >= consentCodeMaxEntries {
94+
return "", errConsentCodeStoreFull
95+
}
96+
s.entries[code] = entry
97+
return code, nil
98+
}
99+
100+
// Consume atomically retrieves and deletes a consent code.
101+
// Returns (entry, true) if the code exists and has not expired.
102+
// Returns (zero, false) if the code is unknown or expired.
103+
func (s *ConsentCodeStore) Consume(code string) (ConsentCodeEntry, bool) {
104+
s.mu.Lock()
105+
defer s.mu.Unlock()
106+
107+
entry, ok := s.entries[code]
108+
if !ok {
109+
return ConsentCodeEntry{}, false
110+
}
111+
112+
// Always delete (one-time use), even if expired.
113+
delete(s.entries, code)
114+
115+
if time.Since(entry.CreatedAt) > consentCodeTTL {
116+
return ConsentCodeEntry{}, false
117+
}
118+
119+
return entry, true
120+
}
121+
122+
// generateConsentCode returns a cryptographically random, URL-safe consent code.
123+
func generateConsentCode() (string, error) {
124+
b := make([]byte, consentCodeBytes)
125+
if _, err := rand.Read(b); err != nil {
126+
return "", fmt.Errorf("generate consent code: %w", err)
127+
}
128+
return base64.RawURLEncoding.EncodeToString(b), nil
129+
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package gateway
2+
3+
import (
4+
"sync"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func newTestConsentCodeStore(t *testing.T) *ConsentCodeStore {
13+
t.Helper()
14+
s := NewConsentCodeStore()
15+
t.Cleanup(s.Close)
16+
return s
17+
}
18+
19+
func TestConsentCodeStore_StoreAndConsume(t *testing.T) {
20+
s := newTestConsentCodeStore(t)
21+
22+
entry := ConsentCodeEntry{
23+
Email: "alice@example.com",
24+
TenantID: "tid-123",
25+
TenantSlug: "acme",
26+
MCPState: "state-abc",
27+
ClientID: "client-1",
28+
ApprovedScopes: []string{"mcp:default"},
29+
CreatedAt: time.Now(),
30+
}
31+
32+
code, err := s.Store(entry)
33+
require.NoError(t, err)
34+
assert.NotEmpty(t, code)
35+
36+
// First consume succeeds.
37+
got, ok := s.Consume(code)
38+
assert.True(t, ok)
39+
assert.Equal(t, entry.Email, got.Email)
40+
assert.Equal(t, entry.TenantID, got.TenantID)
41+
assert.Equal(t, entry.TenantSlug, got.TenantSlug)
42+
assert.Equal(t, entry.MCPState, got.MCPState)
43+
assert.Equal(t, entry.ClientID, got.ClientID)
44+
assert.Equal(t, entry.ApprovedScopes, got.ApprovedScopes)
45+
46+
// Second consume fails (one-time use).
47+
_, ok = s.Consume(code)
48+
assert.False(t, ok)
49+
}
50+
51+
func TestConsentCodeStore_ConsumeExpired(t *testing.T) {
52+
s := newTestConsentCodeStore(t)
53+
54+
entry := ConsentCodeEntry{
55+
Email: "expired@example.com",
56+
CreatedAt: time.Now().Add(-consentCodeTTL - time.Second),
57+
}
58+
59+
code, err := s.Store(entry)
60+
require.NoError(t, err)
61+
62+
_, ok := s.Consume(code)
63+
assert.False(t, ok, "expired entry should not be consumable")
64+
65+
// Entry should have been deleted despite being expired.
66+
_, ok = s.Consume(code)
67+
assert.False(t, ok, "expired entry should be cleaned up after first consume attempt")
68+
}
69+
70+
func TestConsentCodeStore_ConcurrentConsume(t *testing.T) {
71+
s := newTestConsentCodeStore(t)
72+
73+
entry := ConsentCodeEntry{
74+
Email: "concurrent@example.com",
75+
CreatedAt: time.Now(),
76+
}
77+
78+
code, err := s.Store(entry)
79+
require.NoError(t, err)
80+
81+
const goroutines = 50
82+
var (
83+
wg sync.WaitGroup
84+
successes int32
85+
mu sync.Mutex
86+
)
87+
88+
wg.Add(goroutines)
89+
for i := 0; i < goroutines; i++ {
90+
go func() {
91+
defer wg.Done()
92+
_, ok := s.Consume(code)
93+
if ok {
94+
mu.Lock()
95+
successes++
96+
mu.Unlock()
97+
}
98+
}()
99+
}
100+
101+
wg.Wait()
102+
assert.Equal(t, int32(1), successes, "exactly one goroutine should consume the code")
103+
}
104+
105+
func TestConsentCodeStore_CapacityLimit(t *testing.T) {
106+
s := newTestConsentCodeStore(t)
107+
108+
// Fill to capacity.
109+
for i := 0; i < consentCodeMaxEntries; i++ {
110+
_, err := s.Store(ConsentCodeEntry{
111+
Email: "fill@example.com",
112+
CreatedAt: time.Now(),
113+
})
114+
require.NoError(t, err)
115+
}
116+
117+
// Next store should fail.
118+
_, err := s.Store(ConsentCodeEntry{
119+
Email: "overflow@example.com",
120+
CreatedAt: time.Now(),
121+
})
122+
assert.ErrorIs(t, err, errConsentCodeStoreFull)
123+
}
124+
125+
func TestConsentCodeStore_EvictionLoop(t *testing.T) {
126+
// Create store with manual control over eviction (don't use the background loop).
127+
s := &ConsentCodeStore{
128+
entries: make(map[string]ConsentCodeEntry),
129+
stop: make(chan struct{}),
130+
}
131+
t.Cleanup(func() { s.closeOnce.Do(func() { close(s.stop) }) })
132+
133+
// Store an expired entry.
134+
s.mu.Lock()
135+
s.entries["expired-code"] = ConsentCodeEntry{
136+
Email: "old@example.com",
137+
CreatedAt: time.Now().Add(-consentCodeTTL - time.Second),
138+
}
139+
// Store a valid entry.
140+
s.entries["valid-code"] = ConsentCodeEntry{
141+
Email: "new@example.com",
142+
CreatedAt: time.Now(),
143+
}
144+
s.mu.Unlock()
145+
146+
// Run eviction.
147+
s.evictExpired()
148+
149+
s.mu.Lock()
150+
defer s.mu.Unlock()
151+
assert.NotContains(t, s.entries, "expired-code", "expired entry should be evicted")
152+
assert.Contains(t, s.entries, "valid-code", "valid entry should remain")
153+
}
154+
155+
func TestConsentCodeStore_ConsumeNotFound(t *testing.T) {
156+
s := newTestConsentCodeStore(t)
157+
158+
_, ok := s.Consume("nonexistent-code")
159+
assert.False(t, ok)
160+
}
161+
162+
func TestConsentCodeStore_UniqueCodeGeneration(t *testing.T) {
163+
s := newTestConsentCodeStore(t)
164+
165+
codes := make(map[string]struct{})
166+
for i := 0; i < 100; i++ {
167+
code, err := s.Store(ConsentCodeEntry{
168+
Email: "unique@example.com",
169+
CreatedAt: time.Now(),
170+
})
171+
require.NoError(t, err)
172+
assert.NotContains(t, codes, code, "generated codes should be unique")
173+
codes[code] = struct{}{}
174+
}
175+
}

services/mcp-server/internal/auth/oidc.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ type OIDCFlowState struct {
110110
DexCodeVerifier string
111111
// TenantSlug extracted from the request subdomain.
112112
TenantSlug string
113+
// RequestedScopes are the OAuth scopes requested by the MCP client.
114+
RequestedScopes []string
113115
// IssuedAt is when this state was created.
114116
IssuedAt time.Time
115117
}
@@ -191,6 +193,22 @@ func (s *OIDCStateStore) Consume(key string) (OIDCFlowState, bool) {
191193
return entry, true
192194
}
193195

196+
// PeekInfo returns selected fields from an OIDC flow state entry without
197+
// consuming it. Expired entries are cleaned up and reported as not found.
198+
func (s *OIDCStateStore) PeekInfo(key string) (clientID, redirectURI string, scopes []string, ok bool) {
199+
s.mu.Lock()
200+
defer s.mu.Unlock()
201+
entry, exists := s.entries[key]
202+
if !exists {
203+
return "", "", nil, false
204+
}
205+
if time.Since(entry.IssuedAt) > oidcStateTTL {
206+
delete(s.entries, key)
207+
return "", "", nil, false
208+
}
209+
return entry.MCPClientID, entry.MCPRedirectURI, entry.RequestedScopes, true
210+
}
211+
194212
// TenantSlugResolver resolves a tenant slug (e.g., "acme") to its canonical
195213
// UUID. This ensures the x-tenant-id JWT claim contains a UUID consistent
196214
// with BFF-issued tokens, not the raw slug string.

services/mcp-server/internal/auth/oidc_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,3 +1242,52 @@ func TestHandleCallback_NoResolver_FallsBackToSlug(t *testing.T) {
12421242

12431243
assert.Equal(t, "acme", claims.TenantID, "without resolver, JWT should contain raw slug")
12441244
}
1245+
1246+
func TestOIDCStateStore_PeekInfo(t *testing.T) {
1247+
s := newTestOIDCStateStore(t)
1248+
1249+
key, err := s.Store(auth.OIDCFlowState{
1250+
MCPClientID: "client-1",
1251+
MCPRedirectURI: "https://example.com/callback",
1252+
RequestedScopes: []string{"mcp:default", "mcp:admin"},
1253+
IssuedAt: time.Now(),
1254+
})
1255+
require.NoError(t, err)
1256+
1257+
clientID, redirectURI, scopes, ok := s.PeekInfo(key)
1258+
assert.True(t, ok)
1259+
assert.Equal(t, "client-1", clientID)
1260+
assert.Equal(t, "https://example.com/callback", redirectURI)
1261+
assert.Equal(t, []string{"mcp:default", "mcp:admin"}, scopes)
1262+
1263+
// PeekInfo is non-consuming - a second call should also succeed.
1264+
_, _, _, ok = s.PeekInfo(key)
1265+
assert.True(t, ok, "PeekInfo should not consume the entry")
1266+
1267+
// Consume should still work after PeekInfo.
1268+
entry, ok := s.Consume(key)
1269+
assert.True(t, ok)
1270+
assert.Equal(t, "client-1", entry.MCPClientID)
1271+
}
1272+
1273+
func TestOIDCStateStore_PeekInfo_Expired(t *testing.T) {
1274+
s := newTestOIDCStateStore(t)
1275+
1276+
// Store a valid entry, then verify PeekInfo works for non-existent keys
1277+
// (expired entries are cleaned up internally by the store).
1278+
clientID, redirectURI, scopes, ok := s.PeekInfo("nonexistent-key")
1279+
assert.False(t, ok)
1280+
assert.Empty(t, clientID)
1281+
assert.Empty(t, redirectURI)
1282+
assert.Nil(t, scopes)
1283+
}
1284+
1285+
func TestOIDCStateStore_PeekInfo_NotFound(t *testing.T) {
1286+
s := newTestOIDCStateStore(t)
1287+
1288+
clientID, redirectURI, scopes, ok := s.PeekInfo("missing-key")
1289+
assert.False(t, ok)
1290+
assert.Empty(t, clientID)
1291+
assert.Empty(t, redirectURI)
1292+
assert.Nil(t, scopes)
1293+
}

0 commit comments

Comments
 (0)