From 093093b129ae6e67bfe4eb8a6ed19c070510a4cc Mon Sep 17 00:00:00 2001 From: Algis Dumbris Date: Thu, 4 Jun 2026 20:25:29 +0300 Subject: [PATCH] feat(teams): per-user OAuth connect flow (Path B) broker connector (spec 074, MCP-1038) Adds internal/teams/broker.OAuthConnector implementing Path B of the upstream token-brokering spec: a per-user authorization-code + PKCE connect flow against an upstream AS that does not support token exchange. - BuildAuthorizationURL: PKCE (S256) verifier/challenge + opaque per-user state, tracked in-memory with a TTL; requires explicit AS consent (confused-deputy avoidance, FR-011). - Complete: validates state (unknown/expired/one-time), exchanges the code via the bound verifier, stores the per-user credential encrypted with ObtainedVia=connect_flow (FR-010). - Deny: clears the pending flow, stores nothing (denied consent). - Refresh: transparent refresh_token grant, preserving a non-rotated refresh token (FR-012). Config: AuthBrokerConfig gains authorization_endpoint, required for the oauth_connect mode. TDD: PKCE URL build + uniqueness, PKCE roundtrip, invalid/expired/one-time state, token-endpoint error, denied consent stores nothing, refresh path. Reuses oauth.GenerateServerKey for the store key; server edition only. Related #1038 Co-Authored-By: Paperclip --- internal/config/auth_broker.go | 9 + internal/config/auth_broker_test.go | 36 +- internal/teams/broker/oauth_connector.go | 380 +++++++++++++++++ internal/teams/broker/oauth_connector_test.go | 400 ++++++++++++++++++ 4 files changed, 824 insertions(+), 1 deletion(-) create mode 100644 internal/teams/broker/oauth_connector.go create mode 100644 internal/teams/broker/oauth_connector_test.go diff --git a/internal/config/auth_broker.go b/internal/config/auth_broker.go index b05cdb0f..63f92f99 100644 --- a/internal/config/auth_broker.go +++ b/internal/config/auth_broker.go @@ -31,6 +31,10 @@ type AuthBrokerConfig struct { Mode string `json:"mode" mapstructure:"mode"` // TokenEndpoint is the IdP token endpoint used to mint the upstream credential. TokenEndpoint string `json:"token_endpoint" mapstructure:"token_endpoint"` + // AuthorizationEndpoint is the upstream AS authorize URL the user is + // redirected to for consent. Required for the oauth_connect mode (Path B, + // spec 074 FR-011); unused by token_exchange/entra_obo. + AuthorizationEndpoint string `json:"authorization_endpoint,omitempty" mapstructure:"authorization_endpoint"` // Resource is the RFC 8707 audience the resulting token is scoped to. Resource string `json:"resource,omitempty" mapstructure:"resource"` // Scopes requested for the upstream credential. @@ -77,6 +81,11 @@ func (a *AuthBrokerConfig) Validate() error { if a.TokenEndpoint == "" { return fmt.Errorf("auth_broker.token_endpoint is required") } + // The connect flow (Path B) additionally needs the upstream authorize URL + // to redirect the user to for consent. + if a.Mode == AuthBrokerModeOAuthConnect && a.AuthorizationEndpoint == "" { + return fmt.Errorf("auth_broker.authorization_endpoint is required for mode %q", AuthBrokerModeOAuthConnect) + } return nil } diff --git a/internal/config/auth_broker_test.go b/internal/config/auth_broker_test.go index d812d687..ee146088 100644 --- a/internal/config/auth_broker_test.go +++ b/internal/config/auth_broker_test.go @@ -43,6 +43,35 @@ func TestAuthBrokerConfig_ApplyDefaults(t *testing.T) { }) } +func TestAuthBroker_OAuthConnectRequiresAuthorizationEndpoint(t *testing.T) { + t.Run("missing authorization_endpoint is rejected", func(t *testing.T) { + b := &AuthBrokerConfig{ + Mode: AuthBrokerModeOAuthConnect, + TokenEndpoint: "https://idp/token", + } + err := b.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization_endpoint") + }) + + t.Run("authorization_endpoint present is accepted", func(t *testing.T) { + b := &AuthBrokerConfig{ + Mode: AuthBrokerModeOAuthConnect, + AuthorizationEndpoint: "https://idp/authorize", + TokenEndpoint: "https://idp/token", + } + require.NoError(t, b.Validate()) + }) + + t.Run("authorization_endpoint is not required for token_exchange", func(t *testing.T) { + b := &AuthBrokerConfig{ + Mode: AuthBrokerModeTokenExchange, + TokenEndpoint: "https://idp/token", + } + require.NoError(t, b.Validate()) + }) +} + func TestAuthBroker_ValidHTTPBroker(t *testing.T) { server := &ServerConfig{ Name: "github", @@ -134,9 +163,14 @@ func TestAuthBroker_MissingRequiredFields(t *testing.T) { func TestAuthBroker_AllValidModes(t *testing.T) { for _, mode := range []string{AuthBrokerModeTokenExchange, AuthBrokerModeEntraOBO, AuthBrokerModeOAuthConnect} { t.Run(mode, func(t *testing.T) { + broker := &AuthBrokerConfig{Mode: mode, TokenEndpoint: "https://idp/token"} + // The connect flow additionally requires the authorize endpoint. + if mode == AuthBrokerModeOAuthConnect { + broker.AuthorizationEndpoint = "https://idp/authorize" + } cfg := baseValidConfig(&ServerConfig{ Name: "s", Protocol: "streamable-http", URL: "https://x/mcp", - AuthBroker: &AuthBrokerConfig{Mode: mode, TokenEndpoint: "https://idp/token"}, + AuthBroker: broker, }) require.NoError(t, cfg.Validate()) }) diff --git a/internal/teams/broker/oauth_connector.go b/internal/teams/broker/oauth_connector.go new file mode 100644 index 00000000..ee9279c0 --- /dev/null +++ b/internal/teams/broker/oauth_connector.go @@ -0,0 +1,380 @@ +//go:build server + +package broker + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" + "go.uber.org/zap" +) + +// defaultStateTTL bounds how long a pending connect flow may sit between the +// authorize redirect and the callback. After this window the state is rejected +// and garbage-collected (confused-deputy / replay hardening, FR-011). +const defaultStateTTL = 10 * time.Minute + +// connectFlowObtainedVia tags credentials acquired through the per-user OAuth +// connect flow (Path B). It distinguishes them from token-exchange creds in +// the store and in audit output (FR-012). +const connectFlowObtainedVia = "connect_flow" + +// ConnectorConfig is the resolved, per-upstream configuration the OAuthConnector +// needs to drive a standard authorization-code + PKCE flow against an upstream +// authorization server. It is assembled by callers (the REST layer, T8) from +// the per-server config.AuthBrokerConfig plus the gateway's own callback URL. +type ConnectorConfig struct { + // ServerName / ServerURL identify the upstream and derive the store's + // serverKey via oauth.GenerateServerKey (matches the existing scheme). + ServerName string + ServerURL string + // AuthorizationEndpoint is the upstream AS authorize URL the user is + // redirected to for consent. + AuthorizationEndpoint string + // TokenEndpoint is the upstream AS token URL used to exchange the auth code + // and to refresh. + TokenEndpoint string + // ClientID / ClientSecret authenticate the gateway to the upstream AS. A + // public client may leave ClientSecret empty (PKCE still protects the code). + ClientID string + ClientSecret string + // Scopes requested from the upstream AS. + Scopes []string + // RedirectURI is the gateway's own callback URL registered with the AS. + RedirectURI string + // Resource is the optional RFC 8707 audience the resulting token is scoped + // to. + Resource string +} + +// validate checks the fields required to drive a connect flow. +func (c ConnectorConfig) validate() error { + switch { + case c.AuthorizationEndpoint == "": + return fmt.Errorf("oauth connector: authorization_endpoint is required") + case c.TokenEndpoint == "": + return fmt.Errorf("oauth connector: token_endpoint is required") + case c.ClientID == "": + return fmt.Errorf("oauth connector: client_id is required") + case c.RedirectURI == "": + return fmt.Errorf("oauth connector: redirect_uri is required") + } + return nil +} + +// pendingFlow tracks one in-flight connect flow between the authorize redirect +// and the callback. It binds the opaque state to the initiating user and the +// PKCE verifier so the callback can be matched back to its initiator +// (per-user state tracking, FR-011). +type pendingFlow struct { + userID string + verifier string + createdAt time.Time +} + +// OAuthConnector implements Path B of spec 074: a per-user, authorization-code +// + PKCE connect flow against an upstream authorization server that does not +// support token exchange. It issues authorize URLs, handles callbacks, persists +// the resulting per-user upstream credential encrypted (ObtainedVia=connect_flow), +// and refreshes transparently via the refresh token. +// +// One connector instance serves a single upstream; the store, however, is +// shared and isolates records per user. +type OAuthConnector struct { + store CredentialStore + cfg ConnectorConfig + serverKey string + client *http.Client + logger *zap.Logger + + mu sync.Mutex + pending map[string]*pendingFlow + + // now and stateTTL are injectable for tests. + now func() time.Time + stateTTL time.Duration +} + +// NewOAuthConnector builds a connector for one upstream. It returns an error if +// the configuration is missing fields required to run a connect flow. +func NewOAuthConnector(store CredentialStore, cfg ConnectorConfig, logger *zap.Logger) (*OAuthConnector, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + if logger == nil { + logger = zap.NewNop() + } + return &OAuthConnector{ + store: store, + cfg: cfg, + serverKey: oauth.GenerateServerKey(cfg.ServerName, cfg.ServerURL), + client: &http.Client{Timeout: 30 * time.Second}, + logger: logger.Named("oauth-connector").With(zap.String("server", cfg.ServerName)), + pending: make(map[string]*pendingFlow), + now: time.Now, + stateTTL: defaultStateTTL, + }, nil +} + +// ServerKey returns the store key this connector persists credentials under. +func (c *OAuthConnector) ServerKey() string { return c.serverKey } + +// BuildAuthorizationURL starts a connect flow for userID. It generates a PKCE +// verifier/challenge and an opaque state, records the pending flow, and returns +// the upstream authorize URL (to which the gateway redirects the user) plus the +// state token. Explicit per-user consent at the AS plus the unguessable state +// is the confused-deputy avoidance required by FR-011. +func (c *OAuthConnector) BuildAuthorizationURL(userID string) (authURL, state string, err error) { + if userID == "" { + return "", "", fmt.Errorf("oauth connector: userID is required") + } + verifier, err := randomURLSafe(32) + if err != nil { + return "", "", fmt.Errorf("oauth connector: generate verifier: %w", err) + } + state, err = randomURLSafe(32) + if err != nil { + return "", "", fmt.Errorf("oauth connector: generate state: %w", err) + } + challenge := codeChallengeS256(verifier) + + c.mu.Lock() + c.gcExpiredLocked() + c.pending[state] = &pendingFlow{userID: userID, verifier: verifier, createdAt: c.now()} + c.mu.Unlock() + + params := url.Values{ + "response_type": {"code"}, + "client_id": {c.cfg.ClientID}, + "redirect_uri": {c.cfg.RedirectURI}, + "state": {state}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + if len(c.cfg.Scopes) > 0 { + params.Set("scope", strings.Join(c.cfg.Scopes, " ")) + } + if c.cfg.Resource != "" { + params.Set("resource", c.cfg.Resource) + } + + sep := "?" + if strings.Contains(c.cfg.AuthorizationEndpoint, "?") { + sep = "&" + } + return c.cfg.AuthorizationEndpoint + sep + params.Encode(), state, nil +} + +// Complete handles a successful callback. It validates state (must be a known, +// unexpired, one-time pending flow), exchanges the code for an upstream token +// using the bound PKCE verifier, and stores the per-user credential encrypted +// with ObtainedVia=connect_flow. The state is consumed regardless of outcome so +// it cannot be replayed. +func (c *OAuthConnector) Complete(ctx context.Context, state, code string) (*UpstreamCredential, error) { + flow, err := c.consume(state) + if err != nil { + return nil, err + } + if code == "" { + return nil, fmt.Errorf("oauth connector: empty authorization code") + } + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {c.cfg.RedirectURI}, + "client_id": {c.cfg.ClientID}, + "code_verifier": {flow.verifier}, + } + tok, err := c.postToken(ctx, form) + if err != nil { + return nil, err + } + + cred := c.credentialFromToken(tok, "") + if err := c.store.Put(flow.userID, c.serverKey, cred); err != nil { + return nil, fmt.Errorf("oauth connector: persist credential: %w", err) + } + c.logger.Info("stored per-user upstream credential via connect flow", + zap.String("user_id", flow.userID)) + return cred, nil +} + +// Deny handles a denied or failed callback (e.g. the AS returned +// error=access_denied). It clears the pending flow and stores nothing. +func (c *OAuthConnector) Deny(state, reason string) error { + c.mu.Lock() + delete(c.pending, state) + c.mu.Unlock() + c.logger.Info("connect flow denied by user", zap.String("reason", reason)) + return nil +} + +// Refresh mints a fresh access token for userID from the stored refresh token +// and re-persists the credential. It is the transparent auto-refresh path +// (FR-012). An absent or empty refresh token is an error. +func (c *OAuthConnector) Refresh(ctx context.Context, userID string) (*UpstreamCredential, error) { + existing, err := c.store.Get(userID, c.serverKey) + if err != nil { + return nil, fmt.Errorf("oauth connector: load credential: %w", err) + } + if existing.RefreshToken == "" { + return nil, fmt.Errorf("oauth connector: no refresh token for user %q", userID) + } + + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {existing.RefreshToken}, + "client_id": {c.cfg.ClientID}, + } + if len(c.cfg.Scopes) > 0 { + form.Set("scope", strings.Join(c.cfg.Scopes, " ")) + } + if c.cfg.Resource != "" { + form.Set("resource", c.cfg.Resource) + } + tok, err := c.postToken(ctx, form) + if err != nil { + return nil, err + } + + // Preserve the prior refresh token when the AS does not rotate it. + cred := c.credentialFromToken(tok, existing.RefreshToken) + if err := c.store.Put(userID, c.serverKey, cred); err != nil { + return nil, fmt.Errorf("oauth connector: persist refreshed credential: %w", err) + } + c.logger.Debug("refreshed per-user upstream credential", zap.String("user_id", userID)) + return cred, nil +} + +// tokenResponse is the subset of the OAuth token endpoint response we consume. +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` +} + +// postToken sends a form-encoded request to the upstream token endpoint and +// decodes the response. The gateway authenticates with client_secret when set. +func (c *OAuthConnector) postToken(ctx context.Context, form url.Values) (*tokenResponse, error) { + if c.cfg.ClientSecret != "" { + form.Set("client_secret", c.cfg.ClientSecret) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("oauth connector: build token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("oauth connector: token request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("oauth connector: read token response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oauth connector: token endpoint returned %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var tok tokenResponse + if err := json.Unmarshal(body, &tok); err != nil { + return nil, fmt.Errorf("oauth connector: parse token response: %w", err) + } + if tok.AccessToken == "" { + return nil, fmt.Errorf("oauth connector: token response missing access_token") + } + return &tok, nil +} + +// credentialFromToken maps a token response into a stored UpstreamCredential. +// fallbackRefresh is used when the response omits a refresh token (so a +// non-rotating AS does not drop the user's refresh capability). +func (c *OAuthConnector) credentialFromToken(tok *tokenResponse, fallbackRefresh string) *UpstreamCredential { + tokenType := tok.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + refresh := tok.RefreshToken + if refresh == "" { + refresh = fallbackRefresh + } + var expiresAt time.Time + if tok.ExpiresIn > 0 { + expiresAt = c.now().Add(time.Duration(tok.ExpiresIn) * time.Second).UTC() + } + var scopes []string + if tok.Scope != "" { + scopes = strings.Fields(tok.Scope) + } + return &UpstreamCredential{ + Type: "oauth2", + AccessToken: tok.AccessToken, + RefreshToken: refresh, + ExpiresAt: expiresAt, + Scopes: scopes, + TokenType: tokenType, + Audience: c.cfg.Resource, + ObtainedVia: connectFlowObtainedVia, + UpdatedAt: c.now().UTC(), + } +} + +// consume validates and removes a pending flow by state. It rejects unknown and +// expired states; a returned flow has been deleted so state is single-use. +func (c *OAuthConnector) consume(state string) (*pendingFlow, error) { + c.mu.Lock() + defer c.mu.Unlock() + flow, ok := c.pending[state] + if !ok { + return nil, fmt.Errorf("oauth connector: unknown or already-used state") + } + delete(c.pending, state) + if c.now().Sub(flow.createdAt) > c.stateTTL { + return nil, fmt.Errorf("oauth connector: state expired") + } + return flow, nil +} + +// gcExpiredLocked drops expired pending flows. Caller holds c.mu. +func (c *OAuthConnector) gcExpiredLocked() { + cutoff := c.now().Add(-c.stateTTL) + for k, v := range c.pending { + if v.createdAt.Before(cutoff) { + delete(c.pending, k) + } + } +} + +// randomURLSafe returns nBytes of cryptographically-random data, base64url +// (no padding) encoded — suitable for PKCE verifiers and opaque state tokens. +func randomURLSafe(nBytes int) (string, error) { + b := make([]byte, nBytes) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// codeChallengeS256 computes the PKCE S256 challenge for a verifier. +func codeChallengeS256(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} diff --git a/internal/teams/broker/oauth_connector_test.go b/internal/teams/broker/oauth_connector_test.go new file mode 100644 index 00000000..9195b040 --- /dev/null +++ b/internal/teams/broker/oauth_connector_test.go @@ -0,0 +1,400 @@ +//go:build server + +package broker + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "go.uber.org/zap" +) + +// connectorTestConfig returns a ConnectorConfig pointed at the given token and +// authorization endpoints. The authorization endpoint never receives traffic in +// tests (the browser would); only its host/path is reflected into the URL. +func connectorTestConfig(tokenEndpoint string) ConnectorConfig { + return ConnectorConfig{ + ServerName: "github-mcp", + ServerURL: "https://api.github.com/mcp", + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: tokenEndpoint, + ClientID: "gateway-client-id", + ClientSecret: "gateway-client-secret", + Scopes: []string{"repo", "read:user"}, + RedirectURI: "https://gw.example.com/api/v1/user/credentials/callback", + Resource: "https://api.github.com/mcp", + } +} + +// s256 returns the base64url-encoded SHA-256 of the given verifier string, +// i.e. the expected PKCE code_challenge for code_challenge_method=S256. +func s256(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +// newConnectorTestStore returns an enabled in-memory credential store. +func newConnectorTestStore(t *testing.T) CredentialStore { + t.Helper() + db := openTestDB(t) + return newTestStore(t, db, newTestKey(t)) +} + +func newTestConnector(t *testing.T, cfg ConnectorConfig) *OAuthConnector { + t.Helper() + c, err := NewOAuthConnector(newConnectorTestStore(t), cfg, zap.NewNop()) + if err != nil { + t.Fatalf("NewOAuthConnector: %v", err) + } + return c +} + +// mockTokenServer stands in for the upstream AS token endpoint. It records the +// last received form values and replies with a canned token response. +type mockTokenServer struct { + srv *httptest.Server + lastForm url.Values + accessToken string + refreshToken string + expiresIn int + status int +} + +func newMockTokenServer(t *testing.T) *mockTokenServer { + t.Helper() + m := &mockTokenServer{ + accessToken: "upstream-access-token", + refreshToken: "upstream-refresh-token", + expiresIn: 3600, + status: http.StatusOK, + } + m.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + m.lastForm = r.PostForm + if m.status != http.StatusOK { + w.WriteHeader(m.status) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + resp := map[string]interface{}{ + "access_token": m.accessToken, + "token_type": "Bearer", + "expires_in": m.expiresIn, + "scope": "repo read:user", + } + if m.refreshToken != "" { + resp["refresh_token"] = m.refreshToken + } + _ = json.NewEncoder(w).Encode(resp) + })) + t.Cleanup(m.srv.Close) + return m +} + +func TestOAuthConnector_BuildAuthorizationURL(t *testing.T) { + c := newTestConnector(t, connectorTestConfig("https://unused.example.com/token")) + + authURL, state, err := c.BuildAuthorizationURL("user-alice") + if err != nil { + t.Fatalf("BuildAuthorizationURL: %v", err) + } + if state == "" { + t.Fatal("expected non-empty state") + } + + u, err := url.Parse(authURL) + if err != nil { + t.Fatalf("parse authURL: %v", err) + } + if got := u.Scheme + "://" + u.Host + u.Path; got != "https://auth.example.com/authorize" { + t.Errorf("authorize endpoint = %q, want https://auth.example.com/authorize", got) + } + q := u.Query() + if q.Get("response_type") != "code" { + t.Errorf("response_type = %q, want code", q.Get("response_type")) + } + if q.Get("client_id") != "gateway-client-id" { + t.Errorf("client_id = %q", q.Get("client_id")) + } + if q.Get("redirect_uri") != "https://gw.example.com/api/v1/user/credentials/callback" { + t.Errorf("redirect_uri = %q", q.Get("redirect_uri")) + } + if q.Get("scope") != "repo read:user" { + t.Errorf("scope = %q, want %q", q.Get("scope"), "repo read:user") + } + if q.Get("state") != state { + t.Errorf("state in URL = %q, returned %q", q.Get("state"), state) + } + if q.Get("code_challenge") == "" { + t.Error("expected non-empty code_challenge") + } + if q.Get("code_challenge_method") != "S256" { + t.Errorf("code_challenge_method = %q, want S256", q.Get("code_challenge_method")) + } + if q.Get("resource") != "https://api.github.com/mcp" { + t.Errorf("resource = %q", q.Get("resource")) + } +} + +func TestOAuthConnector_BuildAuthorizationURL_UniquePerFlow(t *testing.T) { + c := newTestConnector(t, connectorTestConfig("https://unused.example.com/token")) + url1, s1, _ := c.BuildAuthorizationURL("user-a") + url2, s2, _ := c.BuildAuthorizationURL("user-b") + if s1 == s2 { + t.Error("expected distinct state per flow") + } + ch1 := mustQuery(t, url1, "code_challenge") + ch2 := mustQuery(t, url2, "code_challenge") + if ch1 == ch2 { + t.Error("expected distinct PKCE challenge per flow") + } +} + +func TestOAuthConnector_Complete_StoresEncryptedToken(t *testing.T) { + m := newMockTokenServer(t) + cfg := connectorTestConfig(m.srv.URL) + store := newConnectorTestStore(t) + c, err := NewOAuthConnector(store, cfg, zap.NewNop()) + if err != nil { + t.Fatalf("NewOAuthConnector: %v", err) + } + + authURL, state, err := c.BuildAuthorizationURL("user-alice") + if err != nil { + t.Fatalf("BuildAuthorizationURL: %v", err) + } + challenge := mustQuery(t, authURL, "code_challenge") + + cred, err := c.Complete(context.Background(), state, "auth-code-xyz") + if err != nil { + t.Fatalf("Complete: %v", err) + } + if cred.AccessToken != "upstream-access-token" { + t.Errorf("AccessToken = %q", cred.AccessToken) + } + if cred.RefreshToken != "upstream-refresh-token" { + t.Errorf("RefreshToken = %q", cred.RefreshToken) + } + if cred.ObtainedVia != "connect_flow" { + t.Errorf("ObtainedVia = %q, want connect_flow", cred.ObtainedVia) + } + if cred.ExpiresAt.IsZero() { + t.Error("expected non-zero ExpiresAt from expires_in") + } + + // PKCE roundtrip: token endpoint must have received the verifier matching + // the challenge from the authorize URL. + gotVerifier := m.lastForm.Get("code_verifier") + if gotVerifier == "" { + t.Fatal("token endpoint received no code_verifier") + } + if s256(gotVerifier) != challenge { + t.Errorf("PKCE mismatch: S256(verifier)=%q challenge=%q", s256(gotVerifier), challenge) + } + if m.lastForm.Get("grant_type") != "authorization_code" { + t.Errorf("grant_type = %q", m.lastForm.Get("grant_type")) + } + if m.lastForm.Get("code") != "auth-code-xyz" { + t.Errorf("code = %q", m.lastForm.Get("code")) + } + + // Stored per-user, retrievable, ObtainedVia preserved through the encrypted + // round-trip. + serverKey := c.ServerKey() + stored, err := store.Get("user-alice", serverKey) + if err != nil { + t.Fatalf("store.Get: %v", err) + } + if stored.AccessToken != "upstream-access-token" || stored.ObtainedVia != "connect_flow" { + t.Errorf("stored cred wrong: %+v", stored) + } +} + +func TestOAuthConnector_Complete_InvalidState(t *testing.T) { + m := newMockTokenServer(t) + store := newConnectorTestStore(t) + c, _ := NewOAuthConnector(store, connectorTestConfig(m.srv.URL), zap.NewNop()) + + if _, err := c.Complete(context.Background(), "bogus-state", "code"); err == nil { + t.Fatal("expected error for unknown state") + } + if m.lastForm != nil { + t.Error("token endpoint should not be called for an invalid state") + } +} + +func TestOAuthConnector_Complete_ExpiredState(t *testing.T) { + m := newMockTokenServer(t) + store := newConnectorTestStore(t) + c, _ := NewOAuthConnector(store, connectorTestConfig(m.srv.URL), zap.NewNop()) + + base := time.Now() + c.now = func() time.Time { return base } + _, state, err := c.BuildAuthorizationURL("user-alice") + if err != nil { + t.Fatalf("BuildAuthorizationURL: %v", err) + } + // Advance past the state TTL. + c.now = func() time.Time { return base.Add(c.stateTTL + time.Minute) } + + if _, err := c.Complete(context.Background(), state, "code"); err == nil { + t.Fatal("expected error for expired state") + } + if _, err := store.Get("user-alice", c.ServerKey()); err == nil { + t.Error("nothing should be stored for an expired flow") + } +} + +func TestOAuthConnector_Complete_StateIsOneTime(t *testing.T) { + m := newMockTokenServer(t) + c, _ := NewOAuthConnector(newConnectorTestStore(t), connectorTestConfig(m.srv.URL), zap.NewNop()) + + _, state, _ := c.BuildAuthorizationURL("user-alice") + if _, err := c.Complete(context.Background(), state, "code"); err != nil { + t.Fatalf("first Complete: %v", err) + } + if _, err := c.Complete(context.Background(), state, "code"); err == nil { + t.Fatal("expected error reusing a consumed state") + } +} + +func TestOAuthConnector_Deny_StoresNothing(t *testing.T) { + m := newMockTokenServer(t) + store := newConnectorTestStore(t) + c, _ := NewOAuthConnector(store, connectorTestConfig(m.srv.URL), zap.NewNop()) + + _, state, _ := c.BuildAuthorizationURL("user-alice") + if err := c.Deny(state, "access_denied"); err != nil { + t.Fatalf("Deny: %v", err) + } + if _, err := store.Get("user-alice", c.ServerKey()); err == nil { + t.Error("denied consent must store nothing") + } + // State is cleared: a follow-up Complete must fail. + if _, err := c.Complete(context.Background(), state, "code"); err == nil { + t.Error("expected error completing a denied/cleared flow") + } + if m.lastForm != nil { + t.Error("token endpoint must not be called on denial") + } +} + +func TestOAuthConnector_Refresh(t *testing.T) { + m := newMockTokenServer(t) + m.accessToken = "refreshed-access-token" + m.refreshToken = "" // emulate AS that does not rotate the refresh token + store := newConnectorTestStore(t) + c, _ := NewOAuthConnector(store, connectorTestConfig(m.srv.URL), zap.NewNop()) + + // Seed an existing connect-flow credential with a refresh token. + seed := &UpstreamCredential{ + Type: "oauth2", + AccessToken: "old-access-token", + RefreshToken: "seed-refresh-token", + ExpiresAt: time.Now().Add(-time.Minute), // expired + ObtainedVia: "connect_flow", + } + if err := store.Put("user-alice", c.ServerKey(), seed); err != nil { + t.Fatalf("seed Put: %v", err) + } + + cred, err := c.Refresh(context.Background(), "user-alice") + if err != nil { + t.Fatalf("Refresh: %v", err) + } + if cred.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken = %q, want refreshed-access-token", cred.AccessToken) + } + // Refresh token preserved when the AS omits a new one. + if cred.RefreshToken != "seed-refresh-token" { + t.Errorf("RefreshToken = %q, want preserved seed-refresh-token", cred.RefreshToken) + } + if cred.ObtainedVia != "connect_flow" { + t.Errorf("ObtainedVia = %q", cred.ObtainedVia) + } + if m.lastForm.Get("grant_type") != "refresh_token" { + t.Errorf("grant_type = %q, want refresh_token", m.lastForm.Get("grant_type")) + } + if m.lastForm.Get("refresh_token") != "seed-refresh-token" { + t.Errorf("sent refresh_token = %q", m.lastForm.Get("refresh_token")) + } + + // Persisted. + stored, err := store.Get("user-alice", c.ServerKey()) + if err != nil { + t.Fatalf("store.Get: %v", err) + } + if stored.AccessToken != "refreshed-access-token" { + t.Errorf("stored AccessToken = %q", stored.AccessToken) + } +} + +func TestOAuthConnector_Refresh_NoRefreshToken(t *testing.T) { + m := newMockTokenServer(t) + store := newConnectorTestStore(t) + c, _ := NewOAuthConnector(store, connectorTestConfig(m.srv.URL), zap.NewNop()) + + seed := &UpstreamCredential{AccessToken: "at", ObtainedVia: "connect_flow"} // no refresh token + if err := store.Put("user-alice", c.ServerKey(), seed); err != nil { + t.Fatalf("seed Put: %v", err) + } + if _, err := c.Refresh(context.Background(), "user-alice"); err == nil { + t.Fatal("expected error refreshing a credential with no refresh token") + } +} + +func TestOAuthConnector_Complete_TokenEndpointError(t *testing.T) { + m := newMockTokenServer(t) + m.status = http.StatusBadRequest + store := newConnectorTestStore(t) + c, _ := NewOAuthConnector(store, connectorTestConfig(m.srv.URL), zap.NewNop()) + + _, state, _ := c.BuildAuthorizationURL("user-alice") + if _, err := c.Complete(context.Background(), state, "code"); err == nil { + t.Fatal("expected error when token endpoint returns 400") + } + if _, err := store.Get("user-alice", c.ServerKey()); err == nil { + t.Error("nothing should be stored on token-exchange failure") + } +} + +func TestNewOAuthConnector_Validation(t *testing.T) { + store := newConnectorTestStore(t) + base := connectorTestConfig("https://idp/token") + cases := map[string]func(*ConnectorConfig){ + "missing authorization_endpoint": func(c *ConnectorConfig) { c.AuthorizationEndpoint = "" }, + "missing token_endpoint": func(c *ConnectorConfig) { c.TokenEndpoint = "" }, + "missing client_id": func(c *ConnectorConfig) { c.ClientID = "" }, + "missing redirect_uri": func(c *ConnectorConfig) { c.RedirectURI = "" }, + } + for name, mutate := range cases { + t.Run(name, func(t *testing.T) { + cfg := base + mutate(&cfg) + if _, err := NewOAuthConnector(store, cfg, zap.NewNop()); err == nil { + t.Errorf("expected validation error for %s", name) + } + }) + } +} + +func mustQuery(t *testing.T, rawURL, key string) string { + t.Helper() + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("parse %q: %v", rawURL, err) + } + v := u.Query().Get(key) + if v == "" { + t.Fatalf("missing query param %q in %q", key, rawURL) + } + return v +}