Skip to content

Commit 55df62d

Browse files
committed
adjust
1 parent 635a206 commit 55df62d

3 files changed

Lines changed: 37 additions & 32 deletions

File tree

internal/conf/configuration.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package conf
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/base64"
67
"encoding/json"
78
"errors"
@@ -141,6 +142,8 @@ type JWTConfiguration struct {
141142
KeyID string `json:"key_id" split_words:"true"`
142143
Keys JwtKeysDecoder `json:"keys"`
143144
ValidMethods []string `json:"-" split_words:"true"`
145+
146+
SigningKey func(context.Context) (any, error) `json:"-"`
144147
}
145148

146149
type MFAFactorTypeConfiguration struct {
@@ -1012,13 +1015,22 @@ func (config *GlobalConfiguration) ApplyDefaults() error {
10121015
if err := config.applyDefaultsJWT([]byte(config.JWT.Secret)); err != nil {
10131016
return err
10141017
}
1018+
} else {
1019+
jwk, err := GetSigningJwk(&config.JWT)
1020+
if err != nil {
1021+
return err
1022+
}
1023+
sk, err := getSigningKey(jwk)
1024+
if err != nil {
1025+
return err
1026+
}
1027+
config.JWT.SigningKey = sk
10151028
}
10161029

10171030
if config.JWT.ValidMethods == nil {
10181031
config.JWT.ValidMethods = []string{}
10191032
for _, key := range config.JWT.Keys {
1020-
alg := GetSigningAlg(key.PublicKey)
1021-
config.JWT.ValidMethods = append(config.JWT.ValidMethods, alg.Alg())
1033+
config.JWT.ValidMethods = append(config.JWT.ValidMethods, key.PublicKey.Algorithm().String())
10221034
}
10231035

10241036
}

internal/conf/jwk.go

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"fmt"
77
"slices"
88
"strings"
9-
"sync"
109

1110
"github.com/aws/aws-sdk-go-v2/config"
1211
"github.com/aws/aws-sdk-go-v2/service/kms"
@@ -135,37 +134,26 @@ func GetSigningJwk(config *JWTConfiguration) (jwk.Key, error) {
135134
return nil, fmt.Errorf("no signing key found")
136135
}
137136

138-
var kmsClients sync.Map // map[string]func() (*kms.Client, error)
139-
140-
func getKMSClient(region string) (*kms.Client, error) {
141-
fn, _ := kmsClients.LoadOrStore(region, sync.OnceValues(func() (*kms.Client, error) {
142-
cfg, err := config.LoadDefaultConfig(
143-
context.Background(),
144-
config.WithRegion(region),
145-
)
146-
if err != nil {
147-
return nil, err
148-
}
149-
150-
return kms.NewFromConfig(cfg), nil
151-
}))
152-
153-
return fn.(func() (*kms.Client, error))()
154-
}
155-
156-
func GetSigningKey(ctx context.Context, k jwk.Key) (any, error) {
137+
func getSigningKey(k jwk.Key) (func(context.Context) (any, error), error) {
157138
if value, isKMS := k.Get("aws:kms:arn"); isKMS && k.KeyOps()[0] == jwk.KeyOpSign {
158139
kmsARN, ok := value.(string)
159140
if !ok {
160141
return nil, fmt.Errorf("conf: jwk key has aws:kms:arn but is not a string %v", value)
161142
}
162143

163144
parts := strings.SplitN(kmsARN, ":", 5)
164-
kmsClient, err := getKMSClient(parts[3])
145+
region := parts[3]
146+
147+
cfg, err := config.LoadDefaultConfig(
148+
context.Background(),
149+
config.WithRegion(region),
150+
)
165151
if err != nil {
166152
return nil, err
167153
}
168154

155+
kmsClient := kms.NewFromConfig(cfg)
156+
169157
pub, err := k.PublicKey()
170158
if err != nil {
171159
return nil, err
@@ -188,21 +176,26 @@ func GetSigningKey(ctx context.Context, k jwk.Key) (any, error) {
188176
return nil, err
189177
}
190178

191-
return &awskmsjwk.RS256Key{
192-
Ctx: ctx,
193-
KMS: kmsClient,
179+
return func(ctx context.Context) (any, error) {
180+
return &awskmsjwk.RS256Key{
181+
Ctx: ctx,
182+
KMS: kmsClient,
194183

195-
KeyID: kmsARN,
196-
Raw: raw,
184+
KeyID: kmsARN,
185+
Raw: raw,
186+
}, nil
197187
}, nil
188+
198189
}
199190

200191
var key any
201192
if err := k.Raw(&key); err != nil {
202193
return nil, err
203194
}
204195

205-
return key, nil
196+
return func(ctx context.Context) (any, error) {
197+
return key, nil
198+
}, nil
206199
}
207200

208201
func GetSigningAlg(k jwk.Key) jwt.SigningMethod {
@@ -233,8 +226,8 @@ func GetSigningAlg(k jwk.Key) jwt.SigningMethod {
233226

234227
func FindPublicKeyByKid(ctx context.Context, kid string, config *JWTConfiguration) (any, error) {
235228
if k, ok := config.Keys[kid]; ok {
236-
key, err := GetSigningKey(ctx, k.PublicKey)
237-
if err != nil {
229+
var key any
230+
if err := k.PublicKey.Raw(&key); err != nil {
238231
return nil, err
239232
}
240233
return key, nil

internal/tokens/service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ func SignJWT(ctx context.Context, config *conf.JWTConfiguration, claims jwt.Clai
972972
}
973973
// this serializes the aud claim to a string
974974
jwt.MarshalSingleStringAsArray = false
975-
signingKey, err := conf.GetSigningKey(ctx, signingJwk)
975+
signingKey, err := config.SigningKey(ctx)
976976
if err != nil {
977977
return "", err
978978
}

0 commit comments

Comments
 (0)