Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import (
// Config holds OAuth configuration
type Config struct {
// OAuth settings
Mode string // "native" or "proxy"
Provider string // "hmac", "okta", "google", "azure"
RedirectURIs string // Redirect URIs allowlist (single or comma-separated)
FixedRedirectURI string // Optional fixed redirect URI used for proxying callbacks
Mode string // "native" or "proxy"
Provider string // "hmac", "okta", "google", "azure"
RedirectURIs string // Redirect URIs allowlist (single or comma-separated)
FixedRedirectURI string // Optional fixed redirect URI used for proxying callbacks
AllowedClientRedirectDomains string // Optional comma-separated list of domain suffixes allowed for client redirect URIs in fixed redirect mode (in addition to localhost)

// OIDC configuration
Issuer string
Expand Down Expand Up @@ -204,6 +205,12 @@ func (b *ConfigBuilder) WithFixedRedirectURI(uri string) *ConfigBuilder {
return b
}

// WithAllowedClientRedirectDomains sets allowed client redirect domains
func (b *ConfigBuilder) WithAllowedClientRedirectDomains(domains string) *ConfigBuilder {
b.config.AllowedClientRedirectDomains = domains
return b
}

// WithIssuer sets the OIDC issuer
func (b *ConfigBuilder) WithIssuer(issuer string) *ConfigBuilder {
b.config.Issuer = issuer
Expand Down Expand Up @@ -327,6 +334,7 @@ func FromEnv() (*Config, error) {
WithProvider(getEnv("OAUTH_PROVIDER", "")).
WithRedirectURIs(getEnv("OAUTH_REDIRECT_URIS", "")).
WithFixedRedirectURI(getEnv("OAUTH_FIXED_REDIRECT_URI", "")).
WithAllowedClientRedirectDomains(getEnv("OAUTH_ALLOWED_CLIENT_REDIRECT_DOMAINS", "")).
WithIssuer(getEnv("OIDC_ISSUER", "")).
WithAudience(getEnv("OIDC_AUDIENCE", "")).
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
Expand Down
98 changes: 72 additions & 26 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ type OAuth2Config struct {
// FixedRedirectURI is an optional fixed redirect URI used when proxying callbacks
FixedRedirectURI string

// AllowedClientRedirectDomains is an optional comma-separated list of domain suffixes
// that are allowed for client redirect URIs in fixed redirect mode (in addition to localhost).
AllowedClientRedirectDomains string

// OIDC configuration
Issuer string
Audience string
Expand Down Expand Up @@ -186,22 +190,23 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
}

return &OAuth2Config{
Enabled: true,
Mode: cfg.Mode,
Provider: cfg.Provider,
RedirectURIs: cfg.RedirectURIs,
FixedRedirectURI: cfg.FixedRedirectURI,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Scopes: scopes,
MCPHost: mcpHost,
MCPPort: mcpPort,
MCPURL: mcpURL,
Scheme: scheme,
Version: version,
stateSigningKey: cfg.JWTSecret,
Enabled: true,
Mode: cfg.Mode,
Provider: cfg.Provider,
RedirectURIs: cfg.RedirectURIs,
FixedRedirectURI: cfg.FixedRedirectURI,
AllowedClientRedirectDomains: cfg.AllowedClientRedirectDomains,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Scopes: scopes,
MCPHost: mcpHost,
MCPPort: mcpPort,
MCPURL: mcpURL,
Scheme: scheme,
Version: version,
stateSigningKey: cfg.JWTSecret,
}
}

Expand Down Expand Up @@ -353,15 +358,15 @@ func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request)
return
}

// Security: For fixed redirect mode, only allow localhost or loopback addresses
// This prevents open redirect attacks while still supporting development tools
if !isLocalhostURI(clientRedirectURI) {
h.logger.Warn("SECURITY: Fixed redirect mode only allows localhost URIs, rejecting: %s from %s", clientRedirectURI, r.RemoteAddr)
http.Error(w, "Fixed redirect mode only allows localhost redirect URIs for security. Use allowlist mode for production.", http.StatusBadRequest)
// Security: For fixed redirect mode, only allow localhost or explicitly configured domain suffixes.
// This prevents open redirect attacks while still supporting development tools and trusted hosts.
if !h.isAllowedClientRedirectURI(clientRedirectURI) {
h.logger.Warn("SECURITY: Fixed redirect mode only allows localhost or configured domains, rejecting: %s from %s", clientRedirectURI, r.RemoteAddr)
http.Error(w, fmt.Sprintf("Invalid redirect_uri for fixed redirect mode: %s", clientRedirectURI), http.StatusBadRequest)
return
}
redirectURI = strings.TrimSpace(h.config.FixedRedirectURI)
h.logger.Info("OAuth2: Validated localhost redirect URI for proxy: %s", clientRedirectURI)
h.logger.Info("OAuth2: Validated client redirect URI for proxy: %s", clientRedirectURI)
// For fixed redirect mode, create signed state with client redirect URI
// Create state data with redirect URI
stateData := map[string]string{
Expand Down Expand Up @@ -466,14 +471,14 @@ func (h *OAuth2Handler) HandleCallback(w http.ResponseWriter, r *http.Request) {

if hasState && hasRedirect {
// Re-validate redirect URI for defense in depth
// Even though state is HMAC-signed, validate the redirect URI is localhost
if !isLocalhostURI(originalRedirectURI) {
h.logger.Warn("SECURITY: Callback redirect URI is not localhost (possible key compromise): %s", originalRedirectURI)
// Even though state is HMAC-signed, validate the redirect URI is localhost or an allowed domain
if !h.isAllowedClientRedirectURI(originalRedirectURI) {
h.logger.Warn("SECURITY: Callback redirect URI is not allowed (possible key compromise): %s", originalRedirectURI)
http.Error(w, "Invalid redirect URI in state", http.StatusBadRequest)
return
}

h.logger.Info("OAuth2: State verified, proxying callback to localhost client: %s", originalRedirectURI)
h.logger.Info("OAuth2: State verified, proxying callback to client redirect URI: %s", originalRedirectURI)

// Build proxy callback URL
proxyURL := fmt.Sprintf("%s?code=%s&state=%s", originalRedirectURI, code, originalState)
Expand Down Expand Up @@ -801,6 +806,47 @@ func isLocalhostURI(uri string) bool {
return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1"
}

// isAllowedClientRedirectURI checks whether a client redirect URI is allowed in fixed redirect mode.
func (h *OAuth2Handler) isAllowedClientRedirectURI(uri string) bool {
Comment thread
WhammyLeaf marked this conversation as resolved.
// Always allow localhost URIs for development tools
if isLocalhostURI(uri) {
return true
}

// For non-localhost URIs, require explicit domain suffix configuration
if h.config.AllowedClientRedirectDomains == "" {
return false
}

parsedURI, err := url.Parse(uri)
if err != nil {
return false
}

// Only allow HTTPS for non-localhost URIs
if parsedURI.Scheme != "https" {
return false
}

host := strings.ToLower(parsedURI.Hostname())
if host == "" {
return false
}

// Check if host matches any configured suffix (exact match or subdomain)
for _, suffix := range strings.Split(h.config.AllowedClientRedirectDomains, ",") {
suffix = strings.TrimSpace(strings.ToLower(suffix))
if suffix == "" {
continue
}
if host == suffix || strings.HasSuffix(host, "."+suffix) {
Comment thread
WhammyLeaf marked this conversation as resolved.
return true
}
}

return false
}

// isValidRedirectURI validates redirect URI against allowlist for security
func (h *OAuth2Handler) isValidRedirectURI(uri string) bool {
if h.config.RedirectURIs == "" {
Expand Down
92 changes: 92 additions & 0 deletions security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,98 @@ func TestRedirectURIValidation(t *testing.T) {
}
}

func TestIsAllowedClientRedirectURI(t *testing.T) {
tests := []struct {
name string
allowed string
uri string
isAllowed bool
}{
{
name: "Localhost HTTP allowed without domains",
allowed: "",
uri: "http://localhost:8080/callback",
isAllowed: true,
},
{
name: "Localhost IPv4 allowed without domains",
allowed: "",
uri: "http://127.0.0.1:3000/callback",
isAllowed: true,
},
{
name: "Localhost IPv6 allowed without domains",
allowed: "",
uri: "http://[::1]:9000/callback",
isAllowed: true,
},
{
name: "Non-localhost without allowed domains rejected",
allowed: "",
uri: "https://example.com/callback",
isAllowed: false,
},
{
name: "HTTPS exact domain match allowed",
allowed: "example.com",
uri: "https://example.com/callback",
isAllowed: true,
},
{
name: "HTTPS subdomain match allowed",
allowed: "example.com",
uri: "https://app.example.com/callback",
isAllowed: true,
},
{
name: "Multiple domains with spaces allowed",
allowed: "example.com, dummy.com ",
uri: "https://client1.dummy.com/proxy/40073/callback",
isAllowed: true,
},
{
name: "Partial suffix does not match",
allowed: "example.com",
uri: "https://evil-example.com/callback",
isAllowed: false,
},
{
name: "HTTP non-localhost rejected even if domain configured",
allowed: "example.com",
uri: "http://example.com/callback",
isAllowed: false,
},
{
name: "Non-HTTPS scheme rejected",
allowed: "example.com",
uri: "custom://example.com/callback",
isAllowed: false,
},
{
name: "Invalid URI rejected",
allowed: "example.com",
uri: "not-a-valid-uri",
isAllowed: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := &OAuth2Handler{
config: &OAuth2Config{
AllowedClientRedirectDomains: tt.allowed,
},
logger: &defaultLogger{},
}

got := handler.isAllowedClientRedirectURI(tt.uri)
if got != tt.isAllowed {
t.Errorf("isAllowedClientRedirectURI(%q) = %v, want %v (allowed=%q)", tt.uri, got, tt.isAllowed, tt.allowed)
}
})
}
}

func TestOAuthParameterValidation(t *testing.T) {
handler := &OAuth2Handler{logger: &defaultLogger{}}

Expand Down
Loading