Skip to content
59 changes: 24 additions & 35 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package auth

import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/formancehq/go-libs/v3/collectionutils"
"github.com/formancehq/go-libs/v3/oidc"
)

Expand All @@ -33,30 +31,15 @@ func NewJWTAuth(

// Authenticate validates the JWT in the request and returns the user, if valid.
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {

claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet)
if err != nil {
return false, err
}

if ja.checkScopes {
scope := claims.Scopes

allowed := true //nolint:ineffassign
switch r.Method {
case http.MethodOptions, http.MethodGet, http.MethodHead, http.MethodTrace:
allowed = collectionutils.Contains(scope, ja.service+":read") ||
collectionutils.Contains(scope, ja.service+":write")
default:
allowed = collectionutils.Contains(scope, ja.service+":write")
}

if !allowed {
return false, fmt.Errorf("missing access, found scopes: '%s' need %s:read|write", strings.Join(scope, ", "), ja.service)
}
if !ja.checkScopes {
return true, nil
}

return true, nil
return checkScopes(ja.service, r.Method, claims.Scopes)
}

var (
Expand All @@ -65,32 +48,42 @@ var (
)

func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySet) (*oidc.AccessTokenClaims, error) {
claims := &oidc.AccessTokenClaims{}
if err := claimsFromRequest(r, claims, keySet); err != nil {
return claims, err
}

if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
return claims, err
}

if err := oidc.CheckExpiration(claims, 0); err != nil {
return claims, err
}
return claims, nil
}
Comment thread
laouji marked this conversation as resolved.

func claimsFromRequest[CLAIMS any](r *http.Request, claims CLAIMS, keySet oidc.KeySet) error {
authHeader := r.Header.Get("authorization")
if authHeader == "" {
return nil, ErrNoAuthorizationHeader
return ErrNoAuthorizationHeader
}

if !strings.HasPrefix(authHeader, "bearer") &&
!strings.HasPrefix(authHeader, "Bearer") {
return nil, ErrMalformedHeader
return ErrMalformedHeader
}

token := authHeader[6:]
token = strings.TrimSpace(token)

claims := &oidc.AccessTokenClaims{}
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
return err
}
payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil {
return nil, err
}

if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
return claims, err
return err
}

if _, err = oidc.CheckSignature(
Expand All @@ -100,12 +93,8 @@ func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySe
[]string{}, // Default to RS256
keySet,
); err != nil {
return claims, err
}

if err = oidc.CheckExpiration(claims, 0); err != nil {
return claims, err
return err
}

return claims, nil
return nil
}
Loading
Loading