Skip to content

Commit 983f59a

Browse files
hfChris Stockton
andauthored
feat: add RS256 signing keys backed by AWS KMS (#2571)
Adds support for RSA signing keys backed by AWS KMS, which are the cheapest type of key. You specify `aws:kms:arn` as a claim in the private key's JWK and it all flows from there. It uses the ambient credentials of the process to talk to KMS. --------- Co-authored-by: Chris Stockton <chris.stockton@supabase.io>
1 parent 5b95ff8 commit 983f59a

17 files changed

Lines changed: 339 additions & 78 deletions

File tree

go.mod

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ require (
3535

3636
require (
3737
github.com/ProjectZKM/Ziren/crates/go-runtime/zkvm_runtime v0.0.0-20251001021608-1fe7b43fc4d6 // indirect
38+
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
39+
github.com/aws/aws-sdk-go-v2/config v1.32.18 // indirect
40+
github.com/aws/aws-sdk-go-v2/credentials v1.19.17 // indirect
41+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
42+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
43+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
44+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
45+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
46+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
47+
github.com/aws/aws-sdk-go-v2/service/kms v1.52.0 // indirect
48+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
49+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
50+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0 // indirect
51+
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
52+
github.com/aws/smithy-go v1.25.1 // indirect
3853
github.com/bits-and-blooms/bitset v1.20.0 // indirect
3954
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
4055
github.com/consensys/gnark-crypto v0.18.1 // indirect

go.sum

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@ github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyR
1717
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM=
1818
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
1919
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
20+
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
21+
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
22+
github.com/aws/aws-sdk-go-v2/config v1.32.18 h1:Hcia46bxhGgF3BaSnG8nSNCWmqTK6bj9xN9/FJ3WK6Q=
23+
github.com/aws/aws-sdk-go-v2/config v1.32.18/go.mod h1:zEjCAYmxqDadH1WX8CdBvmLKhUEUVFgKRQG38zjDmrY=
24+
github.com/aws/aws-sdk-go-v2/credentials v1.19.17 h1:gP2nkGsS+KMvF/jfFz2Vv2qiiOqWKyPACSzPsqHgoW8=
25+
github.com/aws/aws-sdk-go-v2/credentials v1.19.17/go.mod h1:Bsew3S/moG5iT77giPj1q8wb/s0RE5/QfH+ASjYtuQc=
26+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
27+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
28+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
29+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
30+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
31+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
32+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
33+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
34+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
35+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
36+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
37+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
38+
github.com/aws/aws-sdk-go-v2/service/kms v1.52.0 h1:QNtg+Mtj1zmepk568+UKBD5DFfqh+ESTUUqQT27JkQc=
39+
github.com/aws/aws-sdk-go-v2/service/kms v1.52.0/go.mod h1:Y0+uxvxz6ib4KktRdK0V4X45Vcs/JyYoz8H71pO8xeI=
40+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
41+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
42+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
43+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
44+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0 h1:nDARhv/oF55bcxF7rCI/4PDxOKnVXVWwDuDwCs2I2SQ=
45+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
46+
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
47+
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
48+
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
49+
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
2050
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
2151
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
2252
github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 h1:vMqcPzLT1/mbYew0gM6EJy4/sCNy9lY9rmlFO+pPwhY=

hack/kms-rsa-to-jwk.js

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#!/usr/bin/env node
2+
3+
import { execFileSync } from "node:child_process";
4+
import { webcrypto } from "node:crypto";
5+
6+
const keyId = process.argv[2];
7+
const compact = process.argv[3] === '--compact';
8+
9+
if (!keyId) {
10+
console.error("Usage: kms-rsa-to-jwk.js <key-arn> [--compact]");
11+
process.exit(1);
12+
}
13+
14+
// arn:partition:kms:region:account:key/uuid
15+
const arnParts = keyId.split(":");
16+
if (arnParts.length < 6 || arnParts[2] !== "kms") {
17+
throw new Error(`Invalid KMS ARN: ${keyId}`);
18+
}
19+
20+
const region = arnParts[3];
21+
22+
const publicKeyB64 = execFileSync(
23+
"aws",
24+
[
25+
"kms",
26+
"get-public-key",
27+
"--key-id",
28+
keyId,
29+
"--query",
30+
"PublicKey",
31+
"--output",
32+
"text",
33+
"--region",
34+
region,
35+
],
36+
{ encoding: "utf8" },
37+
).trim();
38+
39+
const spki = Buffer.from(publicKeyB64, "base64");
40+
41+
const key = await webcrypto.subtle.importKey(
42+
"spki",
43+
spki,
44+
{
45+
name: "RSASSA-PKCS1-v1_5",
46+
hash: "SHA-256",
47+
},
48+
true,
49+
["verify"],
50+
);
51+
52+
const jwk = await webcrypto.subtle.exportKey("jwk", key);
53+
54+
console.log(
55+
JSON.stringify(
56+
{
57+
...jwk,
58+
ext: undefined,
59+
use: "sig",
60+
key_ops: ['sign', 'verify'],
61+
'aws:kms:arn': keyId,
62+
63+
//kty: jwk.kty,
64+
//use: "sig",
65+
//alg: "RS256",
66+
//kid: keyId,
67+
//n: jwk.n,
68+
//e: jwk.e,
69+
},
70+
null,
71+
compact ? 0 : 2,
72+
),
73+
);

internal/api/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
8282
token, err := p.ParseWithClaims(bearer, &AccessTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
8383
if kid, ok := token.Header["kid"]; ok {
8484
if kidStr, ok := kid.(string); ok {
85-
key, err := conf.FindPublicKeyByKid(kidStr, &config.JWT)
85+
key, err := conf.FindPublicKeyByKid(ctx, kidStr, &config.JWT)
8686
if err != nil {
8787
return nil, err
8888
}

internal/api/auth_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"encoding/json"
56
"net/http"
67
"net/http/httptest"
@@ -167,7 +168,7 @@ func (ts *AuthTestSuite) TestParseJWTClaims() {
167168
jwk, err := conf.GetSigningJwk(&ts.Config.JWT)
168169
require.NoError(ts.T(), err)
169170
signingMethod := conf.GetSigningAlg(jwk)
170-
signingKey, err := conf.GetSigningKey(jwk)
171+
signingKey, err := ts.Config.JWT.SigningKey(context.Background())
171172
require.NoError(ts.T(), err)
172173

173174
userJwtToken := jwt.NewWithClaims(signingMethod, userClaims)

internal/api/e2e_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ func TestE2EHooks(t *testing.T) {
980980
) (any, error) {
981981
if kid, ok := token.Header["kid"]; ok {
982982
if kidStr, ok := kid.(string); ok {
983-
return conf.FindPublicKeyByKid(kidStr, &globalCfg.JWT)
983+
return conf.FindPublicKeyByKid(context.Background(), kidStr, &globalCfg.JWT)
984984
}
985985
}
986986
if alg, ok := token.Header["alg"]; ok {
@@ -1055,7 +1055,7 @@ func TestE2EHooks(t *testing.T) {
10551055
func(token *jwt.Token) (any, error) {
10561056
if kid, ok := token.Header["kid"]; ok {
10571057
if kidStr, ok := kid.(string); ok {
1058-
return conf.FindPublicKeyByKid(kidStr, &globalCfg.JWT)
1058+
return conf.FindPublicKeyByKid(context.Background(), kidStr, &globalCfg.JWT)
10591059
}
10601060
}
10611061
if alg, ok := token.Header["alg"]; ok {

internal/api/oauthserver/handlers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon
435435
nonce = *authorization.Nonce
436436
}
437437

438-
idToken, err := tokenService.GenerateIDToken(tokens.GenerateIDTokenParams{
438+
idToken, err := tokenService.GenerateIDToken(ctx, tokens.GenerateIDTokenParams{
439439
User: user,
440440
ClientID: client.ID,
441441
Nonce: nonce,

internal/conf/awskmsjwk/kms.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package awskmsjwk
2+
3+
import (
4+
"context"
5+
6+
"github.com/aws/aws-sdk-go-v2/service/kms"
7+
)
8+
9+
type KMSAPI interface {
10+
Sign(ctx context.Context, in *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error)
11+
}

internal/conf/awskmsjwk/rs256.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package awskmsjwk
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"errors"
7+
8+
"github.com/aws/aws-sdk-go-v2/service/kms"
9+
kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types"
10+
"github.com/golang-jwt/jwt/v5"
11+
"github.com/sirupsen/logrus"
12+
)
13+
14+
var ErrNotRS256Key = errors.New("awskmsjwk: key needs to be *RS256Key")
15+
16+
type RS256Key struct {
17+
Ctx context.Context
18+
KMS KMSAPI
19+
20+
KeyID string
21+
Raw any
22+
}
23+
24+
type signingMethodKMSRS256 struct{}
25+
26+
var SigningMethodRS256KMS jwt.SigningMethod = &signingMethodKMSRS256{}
27+
28+
func (m *signingMethodKMSRS256) Alg() string {
29+
return jwt.SigningMethodRS256.Alg() // "RS256"
30+
}
31+
32+
func (m *signingMethodKMSRS256) Sign(signingString string, key any) ([]byte, error) {
33+
k, ok := key.(*RS256Key)
34+
if !ok {
35+
return nil, ErrNotRS256Key
36+
}
37+
38+
// JWT RS256 signs SHA256(base64url(header) + "." + base64url(payload)).
39+
// Use DIGEST so large JWTs do not hit KMS RAW message size limits.
40+
digest := sha256.Sum256([]byte(signingString))
41+
42+
out, err := k.KMS.Sign(k.Ctx, &kms.SignInput{
43+
KeyId: &k.KeyID,
44+
Message: digest[:],
45+
MessageType: kmstypes.MessageTypeDigest,
46+
SigningAlgorithm: kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha256,
47+
})
48+
if err != nil {
49+
logrus.WithError(err).Errorf("Unable to sign RS256 JWT with AWS KMS key %q", k.KeyID)
50+
51+
return nil, err
52+
}
53+
54+
return out.Signature, nil
55+
}
56+
57+
func (m *signingMethodKMSRS256) Verify(signingString string, sig []byte, key any) error {
58+
k, ok := key.(*RS256Key)
59+
if !ok {
60+
return ErrNotRS256Key
61+
}
62+
63+
return jwt.SigningMethodRS256.Verify(signingString, sig, k.Raw)
64+
}

internal/conf/configuration.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package conf
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/base64"
67
"encoding/json"
78
"errors"
@@ -141,6 +142,8 @@ type JWTConfiguration struct {
141142
KeyID string `json:"key_id" split_words:"true"`
142143
Keys JwtKeysDecoder `json:"keys"`
143144
ValidMethods []string `json:"-" split_words:"true"`
145+
146+
SigningKey func(context.Context) (any, error) `json:"-"`
144147
}
145148

146149
type MFAFactorTypeConfiguration struct {
@@ -1043,13 +1046,22 @@ func (config *GlobalConfiguration) ApplyDefaults() error {
10431046
if err := config.applyDefaultsJWT([]byte(config.JWT.Secret)); err != nil {
10441047
return err
10451048
}
1049+
} else {
1050+
jwk, err := GetSigningJwk(&config.JWT)
1051+
if err != nil {
1052+
return err
1053+
}
1054+
sk, err := getSigningKey(jwk)
1055+
if err != nil {
1056+
return err
1057+
}
1058+
config.JWT.SigningKey = sk
10461059
}
10471060

10481061
if config.JWT.ValidMethods == nil {
10491062
config.JWT.ValidMethods = []string{}
10501063
for _, key := range config.JWT.Keys {
1051-
alg := GetSigningAlg(key.PublicKey)
1052-
config.JWT.ValidMethods = append(config.JWT.ValidMethods, alg.Alg())
1064+
config.JWT.ValidMethods = append(config.JWT.ValidMethods, key.PublicKey.Algorithm().String())
10531065
}
10541066

10551067
}
@@ -1193,6 +1205,16 @@ func (config *GlobalConfiguration) applyDefaultsJWTPrivateKey(privKey jwk.Key) e
11931205
PublicKey: pubKey,
11941206
PrivateKey: privKey,
11951207
}
1208+
1209+
var key any
1210+
if err := privKey.Raw(&key); err != nil {
1211+
return err
1212+
}
1213+
1214+
config.JWT.SigningKey = func(ctx context.Context) (any, error) {
1215+
return key, nil
1216+
}
1217+
11961218
return nil
11971219
}
11981220

0 commit comments

Comments
 (0)