Skip to content

Commit 2f323fe

Browse files
authored
feat: fix large group claim handling in azure id tokens (#1995)
Handles [large `group` claims in Azure ID tokens](https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference#groups-overage-claim) by fetching them from the ([usually](https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects?view=graph-rest-1.0&tabs=http)) designated Azure endpoint.
1 parent f94f97e commit 2f323fe

File tree

3 files changed

+341
-107
lines changed

3 files changed

+341
-107
lines changed

internal/api/provider/azure.go

+197
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ import (
55
"encoding/base64"
66
"encoding/json"
77
"fmt"
8+
"io"
9+
"net/http"
810
"regexp"
911
"strings"
12+
"unicode/utf8"
1013

1114
"github.com/coreos/go-oidc/v3/oidc"
15+
"github.com/golang-jwt/jwt/v5"
1216
"github.com/supabase/auth/internal/conf"
1317
"golang.org/x/oauth2"
1418
)
@@ -162,3 +166,196 @@ func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use
162166

163167
return nil, fmt.Errorf("azure: no OIDC ID token present in response")
164168
}
169+
170+
type AzureIDTokenClaimSource struct {
171+
Endpoint string `json:"endpoint"`
172+
}
173+
174+
type AzureIDTokenClaims struct {
175+
jwt.RegisteredClaims
176+
177+
Email string `json:"email"`
178+
Name string `json:"name"`
179+
PreferredUsername string `json:"preferred_username"`
180+
XMicrosoftEmailDomainOwnerVerified any `json:"xms_edov"`
181+
182+
ClaimNames map[string]string `json:"__claim_names"`
183+
ClaimSources map[string]AzureIDTokenClaimSource `json:"__claim_sources"`
184+
}
185+
186+
// ResolveIndirectClaims resolves claims in the Azure Token that require a call to the Microsoft Graph API. This is typically to an API like this: https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects?view=graph-rest-1.0&tabs=http
187+
func (c *AzureIDTokenClaims) ResolveIndirectClaims(ctx context.Context, httpClient *http.Client, accessToken string) (map[string]any, error) {
188+
if len(c.ClaimNames) == 0 || len(c.ClaimSources) == 0 {
189+
return nil, nil
190+
}
191+
192+
result := make(map[string]any)
193+
194+
for claimName, claimSource := range c.ClaimNames {
195+
claimEndpointObject, ok := c.ClaimSources[claimSource]
196+
197+
if !ok || !strings.HasPrefix(claimEndpointObject.Endpoint, "https://") {
198+
continue
199+
}
200+
201+
claimEndpoint := claimEndpointObject.Endpoint
202+
203+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, claimEndpoint, strings.NewReader(`{"securityEnabledOnly":true}`))
204+
if err != nil {
205+
return nil, fmt.Errorf("azure: failed to create POST request to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err)
206+
}
207+
208+
req.Header.Add("Authorization", "Bearer "+accessToken)
209+
req.Header.Add("Content-Type", "application/json")
210+
211+
resp, err := httpClient.Do(req)
212+
if err != nil {
213+
return nil, fmt.Errorf("azure: failed to send POST request to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err)
214+
}
215+
216+
defer resp.Body.Close()
217+
218+
if resp.StatusCode != http.StatusOK {
219+
resBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024))
220+
221+
body := "<empty>"
222+
if len(resBody) > 0 {
223+
if utf8.Valid(resBody) {
224+
body = string(resBody)
225+
} else {
226+
body = "<invalid-utf8>"
227+
}
228+
}
229+
230+
readErrString := ""
231+
if readErr != nil {
232+
readErrString = fmt.Sprintf(" with read error %q", readErr.Error())
233+
}
234+
235+
return nil, fmt.Errorf("azure: received %d but expected 200 HTTP status code when sending POST to %q (resolving overage claim %q) with response body %q%s", resp.StatusCode, claimEndpoint, claimName, body, readErrString)
236+
}
237+
238+
var responseResult struct {
239+
Value any `json:"value"`
240+
}
241+
242+
if err := json.NewDecoder(resp.Body).Decode(&responseResult); err != nil {
243+
return nil, fmt.Errorf("azure: failed to parse JSON response from POST to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err)
244+
}
245+
246+
result[claimName] = responseResult.Value
247+
}
248+
249+
return result, nil
250+
}
251+
252+
func (c *AzureIDTokenClaims) IsEmailVerified() bool {
253+
emailVerified := false
254+
255+
edov := c.XMicrosoftEmailDomainOwnerVerified
256+
257+
// If xms_edov is not set, and an email is present or xms_edov is true,
258+
// only then is the email regarded as verified.
259+
// https://learn.microsoft.com/en-us/azure/active-directory/develop/migrate-off-email-claim-authorization#using-the-xms_edov-optional-claim-to-determine-email-verification-status-and-migrate-users
260+
if edov == nil {
261+
// An email is provided, but xms_edov is not -- probably not
262+
// configured, so we must assume the email is verified as Azure
263+
// will only send out a potentially unverified email address in
264+
// single-tenanat apps.
265+
emailVerified = c.Email != ""
266+
} else {
267+
edovBool := false
268+
269+
// Azure can't be trusted with how they encode the xms_edov
270+
// claim. Sometimes it's "xms_edov": "1", sometimes "xms_edov": true.
271+
switch v := edov.(type) {
272+
case bool:
273+
edovBool = v
274+
275+
case string:
276+
edovBool = v == "1" || v == "true"
277+
278+
default:
279+
edovBool = false
280+
}
281+
282+
emailVerified = c.Email != "" && edovBool
283+
}
284+
285+
return emailVerified
286+
}
287+
288+
// removeAzureClaimsFromCustomClaims contains the list of claims to be removed
289+
// from the CustomClaims map. See:
290+
// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference
291+
var removeAzureClaimsFromCustomClaims = []string{
292+
"aud",
293+
"iss",
294+
"iat",
295+
"nbf",
296+
"exp",
297+
"c_hash",
298+
"at_hash",
299+
"aio",
300+
"nonce",
301+
"rh",
302+
"uti",
303+
"jti",
304+
"ver",
305+
"sub",
306+
"name",
307+
"preferred_username",
308+
}
309+
310+
func parseAzureIDToken(ctx context.Context, token *oidc.IDToken, accessToken string) (*oidc.IDToken, *UserProvidedData, error) {
311+
var data UserProvidedData
312+
313+
var azureClaims AzureIDTokenClaims
314+
if err := token.Claims(&azureClaims); err != nil {
315+
return nil, nil, err
316+
}
317+
318+
data.Metadata = &Claims{
319+
Issuer: token.Issuer,
320+
Subject: token.Subject,
321+
ProviderId: token.Subject,
322+
PreferredUsername: azureClaims.PreferredUsername,
323+
FullName: azureClaims.Name,
324+
CustomClaims: make(map[string]any),
325+
}
326+
327+
if azureClaims.Email != "" {
328+
data.Emails = []Email{{
329+
Email: azureClaims.Email,
330+
Verified: azureClaims.IsEmailVerified(),
331+
Primary: true,
332+
}}
333+
}
334+
335+
if err := token.Claims(&data.Metadata.CustomClaims); err != nil {
336+
return nil, nil, err
337+
}
338+
339+
resolvedClaims, err := azureClaims.ResolveIndirectClaims(ctx, http.DefaultClient, accessToken)
340+
if err != nil {
341+
return nil, nil, err
342+
}
343+
344+
if data.Metadata.CustomClaims == nil {
345+
if resolvedClaims != nil {
346+
data.Metadata.CustomClaims = make(map[string]any, len(resolvedClaims))
347+
}
348+
}
349+
350+
if data.Metadata.CustomClaims != nil {
351+
for _, claim := range removeAzureClaimsFromCustomClaims {
352+
delete(data.Metadata.CustomClaims, claim)
353+
}
354+
}
355+
356+
for k, v := range resolvedClaims {
357+
data.Metadata.CustomClaims[k] = v
358+
}
359+
360+
return token, &data, nil
361+
}

internal/api/provider/azure_test.go

+143-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
package provider
22

3-
import "testing"
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
8+
"strings"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
413

514
func TestIsAzureIssuer(t *testing.T) {
615
positiveExamples := []string{
@@ -27,3 +36,136 @@ func TestIsAzureIssuer(t *testing.T) {
2736
}
2837
}
2938
}
39+
40+
func TestAzureResolveIndirectClaims(t *testing.T) {
41+
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
42+
w.Header().Add("Content-Type", "application/json")
43+
w.WriteHeader(http.StatusOK)
44+
45+
w.Write([]byte(`{
46+
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#Collection(Edm.String)",
47+
"value": [
48+
"fee2c45b-915a-4a64-b130-f4eb9e75525e",
49+
"4fe90ae7-065a-478b-9400-e0a0e1cbd540",
50+
"c9ee2d50-9e8a-4352-b97c-4c2c99557c22",
51+
"e0c3beaf-eeb4-43d8-abc5-94f037a65697"
52+
]
53+
}`))
54+
}))
55+
56+
defer server.Close()
57+
58+
var claims AzureIDTokenClaims
59+
60+
resolvedClaims, err := claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
61+
require.Nil(t, resolvedClaims)
62+
require.Nil(t, err)
63+
64+
claims.ClaimNames = make(map[string]string)
65+
66+
resolvedClaims, err = claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
67+
require.Nil(t, resolvedClaims)
68+
require.Nil(t, err)
69+
70+
claims.ClaimNames = map[string]string{
71+
"groups": "src1",
72+
"missing-source": "src2",
73+
"not-https": "src3",
74+
}
75+
claims.ClaimSources = map[string]AzureIDTokenClaimSource{
76+
"src1": {
77+
Endpoint: server.URL,
78+
},
79+
"src3": {
80+
Endpoint: "http://example.com",
81+
},
82+
}
83+
84+
resolvedClaims, err = claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
85+
require.NoError(t, err)
86+
require.NotNil(t, resolvedClaims)
87+
require.Equal(t, 1, len(resolvedClaims))
88+
require.Equal(t, 4, len(resolvedClaims["groups"].([]interface{})))
89+
}
90+
91+
func TestAzureResolveIndirectClaimsFailures(t *testing.T) {
92+
examples := []struct {
93+
name string
94+
urlSuffix string
95+
statusCode int
96+
body []byte
97+
expectedError string
98+
}{
99+
{
100+
name: "invalid url",
101+
urlSuffix: "\000",
102+
expectedError: "azure: failed to create POST request to \"SERVER-URL\\x00\" (resolving overage claim \"groups\"): parse \"SERVER-URL\\x00\": net/url: invalid control character in URL",
103+
},
104+
{
105+
name: "no such server",
106+
urlSuffix: "000",
107+
expectedError: "azure: failed to send POST request to \"SERVER-URL000\" (resolving overage claim \"groups\"): Post \"SERVER-URL000\": dial tcp: address PORT000: invalid port",
108+
},
109+
{
110+
name: "non 200 status code",
111+
statusCode: 500,
112+
body: []byte(`something is wrong`),
113+
expectedError: "azure: received 500 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"something is wrong\"",
114+
},
115+
{
116+
name: "non 200 status code, non utf8 valid body",
117+
statusCode: 201,
118+
body: []byte{255, 255, 255, 255},
119+
expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"<invalid-utf8>\"",
120+
},
121+
{
122+
name: "non 200 status code, empty body",
123+
statusCode: 201,
124+
body: []byte{},
125+
expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"<empty>\"",
126+
},
127+
{
128+
name: "non 200 status code, body over 2KB",
129+
statusCode: 201,
130+
body: []byte(strings.Repeat("x", 2*1024+1)),
131+
expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"",
132+
},
133+
{
134+
name: "ok response, not json",
135+
statusCode: 200,
136+
body: []byte("not json"),
137+
expectedError: "azure: failed to parse JSON response from POST to \"SERVER-URL\" (resolving overage claim \"groups\"): invalid character 'o' in literal null (expecting 'u')",
138+
},
139+
}
140+
141+
for _, example := range examples {
142+
t.Run(example.name, func(t *testing.T) {
143+
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
144+
w.WriteHeader(example.statusCode)
145+
146+
w.Write(example.body)
147+
}))
148+
149+
defer server.Close()
150+
151+
u, _ := url.Parse(server.URL)
152+
153+
var claims AzureIDTokenClaims
154+
155+
claims.ClaimNames = map[string]string{
156+
"groups": "src1",
157+
}
158+
claims.ClaimSources = map[string]AzureIDTokenClaimSource{
159+
"src1": {
160+
Endpoint: server.URL + example.urlSuffix,
161+
},
162+
}
163+
164+
resolvedClaims, err := claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
165+
require.Nil(t, resolvedClaims)
166+
require.Error(t, err)
167+
require.Equal(t, example.expectedError, strings.ReplaceAll(strings.ReplaceAll(err.Error(), server.URL, "SERVER-URL"), u.Port(), "PORT"))
168+
})
169+
}
170+
171+
}

0 commit comments

Comments
 (0)