diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 84f1ec5e3268..7c5471f78150 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -63,6 +63,7 @@ const ( RouteCallback = RouteBase + "/callback/{provider}" RouteCallbackGeneric = RouteBase + "/callback" RouteOrganizationCallback = RouteBase + "/organization/{organization}/callback/{provider}" + RouteThirdPartyLoginInit = RouteBase + "/third-party-login" ) var ( @@ -217,6 +218,13 @@ func (s *Strategy) RegisterPublicRoutes(r *httprouterx.RouterPublic) { // by the browser. So here we just redirect the request to the same location rewriting the // form fields to query params. This second GET request should have the cookies attached. r.POST(RouteCallback, s.redirectToGET) + + // Third-party login initiation (OpenID Connect spec Section 4). + // CSRF is exempted because the request originates from an external party. + s.d.CSRFHandler().IgnorePath(RouteThirdPartyLoginInit) + wrappedThirdParty := strategy.IsDisabled(s.d, s.ID().String(), s.HandleThirdPartyLoginInit) + r.GET(RouteThirdPartyLoginInit, wrappedThirdParty) + r.POST(RouteThirdPartyLoginInit, wrappedThirdParty) } func (s *Strategy) RegisterAdminRoutes(*httprouterx.RouterAdmin) {} diff --git a/selfservice/strategy/oidc/strategy_third_party_login.go b/selfservice/strategy/oidc/strategy_third_party_login.go new file mode 100644 index 000000000000..253a7c3fd851 --- /dev/null +++ b/selfservice/strategy/oidc/strategy_third_party_login.go @@ -0,0 +1,166 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oidc + +import ( + "context" + "net/http" + "net/url" + "strings" + "time" + + "github.com/pkg/errors" + + "github.com/ory/herodot" + "github.com/ory/kratos/continuity" + "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/login" + "github.com/ory/kratos/x/redir" + "github.com/ory/x/otelx" +) + +// HandleThirdPartyLoginInit implements OpenID Connect Third-Party Login +// Initiation (spec Section 4). An external party redirects the user here with +// an `iss` parameter identifying the OIDC provider. Kratos looks up the +// matching provider, creates a login flow, and redirects directly to the +// provider's authorization endpoint — no login UI is shown. +func (s *Strategy) HandleThirdPartyLoginInit(w http.ResponseWriter, r *http.Request) { + var err error + ctx := r.Context() + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.HandleThirdPartyLoginInit") + defer otelx.End(span, &err) + r = r.WithContext(ctx) + + if err = r.ParseForm(); err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, + errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to parse form: %s", err))) + return + } + + iss := r.FormValue("iss") + loginHint := r.FormValue("login_hint") + targetLinkURI := r.FormValue("target_link_uri") + + if iss == "" { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, + errors.WithStack(herodot.ErrBadRequest.WithReason("The `iss` parameter is required."))) + return + } + + issURL, parseErr := url.Parse(iss) + if parseErr != nil || issURL.Host == "" || + (issURL.Scheme != "https" && !s.d.Config().IsInsecureDevMode(ctx)) { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, + errors.WithStack(herodot.ErrBadRequest.WithReasonf( + "The `iss` parameter must be a valid HTTPS URL, got: %q", iss))) + return + } + + provider, _, err := s.findProviderByIssuer(ctx, iss) + if err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, err) + return + } + + conf := s.d.Config() + var validatedTargetURI *url.URL + if targetLinkURI != "" { + if validatedTargetURI, err = redir.SecureRedirectTo(r, + conf.SelfServiceBrowserDefaultReturnTo(ctx), + redir.SecureRedirectReturnTo(targetLinkURI), + redir.SecureRedirectAllowURLs(conf.SelfServiceBrowserAllowedReturnToDomains(ctx)), + redir.SecureRedirectAllowSelfServiceURLs(conf.SelfPublicURL(ctx)), + ); err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, + errors.WithStack(herodot.ErrBadRequest.WithReasonf( + "The `target_link_uri` is not allowed: %s", err))) + return + } + + q := r.URL.Query() + q.Set("return_to", targetLinkURI) + r.URL.RawQuery = q.Encode() + } + + loginFlow, _, err := s.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser) + if err != nil { + if errors.Is(err, login.ErrAlreadyLoggedIn) { + returnTo := conf.SelfServiceBrowserDefaultReturnTo(ctx) + if validatedTargetURI != nil { + returnTo = validatedTargetURI + } + http.Redirect(w, r, returnTo.String(), http.StatusSeeOther) + return + } + if errors.Is(err, flow.ErrCompletedByStrategy) { + return + } + s.d.SelfServiceErrorManager().Forward(ctx, w, r, err) + return + } + if loginFlow == nil { + // PreLoginHook already wrote the response. + return + } + + state, pkce, err := s.GenerateState(ctx, provider, loginFlow) + if err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, err) + return + } + + if err = s.d.ContinuityManager().Pause(ctx, w, r, sessionName, + continuity.WithPayload(&AuthCodeContainer{ + State: state, + FlowID: loginFlow.ID.String(), + }), + continuity.WithLifespan(time.Minute*30)); err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, err) + return + } + + loginFlow.Active = s.ID() + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, loginFlow); err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, + errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithWrap(err))) + return + } + + up := make(map[string]string) + if loginHint != "" { + up["login_hint"] = loginHint + } + + codeURL, err := getAuthRedirectURL(ctx, provider, loginFlow, state, up, pkce) + if err != nil { + s.d.SelfServiceErrorManager().Forward(ctx, w, r, err) + return + } + + http.Redirect(w, r, codeURL, http.StatusSeeOther) +} + +// findProviderByIssuer looks up a configured OIDC provider whose IssuerURL +// matches the given issuer string (trailing-slash normalized). +func (s *Strategy) findProviderByIssuer(ctx context.Context, issuer string) (Provider, *Configuration, error) { + conf, err := s.Config(ctx) + if err != nil { + return nil, nil, err + } + + issuer = strings.TrimRight(issuer, "/") + for _, p := range conf.Providers { + if strings.TrimRight(p.IssuerURL, "/") == issuer { + provider, err := conf.Provider(p.ID, s.d) + if err != nil { + return nil, nil, err + } + return provider, &p, nil + } + } + + return nil, nil, errors.WithStack( + herodot.ErrNotFound.WithReasonf("No configured OpenID Connect provider matches the issuer %q", issuer), + ) +} diff --git a/selfservice/strategy/oidc/strategy_third_party_login_test.go b/selfservice/strategy/oidc/strategy_third_party_login_test.go new file mode 100644 index 000000000000..a1b12317a7d2 --- /dev/null +++ b/selfservice/strategy/oidc/strategy_third_party_login_test.go @@ -0,0 +1,291 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oidc_test + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/identity" + "github.com/ory/kratos/pkg" + "github.com/ory/kratos/pkg/testhelpers" + "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/x/configx" + "github.com/ory/x/httprouterx" +) + +func TestThirdPartyLoginInit(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Mock OIDC discovery server over plain HTTP. Dev mode is enabled so + // the handler accepts HTTP issuers. + mockOIDC := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/.well-known/openid-configuration": + issuer := "http://" + r.Host + fmt.Fprintf(w, `{ + "issuer": %q, + "authorization_endpoint": %q, + "token_endpoint": %q, + "jwks_uri": %q, + "code_challenge_methods_supported": ["S256"] + }`, + issuer, + issuer+"/authorize", + issuer+"/token", + issuer+"/jwks", + ) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(mockOIDC.Close) + mockOIDCIssuer := mockOIDC.URL // http://127.0.0.1: + + conf, reg := pkg.NewFastRegistryWithMocks(t, + configx.WithValues(map[string]any{ + "dev": true, + config.ViperKeyIdentitySchemas: config.Schemas{ + {ID: "default", URL: "file://./stub/registration.schema.json"}, + }, + config.ViperKeyDefaultIdentitySchemaID: "default", + config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String()): []config.SelfServiceHook{{Name: "session"}}, + }), + ) + + returnTS := newReturnTS(t, reg) + _ = newUI(t, reg) + errTS := testhelpers.NewErrorTestServer(t, reg) + + routerP := httprouterx.NewTestRouterPublic(t) + routerA := httprouterx.NewTestRouterAdminWithPrefix(t) + ts, _ := testhelpers.NewKratosServerWithRouters(t, reg, routerP, routerA) + + viperSetProviderConfig(t, conf, oidc.Configuration{ + ID: "test-provider", + Provider: "generic", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + IssuerURL: mockOIDCIssuer, + Mapper: "file://./stub/oidc.hydra.jsonnet", + }) + + thirdPartyURL := ts.URL + oidc.RouteThirdPartyLoginInit + + // noRedirectClient follows all redirects within the test infrastructure. + noRedirectClient := func(t *testing.T) *http.Client { + return &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + t.Logf("Redirect: %s", req.URL.String()) + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + return nil + }, + } + } + + // stopOnExternalRedirect stops following redirects once the target leaves + // the test infrastructure servers. + stopOnExternalRedirect := func(t *testing.T) *http.Client { + return &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + for _, allowed := range []string{ts.URL, errTS.URL, returnTS.URL} { + u, _ := url.Parse(allowed) + if req.URL.Host == u.Host { + return nil + } + } + return http.ErrUseLastResponse + }, + } + } + + t.Run("case=should fail when iss is missing", func(t *testing.T) { + res, err := noRedirectClient(t).Get(thirdPartyURL) + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, res.Request.URL.String(), errTS.URL, "%s", body) + assert.Equal(t, int64(http.StatusBadRequest), gjson.GetBytes(body, "code").Int(), "%s", body) + assert.Contains(t, gjson.GetBytes(body, "reason").String(), "`iss` parameter is required", "%s", body) + }) + + t.Run("case=should fail when iss is not HTTPS in production mode", func(t *testing.T) { + // Temporarily disable dev mode to test HTTPS enforcement. + conf.MustSet(ctx, "dev", false) + t.Cleanup(func() { conf.MustSet(ctx, "dev", true) }) + + res, err := noRedirectClient(t).Get(thirdPartyURL + "?iss=http://example.com") + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, res.Request.URL.String(), errTS.URL, "%s", body) + assert.Equal(t, int64(http.StatusBadRequest), gjson.GetBytes(body, "code").Int(), "%s", body) + assert.Contains(t, gjson.GetBytes(body, "reason").String(), "HTTPS", "%s", body) + }) + + t.Run("case=should fail when iss is not a valid URL", func(t *testing.T) { + res, err := noRedirectClient(t).Get(thirdPartyURL + "?iss=not-a-url") + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, res.Request.URL.String(), errTS.URL, "%s", body) + assert.Equal(t, int64(http.StatusBadRequest), gjson.GetBytes(body, "code").Int(), "%s", body) + assert.Contains(t, gjson.GetBytes(body, "reason").String(), "HTTPS", "%s", body) + }) + + t.Run("case=should fail when issuer is unknown", func(t *testing.T) { + res, err := noRedirectClient(t).Get(thirdPartyURL + "?iss=https://unknown-issuer.example.com") + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, res.Request.URL.String(), errTS.URL, "%s", body) + assert.Equal(t, int64(http.StatusNotFound), gjson.GetBytes(body, "code").Int(), "%s", body) + assert.Contains(t, gjson.GetBytes(body, "reason").String(), "No configured OpenID Connect provider", "%s", body) + }) + + t.Run("case=should fail when target_link_uri is not allowed", func(t *testing.T) { + issuer := url.QueryEscape(mockOIDCIssuer) + target := url.QueryEscape("https://evil.example.com/steal-session") + res, err := noRedirectClient(t).Get(thirdPartyURL + "?iss=" + issuer + "&target_link_uri=" + target) + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, res.Request.URL.String(), errTS.URL, "%s", body) + assert.Equal(t, int64(http.StatusBadRequest), gjson.GetBytes(body, "code").Int(), "%s", body) + assert.Contains(t, gjson.GetBytes(body, "reason").String(), "target_link_uri", "%s", body) + }) + + t.Run("case=should redirect to OIDC provider on valid request", func(t *testing.T) { + issuer := url.QueryEscape(mockOIDCIssuer) + client := stopOnExternalRedirect(t) + res, err := client.Get(thirdPartyURL + "?iss=" + issuer) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusSeeOther, res.StatusCode) + location := res.Header.Get("Location") + require.NotEmpty(t, location, "Expected Location header with auth redirect URL") + + locURL, err := url.Parse(location) + require.NoError(t, err) + + // The auth redirect should point to the mock OIDC server's + // authorization_endpoint (which is its own URL + /authorize). + mockURL, _ := url.Parse(mockOIDCIssuer) + assert.Equal(t, mockURL.Host, locURL.Host) + assert.Equal(t, "/authorize", locURL.Path) + assert.Equal(t, "test-client-id", locURL.Query().Get("client_id")) + assert.Equal(t, "code", locURL.Query().Get("response_type")) + assert.NotEmpty(t, locURL.Query().Get("state")) + }) + + t.Run("case=should pass login_hint to provider", func(t *testing.T) { + issuer := url.QueryEscape(mockOIDCIssuer) + client := stopOnExternalRedirect(t) + res, err := client.Get(thirdPartyURL + "?iss=" + issuer + "&login_hint=user@example.com") + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusSeeOther, res.StatusCode) + location := res.Header.Get("Location") + require.NotEmpty(t, location) + + locURL, err := url.Parse(location) + require.NoError(t, err) + assert.Equal(t, "user@example.com", locURL.Query().Get("login_hint")) + }) + + t.Run("case=should accept valid target_link_uri", func(t *testing.T) { + issuer := url.QueryEscape(mockOIDCIssuer) + target := url.QueryEscape(returnTS.URL + "/after-login") + client := stopOnExternalRedirect(t) + res, err := client.Get(thirdPartyURL + "?iss=" + issuer + "&target_link_uri=" + target) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusSeeOther, res.StatusCode) + location := res.Header.Get("Location") + require.NotEmpty(t, location) + + locURL, err := url.Parse(location) + require.NoError(t, err) + mockURL, _ := url.Parse(mockOIDCIssuer) + assert.Equal(t, mockURL.Host, locURL.Host) + }) + + t.Run("case=should handle trailing slash in issuer", func(t *testing.T) { + issuer := url.QueryEscape(mockOIDCIssuer + "/") + client := stopOnExternalRedirect(t) + res, err := client.Get(thirdPartyURL + "?iss=" + issuer) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusSeeOther, res.StatusCode) + location := res.Header.Get("Location") + require.NotEmpty(t, location) + + locURL, err := url.Parse(location) + require.NoError(t, err) + mockURL, _ := url.Parse(mockOIDCIssuer) + assert.Equal(t, mockURL.Host, locURL.Host) + }) + + t.Run("case=should work with POST method", func(t *testing.T) { + client := stopOnExternalRedirect(t) + res, err := client.PostForm(thirdPartyURL, url.Values{ + "iss": {mockOIDCIssuer}, + "login_hint": {"user@example.com"}, + }) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusSeeOther, res.StatusCode) + location := res.Header.Get("Location") + require.NotEmpty(t, location) + + locURL, err := url.Parse(location) + require.NoError(t, err) + mockURL, _ := url.Parse(mockOIDCIssuer) + assert.Equal(t, mockURL.Host, locURL.Host) + assert.Equal(t, "user@example.com", locURL.Query().Get("login_hint")) + }) + + t.Run("case=should return 404 when OIDC strategy is disabled", func(t *testing.T) { + baseKey := fmt.Sprintf("%s.%s", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeOIDC) + conf.MustSet(ctx, baseKey+".enabled", false) + t.Cleanup(func() { + conf.MustSet(ctx, baseKey+".enabled", true) + }) + + res, err := noRedirectClient(t).Get(thirdPartyURL + "?iss=https://example.com") + require.NoError(t, err) + defer res.Body.Close() + + assert.Equal(t, http.StatusNotFound, res.StatusCode) + }) +}