diff --git a/.gitignore b/.gitignore index 9ed3b07..89ca060 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ *.test + +/.idea diff --git a/build_request.go b/build_request.go index 0885c1c..675c5bd 100644 --- a/build_request.go +++ b/build_request.go @@ -78,6 +78,13 @@ func (sp *SAMLServiceProvider) buildAuthnRequest(includeSig bool) (*etree.Docume } } + for _, processor := range sp.AuthNRequestProcessors { + err := processor.Process(authnRequest) + if err != nil { + return nil, err + } + } + doc := etree.NewDocument() // Only POST binding includes in (includeSig) @@ -254,7 +261,7 @@ func (sp *SAMLServiceProvider) buildAuthBodyPostFromDocument(relayState string, return rv.Bytes(), nil } -//BuildAuthBodyPost builds the POST body to be sent to IDP. +// BuildAuthBodyPost builds the POST body to be sent to IDP. func (sp *SAMLServiceProvider) BuildAuthBodyPost(relayState string) ([]byte, error) { var doc *etree.Document var err error @@ -272,8 +279,8 @@ func (sp *SAMLServiceProvider) BuildAuthBodyPost(relayState string) ([]byte, err return sp.buildAuthBodyPostFromDocument(relayState, doc) } -//BuildAuthBodyPostFromDocument builds the POST body to be sent to IDP. -//It takes the AuthnRequest xml as input. +// BuildAuthBodyPostFromDocument builds the POST body to be sent to IDP. +// It takes the AuthnRequest xml as input. func (sp *SAMLServiceProvider) BuildAuthBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) { return sp.buildAuthBodyPostFromDocument(relayState, doc) } @@ -382,8 +389,8 @@ func (sp *SAMLServiceProvider) BuildLogoutRequestDocument(nameID string, session return sp.buildLogoutRequest(true, nameID, sessionIndex) } -//BuildLogoutBodyPostFromDocument builds the POST body to be sent to IDP. -//It takes the LogoutRequest xml as input. +// BuildLogoutBodyPostFromDocument builds the POST body to be sent to IDP. +// It takes the LogoutRequest xml as input. func (sp *SAMLServiceProvider) BuildLogoutBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) { return sp.buildLogoutBodyPostFromDocument(relayState, doc) } @@ -555,3 +562,16 @@ func signatureInputString(samlRequest, relayState, sigAlg string) string { } return buf.String() } + +type AddIdpScoping struct { + ProviderId string + Name string +} + +func (a *AddIdpScoping) Process(doc *etree.Element) error { + idpList := doc.CreateElement("samlp:Scoping").CreateElement("samlp:IDPList") + idpEntry := idpList.CreateElement("samlp:IDPEntry") + idpEntry.CreateAttr("ProviderID", a.ProviderId) + idpEntry.CreateAttr("Name", a.Name) + return nil +} diff --git a/build_request_test.go b/build_request_test.go index dbb43de..05f8b81 100644 --- a/build_request_test.go +++ b/build_request_test.go @@ -213,3 +213,34 @@ func TestIsPassiveIncluded(t *testing.T) { require.NotNil(t, attr) require.Equal(t, "true", attr.Value) } + +func TestAddIdpScopingExtension(t *testing.T) { + spURL := "https://sp.test" + extension := AddIdpScoping{ + ProviderId: "foo", + Name: "bar", + } + + sp := SAMLServiceProvider{ + AssertionConsumerServiceURL: spURL, + AudienceURI: spURL, + IdentityProviderIssuer: spURL, + IdentityProviderSSOURL: "https://idp.test/saml/sso", + + // Add IdP scoping extension + AuthNRequestProcessors: []AuthNRequestProcessor{&extension}, + } + + request, err := sp.BuildAuthRequest() + require.NoError(t, err) + + doc := etree.NewDocument() + err = doc.ReadFromString(request) + require.NoError(t, err) + + el := doc.FindElement("./AuthnRequest/Scoping/IDPList/IDPEntry") + + require.NotNil(t, el) + require.Equal(t, el.SelectAttrValue("ProviderID", ""), "foo") + require.Equal(t, el.SelectAttrValue("Name", ""), "bar") +} diff --git a/saml.go b/saml.go index 49a2fb8..adff0c5 100644 --- a/saml.go +++ b/saml.go @@ -17,6 +17,7 @@ package saml2 import ( "crypto" "encoding/base64" + "github.com/beevik/etree" "sync" "time" @@ -37,6 +38,10 @@ func (serr ErrSaml) Error() string { return "SAML error" } +type AuthNRequestProcessor interface { + Process(request *etree.Element) error +} + type SAMLServiceProvider struct { IdentityProviderSSOURL string IdentityProviderSSOBinding string @@ -74,6 +79,8 @@ type SAMLServiceProvider struct { AllowMissingAttributes bool Clock *dsig.Clock + AuthNRequestProcessors []AuthNRequestProcessor + // Required encryption key and default signing key. // Deprecated: Use SetSPKeyStore instead of setting or reading this field. SPKeyStore dsig.X509KeyStore