@@ -19,22 +19,32 @@ import (
1919 "github.com/supabase/auth/internal/utilities"
2020)
2121
22- // loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider
23- // with that ID (or resource ID) and adds it to the context.
22+ // loadSSOProvider looks for an idp_id and first checks it for a "resource_"
23+ // prefix, if present the provider is loaded by resource_id. Otherwise the
24+ // provider is loaded by id.
2425func (a * API ) loadSSOProvider (w http.ResponseWriter , r * http.Request ) (context.Context , error ) {
2526 ctx := r .Context ()
2627 db := a .db .WithContext (ctx )
2728
28- idpParam := chi .URLParam (r , "idp_id" )
29+ var (
30+ provider * models.SSOProvider
31+ err error
32+ )
2933
30- idpID , err := uuid .FromString (idpParam )
31- if err != nil {
32- // idpParam is not UUIDv4
33- return nil , apierrors .NewNotFoundError (apierrors .ErrorCodeSSOProviderNotFound , "SSO Identity Provider not found" )
34+ const resourcePrefix = "resource_"
35+ idpParam := chi .URLParam (r , "idp_id" )
36+ switch {
37+ case strings .HasPrefix (idpParam , resourcePrefix ):
38+ resourceID := strings .TrimPrefix (idpParam , resourcePrefix )
39+ provider , err = models .FindSSOProviderByResourceID (db , resourceID )
40+ default :
41+ idpID , idpErr := uuid .FromString (idpParam )
42+ if idpErr != nil {
43+ return nil , apierrors .NewNotFoundError (apierrors .ErrorCodeSSOProviderNotFound , "SSO Identity Provider not found" )
44+ }
45+ provider , err = models .FindSSOProviderByID (db , idpID )
3446 }
3547
36- // idpParam is a UUIDv4
37- provider , err := models .FindSSOProviderByID (db , idpID )
3848 if err != nil {
3949 if models .IsNotFoundError (err ) {
4050 return nil , apierrors .NewNotFoundError (apierrors .ErrorCodeSSOProviderNotFound , "SSO Identity Provider not found" )
@@ -44,17 +54,16 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C
4454 }
4555
4656 observability .LogEntrySetField (r , "sso_provider_id" , provider .ID .String ())
47-
4857 return withSSOProvider (r .Context (), provider ), nil
4958}
5059
51- // adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does
60+ // adminSSOProvidersList lists all SSO Identity Providers in the system. Does
5261// not deal with pagination at this time.
5362func (a * API ) adminSSOProvidersList (w http.ResponseWriter , r * http.Request ) error {
5463 ctx := r .Context ()
5564 db := a .db .WithContext (ctx )
5665
57- providers , err := models .FindAllSAMLProviders (db )
66+ providers , err := models .FindAllSSOProvidersByFilter (db , r . URL . Query () )
5867 if err != nil {
5968 return err
6069 }
@@ -77,6 +86,9 @@ type CreateSSOProviderParams struct {
7786 Domains []string `json:"domains"`
7887 AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"`
7988 NameIDFormat string `json:"name_id_format"`
89+
90+ ResourceID * string `json:"resource_id,omitempty"`
91+ Disabled * bool `json:"disabled,omitempty"`
8092}
8193
8294func (p * CreateSSOProviderParams ) validate (forUpdate bool ) error {
@@ -223,17 +235,23 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
223235 }
224236
225237 provider := & models.SSOProvider {
238+
226239 // TODO handle Name, Description, Attribute Mapping
227240 SAMLProvider : models.SAMLProvider {
228241 EntityID : metadata .EntityID ,
229242 MetadataXML : string (rawMetadata ),
230243 },
231244 }
232245
246+ if params .ResourceID != nil {
247+ provider .ResourceID = params .ResourceID
248+ }
249+ if params .Disabled != nil {
250+ provider .Disabled = params .Disabled
251+ }
233252 if params .MetadataURL != "" {
234253 provider .SAMLProvider .MetadataURL = & params .MetadataURL
235254 }
236-
237255 if params .NameIDFormat != "" {
238256 provider .SAMLProvider .NameIDFormat = & params .NameIDFormat
239257 }
@@ -374,6 +392,28 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
374392 }
375393 }
376394
395+ if params .ResourceID != nil {
396+ resourceID := * params .ResourceID
397+ switch {
398+ case resourceID == "" && provider .ResourceID != nil :
399+ provider .ResourceID = nil
400+ modified = true
401+ case resourceID != "" &&
402+ (provider .ResourceID == nil ||
403+ * provider .ResourceID != resourceID ):
404+ provider .ResourceID = & resourceID
405+ modified = true
406+ }
407+ }
408+
409+ if params .Disabled != nil {
410+ disabled := * params .Disabled
411+ if provider .Disabled == nil || * provider .Disabled != disabled {
412+ provider .Disabled = & disabled
413+ modified = true
414+ }
415+ }
416+
377417 if modified {
378418 if err := db .Transaction (func (tx * storage.Connection ) error {
379419 if terr := tx .Eager ().Update (provider ); terr != nil {
0 commit comments