Skip to content

Commit 7fedd92

Browse files
authored
feat(auth): Create OrganizationID aware authenticator for use in http middleware (#537)
* Add additional issuer test * Fetch claims from request using generics instead of strict typing * Create organization aware authorizer struct to be used in auth middleware * Create fx module for org aware authenticator * Use public ClaimsFromRequest func in JWTAuth Authenticate function * Return http.StatusForbidden in middleware when orgID specific errors are seen * Use AdditionalChecks array as part of auth module configuration * Create additional FXModuleFromFlags function that allows appending custom checks
1 parent ea66e37 commit 7fedd92

12 files changed

Lines changed: 789 additions & 153 deletions

auth/additional_checks.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package auth
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
"github.com/formancehq/go-libs/v3/oidc"
8+
)
9+
10+
type AdditionalCheck func(*http.Request, *oidc.AccessTokenClaims) error
11+
12+
// OrganizationIDProvider should give the authorizer the ability
13+
// to know what orgID (if any) is associated with the resource the requester is attempting to access
14+
// if no orgID is required, a blank string can be returned
15+
type OrganizationIDProvider func(*http.Request) (orgID string, err error)
16+
17+
func CheckOrganizationIDClaim(fn OrganizationIDProvider) AdditionalCheck {
18+
return func(r *http.Request, rawClaims *oidc.AccessTokenClaims) error {
19+
if rawClaims == nil {
20+
return fmt.Errorf("claims cannot be nil")
21+
}
22+
claims := &oidc.OrganizationAwareAccessTokenClaims{AccessTokenClaims: *rawClaims}
23+
24+
expectedOrgID, err := fn(r)
25+
if err != nil {
26+
return err
27+
}
28+
29+
// if the endpoint doesn't require a particular orgID we consider it valid
30+
if expectedOrgID == "" {
31+
return nil
32+
}
33+
34+
orgID := claims.GetOrganizationID()
35+
if orgID == "" {
36+
return oidc.ErrOrgIDNotPresent
37+
}
38+
39+
if expectedOrgID != "" && orgID != expectedOrgID {
40+
return oidc.ErrOrgIDInvalid
41+
}
42+
return nil
43+
}
44+
}

auth/auth.go

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,61 +2,54 @@ package auth
22

33
import (
44
"errors"
5-
"fmt"
65
"net/http"
76
"strings"
87

9-
"github.com/formancehq/go-libs/v3/collectionutils"
108
"github.com/formancehq/go-libs/v3/oidc"
119
)
1210

1311
type JWTAuth struct {
14-
issuer string
15-
checkScopes bool
16-
service string
17-
keySet oidc.KeySet
12+
issuer string
13+
checkScopes bool
14+
service string
15+
keySet oidc.KeySet
16+
additionalChecks []AdditionalCheck
1817
}
1918

2019
func NewJWTAuth(
2120
keySet oidc.KeySet,
2221
issuer string,
2322
service string,
2423
checkScopes bool,
24+
additionalChecks []AdditionalCheck,
2525
) *JWTAuth {
2626
return &JWTAuth{
27-
issuer: issuer,
28-
checkScopes: checkScopes,
29-
service: service,
30-
keySet: keySet,
27+
issuer: issuer,
28+
checkScopes: checkScopes,
29+
service: service,
30+
keySet: keySet,
31+
additionalChecks: additionalChecks,
3132
}
3233
}
3334

3435
// Authenticate validates the JWT in the request and returns the user, if valid.
3536
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {
36-
3737
claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet)
3838
if err != nil {
3939
return false, err
4040
}
4141

42-
if ja.checkScopes {
43-
scope := claims.Scopes
44-
45-
allowed := true //nolint:ineffassign
46-
switch r.Method {
47-
case http.MethodOptions, http.MethodGet, http.MethodHead, http.MethodTrace:
48-
allowed = collectionutils.Contains(scope, ja.service+":read") ||
49-
collectionutils.Contains(scope, ja.service+":write")
50-
default:
51-
allowed = collectionutils.Contains(scope, ja.service+":write")
52-
}
53-
54-
if !allowed {
55-
return false, fmt.Errorf("missing access, found scopes: '%s' need %s:read|write", strings.Join(scope, ", "), ja.service)
42+
for _, check := range ja.additionalChecks {
43+
err := check(r, claims)
44+
if err != nil {
45+
return false, err
5646
}
5747
}
5848

59-
return true, nil
49+
if !ja.checkScopes {
50+
return true, nil
51+
}
52+
return checkScopes(ja.service, r.Method, claims.Scopes)
6053
}
6154

6255
var (
@@ -65,32 +58,43 @@ var (
6558
)
6659

6760
func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySet) (*oidc.AccessTokenClaims, error) {
61+
claims := &oidc.AccessTokenClaims{}
62+
if err := claimsFromRequest(r, claims, keySet); err != nil {
63+
return claims, err
64+
}
65+
66+
if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
67+
return claims, err
68+
}
69+
70+
if err := oidc.CheckExpiration(claims, 0); err != nil {
71+
return claims, err
72+
}
73+
74+
return claims, nil
75+
}
6876

77+
func claimsFromRequest[CLAIMS any](r *http.Request, claims CLAIMS, keySet oidc.KeySet) error {
6978
authHeader := r.Header.Get("authorization")
7079
if authHeader == "" {
71-
return nil, ErrNoAuthorizationHeader
80+
return ErrNoAuthorizationHeader
7281
}
7382

7483
if !strings.HasPrefix(authHeader, "bearer") &&
7584
!strings.HasPrefix(authHeader, "Bearer") {
76-
return nil, ErrMalformedHeader
85+
return ErrMalformedHeader
7786
}
7887

7988
token := authHeader[6:]
8089
token = strings.TrimSpace(token)
8190

82-
claims := &oidc.AccessTokenClaims{}
8391
decrypted, err := oidc.DecryptToken(token)
8492
if err != nil {
85-
return nil, err
93+
return err
8694
}
8795
payload, err := oidc.ParseToken(decrypted, &claims)
8896
if err != nil {
89-
return nil, err
90-
}
91-
92-
if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
93-
return claims, err
97+
return err
9498
}
9599

96100
if _, err = oidc.CheckSignature(
@@ -100,12 +104,8 @@ func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySe
100104
[]string{}, // Default to RS256
101105
keySet,
102106
); err != nil {
103-
return claims, err
107+
return err
104108
}
105109

106-
if err = oidc.CheckExpiration(claims, 0); err != nil {
107-
return claims, err
108-
}
109-
110-
return claims, nil
110+
return nil
111111
}

0 commit comments

Comments
 (0)