Skip to content

Commit 635a206

Browse files
committed
feat: add RS256 signing keys backed by AWS KMS
1 parent 169ad67 commit 635a206

10 files changed

Lines changed: 284 additions & 11 deletions

File tree

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ RUN apk add --no-cache make git
88
WORKDIR /go/src/github.com/supabase/auth
99

1010
# Pulling dependencies
11-
COPY ./Makefile ./go.* ./
11+
COPY ./Makefile ./go.* ./.git ./
1212
COPY ./internal/forks/godotenv ./internal/forks/godotenv
1313
RUN make deps
1414

go.mod

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ require (
3535

3636
require (
3737
github.com/ProjectZKM/Ziren/crates/go-runtime/zkvm_runtime v0.0.0-20251001021608-1fe7b43fc4d6 // indirect
38+
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
39+
github.com/aws/aws-sdk-go-v2/config v1.32.18 // indirect
40+
github.com/aws/aws-sdk-go-v2/credentials v1.19.17 // indirect
41+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
42+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
43+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
44+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
45+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
46+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
47+
github.com/aws/aws-sdk-go-v2/service/kms v1.52.0 // indirect
48+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
49+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
50+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0 // indirect
51+
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
52+
github.com/aws/smithy-go v1.25.1 // indirect
3853
github.com/bits-and-blooms/bitset v1.20.0 // indirect
3954
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
4055
github.com/consensys/gnark-crypto v0.18.1 // indirect

go.sum

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@ github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyR
1717
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM=
1818
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
1919
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
20+
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
21+
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
22+
github.com/aws/aws-sdk-go-v2/config v1.32.18 h1:Hcia46bxhGgF3BaSnG8nSNCWmqTK6bj9xN9/FJ3WK6Q=
23+
github.com/aws/aws-sdk-go-v2/config v1.32.18/go.mod h1:zEjCAYmxqDadH1WX8CdBvmLKhUEUVFgKRQG38zjDmrY=
24+
github.com/aws/aws-sdk-go-v2/credentials v1.19.17 h1:gP2nkGsS+KMvF/jfFz2Vv2qiiOqWKyPACSzPsqHgoW8=
25+
github.com/aws/aws-sdk-go-v2/credentials v1.19.17/go.mod h1:Bsew3S/moG5iT77giPj1q8wb/s0RE5/QfH+ASjYtuQc=
26+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
27+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
28+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
29+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
30+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
31+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
32+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
33+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
34+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
35+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
36+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
37+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
38+
github.com/aws/aws-sdk-go-v2/service/kms v1.52.0 h1:QNtg+Mtj1zmepk568+UKBD5DFfqh+ESTUUqQT27JkQc=
39+
github.com/aws/aws-sdk-go-v2/service/kms v1.52.0/go.mod h1:Y0+uxvxz6ib4KktRdK0V4X45Vcs/JyYoz8H71pO8xeI=
40+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
41+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
42+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
43+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
44+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0 h1:nDARhv/oF55bcxF7rCI/4PDxOKnVXVWwDuDwCs2I2SQ=
45+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
46+
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
47+
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
48+
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
49+
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
2050
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
2151
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
2252
github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 h1:vMqcPzLT1/mbYew0gM6EJy4/sCNy9lY9rmlFO+pPwhY=

hack/kms-rsa-to-jwk.js

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#!/usr/bin/env node
2+
3+
import { execFileSync } from "node:child_process";
4+
import { webcrypto } from "node:crypto";
5+
6+
const keyId = process.argv[2];
7+
const compact = process.argv[3] === '--compact';
8+
9+
if (!keyId) {
10+
console.error("Usage: kms-rsa-to-jwk.js <key-arn> [--compact]");
11+
process.exit(1);
12+
}
13+
14+
// arn:partition:kms:region:account:key/uuid
15+
const arnParts = keyId.split(":");
16+
if (arnParts.length < 6 || arnParts[2] !== "kms") {
17+
throw new Error(`Invalid KMS ARN: ${keyId}`);
18+
}
19+
20+
const region = arnParts[3];
21+
22+
const publicKeyB64 = execFileSync(
23+
"aws",
24+
[
25+
"kms",
26+
"get-public-key",
27+
"--key-id",
28+
keyId,
29+
"--query",
30+
"PublicKey",
31+
"--output",
32+
"text",
33+
"--region",
34+
region,
35+
],
36+
{ encoding: "utf8" },
37+
).trim();
38+
39+
const spki = Buffer.from(publicKeyB64, "base64");
40+
41+
const key = await webcrypto.subtle.importKey(
42+
"spki",
43+
spki,
44+
{
45+
name: "RSASSA-PKCS1-v1_5",
46+
hash: "SHA-256",
47+
},
48+
true,
49+
["verify"],
50+
);
51+
52+
const jwk = await webcrypto.subtle.exportKey("jwk", key);
53+
54+
console.log(
55+
JSON.stringify(
56+
{
57+
...jwk,
58+
ext: undefined,
59+
use: "sig",
60+
key_ops: ['sign', 'verify'],
61+
'aws:kms:arn': keyId,
62+
63+
//kty: jwk.kty,
64+
//use: "sig",
65+
//alg: "RS256",
66+
//kid: keyId,
67+
//n: jwk.n,
68+
//e: jwk.e,
69+
},
70+
null,
71+
compact ? 0 : 2,
72+
),
73+
);

internal/api/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
8282
token, err := p.ParseWithClaims(bearer, &AccessTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
8383
if kid, ok := token.Header["kid"]; ok {
8484
if kidStr, ok := kid.(string); ok {
85-
key, err := conf.FindPublicKeyByKid(kidStr, &config.JWT)
85+
key, err := conf.FindPublicKeyByKid(ctx, kidStr, &config.JWT)
8686
if err != nil {
8787
return nil, err
8888
}

internal/api/oauthserver/handlers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon
435435
nonce = *authorization.Nonce
436436
}
437437

438-
idToken, err := tokenService.GenerateIDToken(tokens.GenerateIDTokenParams{
438+
idToken, err := tokenService.GenerateIDToken(ctx, tokens.GenerateIDTokenParams{
439439
User: user,
440440
ClientID: client.ID,
441441
Nonce: nonce,

internal/conf/awskmsjwk/kms.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package awskmsjwk
2+
3+
import (
4+
"context"
5+
6+
"github.com/aws/aws-sdk-go-v2/service/kms"
7+
)
8+
9+
type KMSAPI interface {
10+
Sign(ctx context.Context, in *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error)
11+
}

internal/conf/awskmsjwk/rs256.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package awskmsjwk
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"errors"
7+
8+
"github.com/aws/aws-sdk-go-v2/service/kms"
9+
kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types"
10+
"github.com/golang-jwt/jwt/v5"
11+
"github.com/sirupsen/logrus"
12+
)
13+
14+
var ErrNotRS256Key = errors.New("awskmsjwk: key needs to be *RS256Key")
15+
16+
type RS256Key struct {
17+
Ctx context.Context
18+
KMS KMSAPI
19+
20+
KeyID string
21+
Raw any
22+
}
23+
24+
type signingMethodKMSRS256 struct{}
25+
26+
var SigningMethodRS256KMS jwt.SigningMethod = &signingMethodKMSRS256{}
27+
28+
func (m *signingMethodKMSRS256) Alg() string {
29+
return jwt.SigningMethodRS256.Alg() // "RS256"
30+
}
31+
32+
func (m *signingMethodKMSRS256) Sign(signingString string, key any) ([]byte, error) {
33+
k, ok := key.(*RS256Key)
34+
if !ok {
35+
return nil, ErrNotRS256Key
36+
}
37+
38+
// JWT RS256 signs SHA256(base64url(header) + "." + base64url(payload)).
39+
// Use DIGEST so large JWTs do not hit KMS RAW message size limits.
40+
digest := sha256.Sum256([]byte(signingString))
41+
42+
out, err := k.KMS.Sign(k.Ctx, &kms.SignInput{
43+
KeyId: &k.KeyID,
44+
Message: digest[:],
45+
MessageType: kmstypes.MessageTypeDigest,
46+
SigningAlgorithm: kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha256,
47+
})
48+
if err != nil {
49+
logrus.WithError(err).Error("Unable to sign RS256 JWT with AWS KMS key %q", k.KeyID)
50+
51+
return nil, err
52+
}
53+
54+
return out.Signature, nil
55+
}
56+
57+
func (m *signingMethodKMSRS256) Verify(signingString string, sig []byte, key any) error {
58+
k, ok := key.(*RS256Key)
59+
if !ok {
60+
return ErrNotRS256Key
61+
}
62+
63+
return jwt.SigningMethodRS256.Verify(signingString, sig, k.Raw)
64+
}

internal/conf/jwk.go

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
package conf
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"slices"
8+
"strings"
9+
"sync"
710

11+
"github.com/aws/aws-sdk-go-v2/config"
12+
"github.com/aws/aws-sdk-go-v2/service/kms"
813
"github.com/golang-jwt/jwt/v5"
914
"github.com/lestrrat-go/jwx/v2/jwk"
15+
"github.com/supabase/auth/internal/conf/awskmsjwk"
1016
)
1117

1218
type JwtKeysDecoder map[string]JwkInfo
@@ -68,6 +74,12 @@ func (j *JwtKeysDecoder) decodePublicKey(
6874
return err
6975
}
7076

77+
if _, isKMS := privJwk.Get("aws:kms:arn"); isKMS {
78+
if err := pubJwk.Remove("aws:kms:arn"); err != nil {
79+
return err
80+
}
81+
}
82+
7183
config[pubJwk.KeyID()] = JwkInfo{
7284
PublicKey: pubJwk,
7385
PrivateKey: privJwk,
@@ -123,11 +135,73 @@ func GetSigningJwk(config *JWTConfiguration) (jwk.Key, error) {
123135
return nil, fmt.Errorf("no signing key found")
124136
}
125137

126-
func GetSigningKey(k jwk.Key) (any, error) {
138+
var kmsClients sync.Map // map[string]func() (*kms.Client, error)
139+
140+
func getKMSClient(region string) (*kms.Client, error) {
141+
fn, _ := kmsClients.LoadOrStore(region, sync.OnceValues(func() (*kms.Client, error) {
142+
cfg, err := config.LoadDefaultConfig(
143+
context.Background(),
144+
config.WithRegion(region),
145+
)
146+
if err != nil {
147+
return nil, err
148+
}
149+
150+
return kms.NewFromConfig(cfg), nil
151+
}))
152+
153+
return fn.(func() (*kms.Client, error))()
154+
}
155+
156+
func GetSigningKey(ctx context.Context, k jwk.Key) (any, error) {
157+
if value, isKMS := k.Get("aws:kms:arn"); isKMS && k.KeyOps()[0] == jwk.KeyOpSign {
158+
kmsARN, ok := value.(string)
159+
if !ok {
160+
return nil, fmt.Errorf("conf: jwk key has aws:kms:arn but is not a string %v", value)
161+
}
162+
163+
parts := strings.SplitN(kmsARN, ":", 5)
164+
kmsClient, err := getKMSClient(parts[3])
165+
if err != nil {
166+
return nil, err
167+
}
168+
169+
pub, err := k.PublicKey()
170+
if err != nil {
171+
return nil, err
172+
}
173+
174+
if err := pub.Set(jwk.KeyUsageKey, "sig"); err != nil {
175+
return nil, err
176+
}
177+
178+
if err := pub.Set(jwk.KeyOpsKey, jwk.KeyOperationList{jwk.KeyOpVerify}); err != nil {
179+
return nil, err
180+
}
181+
182+
if err := pub.Remove("aws:kms:arn"); err != nil {
183+
return nil, err
184+
}
185+
186+
var raw any
187+
if err := pub.Raw(&raw); err != nil {
188+
return nil, err
189+
}
190+
191+
return &awskmsjwk.RS256Key{
192+
Ctx: ctx,
193+
KMS: kmsClient,
194+
195+
KeyID: kmsARN,
196+
Raw: raw,
197+
}, nil
198+
}
199+
127200
var key any
128201
if err := k.Raw(&key); err != nil {
129202
return nil, err
130203
}
204+
131205
return key, nil
132206
}
133207

@@ -138,6 +212,10 @@ func GetSigningAlg(k jwk.Key) jwt.SigningMethod {
138212

139213
switch (k).Algorithm().String() {
140214
case "RS256":
215+
if _, isKMS := k.Get("aws:kms:arn"); isKMS {
216+
return awskmsjwk.SigningMethodRS256KMS
217+
}
218+
141219
return jwt.SigningMethodRS256
142220
case "RS512":
143221
return jwt.SigningMethodRS512
@@ -153,9 +231,9 @@ func GetSigningAlg(k jwk.Key) jwt.SigningMethod {
153231
return jwt.SigningMethodHS256
154232
}
155233

156-
func FindPublicKeyByKid(kid string, config *JWTConfiguration) (any, error) {
234+
func FindPublicKeyByKid(ctx context.Context, kid string, config *JWTConfiguration) (any, error) {
157235
if k, ok := config.Keys[kid]; ok {
158-
key, err := GetSigningKey(k.PublicKey)
236+
key, err := GetSigningKey(ctx, k.PublicKey)
159237
if err != nil {
160238
return nil, err
161239
}

0 commit comments

Comments
 (0)