@@ -27,12 +27,13 @@ import (
2727// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
2828type ExternalProviderClaims struct {
2929 AuthMicroserviceClaims
30- Provider string `json:"provider"`
31- InviteToken string `json:"invite_token,omitempty"`
32- Referrer string `json:"referrer,omitempty"`
33- FlowStateID string `json:"flow_state_id"`
34- LinkingTargetID string `json:"linking_target_id,omitempty"`
35- EmailOptional bool `json:"email_optional,omitempty"`
30+ Provider string `json:"provider"`
31+ InviteToken string `json:"invite_token,omitempty"`
32+ Referrer string `json:"referrer,omitempty"`
33+ FlowStateID string `json:"flow_state_id"`
34+ OAuthClientStateID string `json:"oauth_client_state_id,omitempty"`
35+ LinkingTargetID string `json:"linking_target_id,omitempty"`
36+ EmailOptional bool `json:"email_optional,omitempty"`
3637}
3738
3839// ExternalProviderRedirect redirects the request to the oauth provider
@@ -90,6 +91,32 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
9091 flowStateID = flowState .ID .String ()
9192 }
9293
94+ authUrlParams := make ([]oauth2.AuthCodeOption , 0 )
95+ query .Del ("scopes" )
96+ query .Del ("provider" )
97+ query .Del ("code_challenge" )
98+ query .Del ("code_challenge_method" )
99+ for key := range query {
100+ if key == "workos_provider" {
101+ // See https://workos.com/docs/reference/sso/authorize/get
102+ authUrlParams = append (authUrlParams , oauth2 .SetAuthURLParam ("provider" , query .Get (key )))
103+ } else {
104+ authUrlParams = append (authUrlParams , oauth2 .SetAuthURLParam (key , query .Get (key )))
105+ }
106+ }
107+
108+ oauthClientStateID := ""
109+ if oauthProvider , ok := p .(provider.OAuthProvider ); ok && oauthProvider .RequiresPKCE () {
110+ codeVerifier := oauth2 .GenerateVerifier ()
111+ oauthClientState := models .NewOAuthClientState (providerType , & codeVerifier )
112+ err := db .Create (oauthClientState )
113+ if err != nil {
114+ return "" , err
115+ }
116+ oauthClientStateID = oauthClientState .ID .String ()
117+ authUrlParams = append (authUrlParams , oauth2 .S256ChallengeOption (codeVerifier ))
118+ }
119+
93120 claims := ExternalProviderClaims {
94121 AuthMicroserviceClaims : AuthMicroserviceClaims {
95122 RegisteredClaims : jwt.RegisteredClaims {
@@ -98,11 +125,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
98125 SiteURL : config .SiteURL ,
99126 InstanceID : uuid .Nil .String (),
100127 },
101- Provider : providerType ,
102- InviteToken : inviteToken ,
103- Referrer : redirectURL ,
104- FlowStateID : flowStateID ,
105- EmailOptional : pConfig .EmailOptional ,
128+ Provider : providerType ,
129+ InviteToken : inviteToken ,
130+ Referrer : redirectURL ,
131+ FlowStateID : flowStateID ,
132+ OAuthClientStateID : oauthClientStateID ,
133+ EmailOptional : pConfig .EmailOptional ,
106134 }
107135
108136 if linkingTargetUser != nil {
@@ -115,20 +143,6 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
115143 return "" , apierrors .NewInternalServerError ("Error creating state" ).WithInternalError (err )
116144 }
117145
118- authUrlParams := make ([]oauth2.AuthCodeOption , 0 )
119- query .Del ("scopes" )
120- query .Del ("provider" )
121- query .Del ("code_challenge" )
122- query .Del ("code_challenge_method" )
123- for key := range query {
124- if key == "workos_provider" {
125- // See https://workos.com/docs/reference/sso/authorize/get
126- authUrlParams = append (authUrlParams , oauth2 .SetAuthURLParam ("provider" , query .Get (key )))
127- } else {
128- authUrlParams = append (authUrlParams , oauth2 .SetAuthURLParam (key , query .Get (key )))
129- }
130- }
131-
132146 authURL := p .AuthCodeURL (tokenString , authUrlParams ... )
133147
134148 return authURL , nil
@@ -565,6 +579,13 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag
565579 if claims .FlowStateID != "" {
566580 ctx = withFlowStateID (ctx , claims .FlowStateID )
567581 }
582+ if claims .OAuthClientStateID != "" {
583+ oauthClientStateID , err := uuid .FromString (claims .OAuthClientStateID )
584+ if err != nil {
585+ return nil , apierrors .NewBadRequestError (apierrors .ErrorCodeBadOAuthState , "OAuth callback with invalid state (oauth_client_state_id must be UUID)" )
586+ }
587+ ctx = withOAuthClientStateID (ctx , oauthClientStateID )
588+ }
568589 if claims .LinkingTargetID != "" {
569590 linkingTargetUserID , err := uuid .FromString (claims .LinkingTargetID )
570591 if err != nil {
@@ -634,7 +655,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
634655 p , err = provider .NewLinkedinProvider (pConfig , scopes )
635656 case "linkedin_oidc" :
636657 pConfig = config .External .LinkedinOIDC
637- p , err = provider .NewLinkedinOIDCProvider (pConfig , scopes )
658+ p , err = provider .NewLinkedinOIDCProvider (ctx , pConfig , scopes )
638659 case "notion" :
639660 pConfig = config .External .Notion
640661 p , err = provider .NewNotionProvider (pConfig )
@@ -656,9 +677,12 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
656677 case "twitter" :
657678 pConfig = config .External .Twitter
658679 p , err = provider .NewTwitterProvider (pConfig , scopes )
680+ case "x" :
681+ pConfig = config .External .X
682+ p , err = provider .NewXProvider (pConfig , scopes )
659683 case "vercel_marketplace" :
660684 pConfig = config .External .VercelMarketplace
661- p , err = provider .NewVercelMarketplaceProvider (pConfig , scopes )
685+ p , err = provider .NewVercelMarketplaceProvider (ctx , pConfig , scopes )
662686 case "workos" :
663687 pConfig = config .External .WorkOS
664688 p , err = provider .NewWorkOSProvider (pConfig )
0 commit comments