Skip to content
44 changes: 44 additions & 0 deletions auth/additional_checks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package auth

import (
"fmt"
"net/http"

"github.com/formancehq/go-libs/v3/oidc"
)

type AdditionalCheck func(*http.Request, *oidc.AccessTokenClaims) error

// OrganizationIDProvider should give the authorizer the ability
// to know what orgID (if any) is associated with the resource the requester is attempting to access
// if no orgID is required, a blank string can be returned
type OrganizationIDProvider func(*http.Request) (orgID string, err error)

func CheckOrganizationIDClaim(fn OrganizationIDProvider) AdditionalCheck {
return func(r *http.Request, rawClaims *oidc.AccessTokenClaims) error {
if rawClaims == nil {
return fmt.Errorf("claims cannot be nil")
}
claims := &oidc.OrganizationAwareAccessTokenClaims{AccessTokenClaims: *rawClaims}

expectedOrgID, err := fn(r)
if err != nil {
return err
}

// if the endpoint doesn't require a particular orgID we consider it valid
if expectedOrgID == "" {
return nil
}

orgID := claims.GetOrganizationID()
if orgID == "" {
return oidc.ErrOrgIDNotPresent
}

if expectedOrgID != "" && orgID != expectedOrgID {
return oidc.ErrOrgIDInvalid
}
return nil
}
}
82 changes: 41 additions & 41 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,54 @@ package auth

import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/formancehq/go-libs/v3/collectionutils"
"github.com/formancehq/go-libs/v3/oidc"
)

type JWTAuth struct {
issuer string
checkScopes bool
service string
keySet oidc.KeySet
issuer string
checkScopes bool
service string
keySet oidc.KeySet
additionalChecks []AdditionalCheck
}

func NewJWTAuth(
keySet oidc.KeySet,
issuer string,
service string,
checkScopes bool,
additionalChecks []AdditionalCheck,
) *JWTAuth {
return &JWTAuth{
issuer: issuer,
checkScopes: checkScopes,
service: service,
keySet: keySet,
issuer: issuer,
checkScopes: checkScopes,
service: service,
keySet: keySet,
additionalChecks: additionalChecks,
}
}

// Authenticate validates the JWT in the request and returns the user, if valid.
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {

claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet)
if err != nil {
return false, err
}

if ja.checkScopes {
scope := claims.Scopes

allowed := true //nolint:ineffassign
switch r.Method {
case http.MethodOptions, http.MethodGet, http.MethodHead, http.MethodTrace:
allowed = collectionutils.Contains(scope, ja.service+":read") ||
collectionutils.Contains(scope, ja.service+":write")
default:
allowed = collectionutils.Contains(scope, ja.service+":write")
}

if !allowed {
return false, fmt.Errorf("missing access, found scopes: '%s' need %s:read|write", strings.Join(scope, ", "), ja.service)
for _, check := range ja.additionalChecks {
err := check(r, claims)
if err != nil {
return false, err
}
}

return true, nil
if !ja.checkScopes {
return true, nil
}
return checkScopes(ja.service, r.Method, claims.Scopes)
}

var (
Expand All @@ -65,32 +58,43 @@ var (
)

func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySet) (*oidc.AccessTokenClaims, error) {
claims := &oidc.AccessTokenClaims{}
if err := claimsFromRequest(r, claims, keySet); err != nil {
return claims, err
}

if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
return claims, err
}

if err := oidc.CheckExpiration(claims, 0); err != nil {
return claims, err
}

return claims, nil
}
Comment thread
laouji marked this conversation as resolved.

func claimsFromRequest[CLAIMS any](r *http.Request, claims CLAIMS, keySet oidc.KeySet) error {
authHeader := r.Header.Get("authorization")
if authHeader == "" {
return nil, ErrNoAuthorizationHeader
return ErrNoAuthorizationHeader
}

if !strings.HasPrefix(authHeader, "bearer") &&
!strings.HasPrefix(authHeader, "Bearer") {
return nil, ErrMalformedHeader
return ErrMalformedHeader
}

token := authHeader[6:]
token = strings.TrimSpace(token)

claims := &oidc.AccessTokenClaims{}
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
return err
}
payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil {
return nil, err
}

if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
return claims, err
return err
}

if _, err = oidc.CheckSignature(
Expand All @@ -100,12 +104,8 @@ func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySe
[]string{}, // Default to RS256
keySet,
); err != nil {
return claims, err
return err
}

if err = oidc.CheckExpiration(claims, 0); err != nil {
return claims, err
}

return claims, nil
return nil
}
Loading
Loading