-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathconfig.go
More file actions
390 lines (336 loc) · 11 KB
/
config.go
File metadata and controls
390 lines (336 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
package oauth
import (
"fmt"
"strconv"
"strings"
"github.com/tuannvm/oauth-mcp-proxy/provider"
)
// Config holds OAuth configuration
type Config struct {
// OAuth settings
Mode string // "native" or "proxy"
Provider string // "hmac", "okta", "google", "azure"
RedirectURIs string // Redirect URIs allowlist (single or comma-separated)
FixedRedirectURI string // Optional fixed redirect URI used for proxying callbacks
AllowedClientRedirectDomains string // Optional comma-separated list of domain suffixes allowed for client redirect URIs in fixed redirect mode (in addition to localhost)
// OIDC configuration
Issuer string
Audience string
ClientID string
ClientSecret string
Scopes []string // OIDC scopes
// Server configuration
ServerURL string // Full URL of the MCP server
// Security
JWTSecret []byte // For HMAC provider and state signing
// Optional - Logging
// Logger allows custom logging implementation. If nil, uses default logger
// that outputs to log.Printf with level prefixes ([INFO], [ERROR], etc.).
// Implement the Logger interface (Debug, Info, Warn, Error methods) to
// integrate with your application's logging system (e.g., zap, logrus).
Logger Logger
SkipAudienceCheck bool // whether to skip audience validation
// The issuer URL to use for issuer validation.
// This should only be set if the issuer in the token differs from the standard issuer URL.
ValidatorIssuer string
}
// Validate validates the configuration
func (c *Config) Validate() error {
// Auto-detect mode if not specified
if c.Mode == "" {
if c.ClientID != "" {
c.Mode = "proxy"
} else {
c.Mode = "native"
}
}
// Validate mode
if c.Mode != "native" && c.Mode != "proxy" {
return fmt.Errorf("mode must be 'native' or 'proxy', got: %s", c.Mode)
}
// Validate provider
if c.Provider == "" {
return fmt.Errorf("provider is required")
}
// Validate provider-specific requirements
switch c.Provider {
case "hmac":
if len(c.JWTSecret) == 0 {
return fmt.Errorf("JWTSecret is required for HMAC provider")
}
case "okta", "google", "azure":
if c.Issuer == "" {
return fmt.Errorf("issuer is required for OIDC provider")
}
// Enforce issuer URL validation for OIDC providers to prevent MITM
if err := ValidateIssuerURL(c.Issuer); err != nil {
return fmt.Errorf("invalid issuer URL: %w", err)
}
default:
return fmt.Errorf("unknown provider: %s (supported: hmac, okta, google, azure)", c.Provider)
}
// Validate audience
if c.Audience == "" {
return fmt.Errorf("audience is required")
}
// Validate proxy mode requirements
if c.Mode == "proxy" {
if c.ClientID == "" {
return fmt.Errorf("proxy mode requires ClientID")
}
if c.ServerURL == "" {
return fmt.Errorf("proxy mode requires ServerURL")
}
if c.RedirectURIs == "" && c.FixedRedirectURI == "" {
return fmt.Errorf("proxy mode requires RedirectURIs or FixedRedirectURI")
}
// Validate redirect URIs for security
if c.RedirectURIs != "" {
validCount := 0
for _, uri := range strings.Split(c.RedirectURIs, ",") {
uri = strings.TrimSpace(uri)
if uri == "" {
continue
}
if err := ValidateRedirectURI(uri); err != nil {
return fmt.Errorf("invalid redirect URI '%s': %w", uri, err)
}
validCount++
}
if validCount == 0 && c.FixedRedirectURI == "" {
return fmt.Errorf("proxy mode requires at least one valid redirect URI")
}
}
// Validate fixed redirect URI if set
if c.FixedRedirectURI != "" {
if err := ValidateRedirectURI(c.FixedRedirectURI); err != nil {
return fmt.Errorf("invalid fixed redirect URI: %w", err)
}
}
}
return nil
}
// SetupOAuth initializes OAuth validation and sets up OAuth configuration.
//
// Deprecated: Use WithOAuth() for new code, which provides complete OAuth setup
// including middleware and HTTP handlers. This function only creates a validator
// and requires manual wiring.
//
// Modern usage:
//
// oauthOption, _ := oauth.WithOAuth(mux, &oauth.Config{...})
// mcpServer := server.NewMCPServer("name", "1.0.0", oauthOption)
func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
logger := cfg.Logger
if logger == nil {
logger = &defaultLogger{}
}
// Initialize OAuth provider based on configuration
validator, err := createValidator(cfg, logger)
if err != nil {
return nil, fmt.Errorf("failed to create OAuth validator: %w", err)
}
logger.Info("OAuth authentication enabled with provider: %s", cfg.Provider)
return validator, nil
}
// createValidator creates the appropriate token validator based on configuration
func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) {
// Convert root Config to provider.Config
providerCfg := &provider.Config{
Provider: cfg.Provider,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
JWTSecret: cfg.JWTSecret,
Logger: logger,
SkipAudienceCheck: cfg.SkipAudienceCheck,
ValidatorIssuer: cfg.ValidatorIssuer,
}
var validator provider.TokenValidator
switch cfg.Provider {
case "hmac":
validator = &provider.HMACValidator{}
case "okta", "google", "azure":
validator = &provider.OIDCValidator{}
default:
return nil, fmt.Errorf("unknown OAuth provider: %s", cfg.Provider)
}
if err := validator.Initialize(providerCfg); err != nil {
return nil, err
}
return validator, nil
}
// CreateOAuth2Handler creates a new OAuth2 handler for HTTP endpoints
func CreateOAuth2Handler(cfg *Config, version string, logger Logger) *OAuth2Handler {
if logger == nil {
logger = &defaultLogger{}
}
oauth2Config := NewOAuth2ConfigFromConfig(cfg, version)
return NewOAuth2Handler(oauth2Config, logger)
}
// ConfigBuilder provides a fluent API for constructing OAuth Config
type ConfigBuilder struct {
config *Config
host string
port string
useTLS bool
}
// NewConfigBuilder creates a new ConfigBuilder
func NewConfigBuilder() *ConfigBuilder {
return &ConfigBuilder{
config: &Config{},
host: "localhost",
port: "8080",
}
}
// WithMode sets the OAuth mode ("native" or "proxy")
func (b *ConfigBuilder) WithMode(mode string) *ConfigBuilder {
b.config.Mode = mode
return b
}
// WithProvider sets the OAuth provider ("hmac", "okta", "google", "azure")
func (b *ConfigBuilder) WithProvider(provider string) *ConfigBuilder {
b.config.Provider = provider
return b
}
// WithRedirectURIs sets the redirect URIs
func (b *ConfigBuilder) WithRedirectURIs(uris string) *ConfigBuilder {
b.config.RedirectURIs = uris
return b
}
// WithFixedRedirectURI sets the fixed redirect URI used for proxying callbacks
func (b *ConfigBuilder) WithFixedRedirectURI(uri string) *ConfigBuilder {
b.config.FixedRedirectURI = uri
return b
}
// WithAllowedClientRedirectDomains sets allowed client redirect domains
func (b *ConfigBuilder) WithAllowedClientRedirectDomains(domains string) *ConfigBuilder {
b.config.AllowedClientRedirectDomains = domains
return b
}
// WithIssuer sets the OIDC issuer
func (b *ConfigBuilder) WithIssuer(issuer string) *ConfigBuilder {
b.config.Issuer = issuer
return b
}
// WithAudience sets the audience
func (b *ConfigBuilder) WithAudience(audience string) *ConfigBuilder {
b.config.Audience = audience
return b
}
// WithClientID sets the client ID
func (b *ConfigBuilder) WithClientID(clientID string) *ConfigBuilder {
b.config.ClientID = clientID
return b
}
// WithClientSecret sets the client secret
func (b *ConfigBuilder) WithClientSecret(secret string) *ConfigBuilder {
b.config.ClientSecret = secret
return b
}
// WithJWTSecret sets the JWT secret
func (b *ConfigBuilder) WithJWTSecret(secret []byte) *ConfigBuilder {
b.config.JWTSecret = secret
return b
}
// WithScopes sets the OIDC scopes
func (b *ConfigBuilder) WithScopes(scopes []string) *ConfigBuilder {
b.config.Scopes = scopes
return b
}
// WithLogger sets the logger
func (b *ConfigBuilder) WithLogger(logger Logger) *ConfigBuilder {
b.config.Logger = logger
return b
}
// WithSkipAudienceCheck sets audience check toggle
func (b *ConfigBuilder) WithSkipAudienceCheck(skipAudienceCheck bool) *ConfigBuilder {
b.config.SkipAudienceCheck = skipAudienceCheck
return b
}
// WithValidatorIssuer sets the validator issuer URL
func (b *ConfigBuilder) WithValidatorIssuer(validatorIssuer string) *ConfigBuilder {
b.config.ValidatorIssuer = validatorIssuer
return b
}
// WithServerURL sets the full server URL directly
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
b.config.ServerURL = url
return b
}
// WithHost sets the server host (used to construct ServerURL if not set)
func (b *ConfigBuilder) WithHost(host string) *ConfigBuilder {
b.host = host
return b
}
// WithPort sets the server port (used to construct ServerURL if not set)
func (b *ConfigBuilder) WithPort(port string) *ConfigBuilder {
b.port = port
return b
}
// WithTLS enables HTTPS scheme (used to construct ServerURL if not set)
func (b *ConfigBuilder) WithTLS(useTLS bool) *ConfigBuilder {
b.useTLS = useTLS
return b
}
// Build constructs and validates the Config
func (b *ConfigBuilder) Build() (*Config, error) {
if b.config.ServerURL == "" {
b.config.ServerURL = AutoDetectServerURL(b.host, b.port, b.useTLS)
}
if err := b.config.Validate(); err != nil {
return nil, err
}
return b.config, nil
}
// AutoDetectServerURL constructs a server URL from components
func AutoDetectServerURL(host, port string, useTLS bool) string {
scheme := "http"
if useTLS {
scheme = "https"
}
return fmt.Sprintf("%s://%s:%s", scheme, host, port)
}
// FromEnv creates a Config from environment variables
func FromEnv() (*Config, error) {
serverURL := getEnv("MCP_URL", "")
host := getEnv("MCP_HOST", "localhost")
port := getEnv("MCP_PORT", "8080")
useTLS := getEnv("HTTPS_CERT_FILE", "") != "" && getEnv("HTTPS_KEY_FILE", "") != ""
if serverURL == "" {
serverURL = AutoDetectServerURL(host, port, useTLS)
}
jwtSecret := getEnv("JWT_SECRET", "")
scopes := []string{}
scopesEnv := getEnv("OIDC_SCOPES", "")
if scopesEnv != "" {
scopes = strings.Split(scopesEnv, " ")
}
return NewConfigBuilder().
WithMode(getEnv("OAUTH_MODE", "")).
WithProvider(getEnv("OAUTH_PROVIDER", "")).
WithRedirectURIs(getEnv("OAUTH_REDIRECT_URIS", "")).
WithFixedRedirectURI(getEnv("OAUTH_FIXED_REDIRECT_URI", "")).
WithAllowedClientRedirectDomains(getEnv("OAUTH_ALLOWED_CLIENT_REDIRECT_DOMAINS", "")).
WithIssuer(getEnv("OIDC_ISSUER", "")).
WithAudience(getEnv("OIDC_AUDIENCE", "")).
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
WithClientSecret(getEnv("OIDC_CLIENT_SECRET", "")).
WithScopes(scopes).
WithSkipAudienceCheck(parseBoolEnv("OIDC_SKIP_AUDIENCE_CHECK", false)).
WithValidatorIssuer(getEnv("OIDC_VALIDATOR_ISSUER", "")).
WithServerURL(serverURL).
WithJWTSecret([]byte(jwtSecret)).
Build()
}
// parseBoolEnv parses a boolean environment variable
func parseBoolEnv(key string, defaultVal bool) bool {
val := getEnv(key, "")
if val == "" {
return defaultVal
}
parsed, err := strconv.ParseBool(val)
if err != nil {
return defaultVal
}
return parsed
}