11package conf
22
33import (
4+ "context"
45 "encoding/json"
56 "fmt"
67 "slices"
8+ "strings"
9+ "sync"
710
11+ "github.com/aws/aws-sdk-go-v2/config"
12+ "github.com/aws/aws-sdk-go-v2/service/kms"
813 "github.com/golang-jwt/jwt/v5"
914 "github.com/lestrrat-go/jwx/v2/jwk"
15+ "github.com/supabase/auth/internal/conf/awskmsjwk"
1016)
1117
1218type JwtKeysDecoder map [string ]JwkInfo
@@ -68,6 +74,12 @@ func (j *JwtKeysDecoder) decodePublicKey(
6874 return err
6975 }
7076
77+ if _ , isKMS := privJwk .Get ("aws:kms:arn" ); isKMS {
78+ if err := pubJwk .Remove ("aws:kms:arn" ); err != nil {
79+ return err
80+ }
81+ }
82+
7183 config [pubJwk .KeyID ()] = JwkInfo {
7284 PublicKey : pubJwk ,
7385 PrivateKey : privJwk ,
@@ -123,11 +135,73 @@ func GetSigningJwk(config *JWTConfiguration) (jwk.Key, error) {
123135 return nil , fmt .Errorf ("no signing key found" )
124136}
125137
126- func GetSigningKey (k jwk.Key ) (any , error ) {
138+ var kmsClients sync.Map // map[string]func() (*kms.Client, error)
139+
140+ func getKMSClient (region string ) (* kms.Client , error ) {
141+ fn , _ := kmsClients .LoadOrStore (region , sync .OnceValues (func () (* kms.Client , error ) {
142+ cfg , err := config .LoadDefaultConfig (
143+ context .Background (),
144+ config .WithRegion (region ),
145+ )
146+ if err != nil {
147+ return nil , err
148+ }
149+
150+ return kms .NewFromConfig (cfg ), nil
151+ }))
152+
153+ return fn .(func () (* kms.Client , error ))()
154+ }
155+
156+ func GetSigningKey (ctx context.Context , k jwk.Key ) (any , error ) {
157+ if value , isKMS := k .Get ("aws:kms:arn" ); isKMS && k .KeyOps ()[0 ] == jwk .KeyOpSign {
158+ kmsARN , ok := value .(string )
159+ if ! ok {
160+ return nil , fmt .Errorf ("conf: jwk key has aws:kms:arn but is not a string %v" , value )
161+ }
162+
163+ parts := strings .SplitN (kmsARN , ":" , 5 )
164+ kmsClient , err := getKMSClient (parts [3 ])
165+ if err != nil {
166+ return nil , err
167+ }
168+
169+ pub , err := k .PublicKey ()
170+ if err != nil {
171+ return nil , err
172+ }
173+
174+ if err := pub .Set (jwk .KeyUsageKey , "sig" ); err != nil {
175+ return nil , err
176+ }
177+
178+ if err := pub .Set (jwk .KeyOpsKey , jwk.KeyOperationList {jwk .KeyOpVerify }); err != nil {
179+ return nil , err
180+ }
181+
182+ if err := pub .Remove ("aws:kms:arn" ); err != nil {
183+ return nil , err
184+ }
185+
186+ var raw any
187+ if err := pub .Raw (& raw ); err != nil {
188+ return nil , err
189+ }
190+
191+ return & awskmsjwk.RS256Key {
192+ Ctx : ctx ,
193+ KMS : kmsClient ,
194+
195+ KeyID : kmsARN ,
196+ Raw : raw ,
197+ }, nil
198+ }
199+
127200 var key any
128201 if err := k .Raw (& key ); err != nil {
129202 return nil , err
130203 }
204+
131205 return key , nil
132206}
133207
@@ -138,6 +212,10 @@ func GetSigningAlg(k jwk.Key) jwt.SigningMethod {
138212
139213 switch (k ).Algorithm ().String () {
140214 case "RS256" :
215+ if _ , isKMS := k .Get ("aws:kms:arn" ); isKMS {
216+ return awskmsjwk .SigningMethodRS256KMS
217+ }
218+
141219 return jwt .SigningMethodRS256
142220 case "RS512" :
143221 return jwt .SigningMethodRS512
@@ -153,9 +231,9 @@ func GetSigningAlg(k jwk.Key) jwt.SigningMethod {
153231 return jwt .SigningMethodHS256
154232}
155233
156- func FindPublicKeyByKid (kid string , config * JWTConfiguration ) (any , error ) {
234+ func FindPublicKeyByKid (ctx context. Context , kid string , config * JWTConfiguration ) (any , error ) {
157235 if k , ok := config .Keys [kid ]; ok {
158- key , err := GetSigningKey (k .PublicKey )
236+ key , err := GetSigningKey (ctx , k .PublicKey )
159237 if err != nil {
160238 return nil , err
161239 }
0 commit comments