Skip to content

Commit 0a6fdaf

Browse files
authored
Merge pull request #514 from smallstep/mariano/delete-key
Implement DeleteKey on tpmkms
2 parents 88c45f1 + a889b14 commit 0a6fdaf

File tree

4 files changed

+164
-13
lines changed

4 files changed

+164
-13
lines changed

keyutil/key.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ func generateECKey(crv string) (crypto.Signer, error) {
219219
}
220220

221221
func generateRSAKey(bits int) (crypto.Signer, error) {
222-
if min := MinRSAKeyBytes * 8; !insecureMode.isSet() && bits < min {
223-
return nil, errors.Errorf("the size of the RSA key should be at least %d bits", min)
222+
if minBits := MinRSAKeyBytes * 8; !insecureMode.isSet() && bits < minBits {
223+
return nil, errors.Errorf("the size of the RSA key should be at least %d bits", minBits)
224224
}
225225

226226
key, err := rsa.GenerateKey(rand.Reader, bits)

kms/tpmkms/tpmkms.go

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,35 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
433433
}, nil
434434
}
435435

436+
// DeleteKey deletes a key identified by name from the TPMKMS.
437+
//
438+
// # Experimental
439+
//
440+
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
441+
// release.
442+
func (k *TPMKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
443+
if req.Name == "" {
444+
return fmt.Errorf("deleteKeyRequest 'name' cannot be empty")
445+
}
446+
properties, err := parseNameURI(req.Name)
447+
if err != nil {
448+
return fmt.Errorf("failed parsing %q: %w", req.Name, err)
449+
}
450+
451+
ctx := context.Background()
452+
if properties.ak {
453+
if err := k.tpm.DeleteAK(ctx, properties.name); err != nil {
454+
return notFoundError(err)
455+
}
456+
} else {
457+
if err := k.tpm.DeleteKey(ctx, properties.name); err != nil {
458+
return notFoundError(err)
459+
}
460+
}
461+
462+
return nil
463+
}
464+
436465
// CreateSigner creates a signer using a key present in the TPM KMS.
437466
//
438467
// The `signingKey` in the [apiv1.CreateSignerRequest] can be used to specify
@@ -460,7 +489,7 @@ func (k *TPMKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er
460489
switch {
461490
case properties.name != "":
462491
ctx := context.Background()
463-
key, err := k.tpm.GetKey(ctx, properties.name)
492+
key, err := k.getKey(ctx, properties.name)
464493
if err != nil {
465494
return nil, err
466495
}
@@ -518,7 +547,7 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
518547
switch {
519548
case properties.name != "":
520549
if properties.ak {
521-
ak, err := k.tpm.GetAK(ctx, properties.name)
550+
ak, err := k.getAK(ctx, properties.name)
522551
if err != nil {
523552
return nil, err
524553
}
@@ -529,7 +558,7 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
529558
return akPub, nil
530559
}
531560

532-
key, err := k.tpm.GetKey(ctx, properties.name)
561+
key, err := k.getKey(ctx, properties.name)
533562
if err != nil {
534563
return nil, err
535564
}
@@ -598,13 +627,13 @@ func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([
598627
ctx := context.Background()
599628
var chain []*x509.Certificate
600629
if properties.ak {
601-
ak, err := k.tpm.GetAK(ctx, properties.name)
630+
ak, err := k.getAK(ctx, properties.name)
602631
if err != nil {
603632
return nil, err
604633
}
605634
chain = ak.CertificateChain()
606635
} else {
607-
key, err := k.tpm.GetKey(ctx, properties.name)
636+
key, err := k.getKey(ctx, properties.name)
608637
if err != nil {
609638
return nil, err
610639
}
@@ -741,7 +770,7 @@ func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
741770

742771
ctx := context.Background()
743772
if properties.ak {
744-
ak, err := k.tpm.GetAK(ctx, properties.name)
773+
ak, err := k.getAK(ctx, properties.name)
745774
if err != nil {
746775
return err
747776
}
@@ -750,7 +779,7 @@ func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
750779
return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err)
751780
}
752781
} else {
753-
key, err := k.tpm.GetKey(ctx, properties.name)
782+
key, err := k.getKey(ctx, properties.name)
754783
if err != nil {
755784
return err
756785
}
@@ -898,15 +927,15 @@ func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
898927

899928
ctx := context.Background()
900929
if properties.ak {
901-
ak, err := k.tpm.GetAK(ctx, properties.name)
930+
ak, err := k.getAK(ctx, properties.name)
902931
if err != nil {
903932
return err
904933
}
905934
if err := ak.SetCertificateChain(ctx, nil); err != nil {
906935
return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err)
907936
}
908937
} else {
909-
key, err := k.tpm.GetKey(ctx, properties.name)
938+
key, err := k.getKey(ctx, properties.name)
910939
if err != nil {
911940
return err
912941
}
@@ -1060,7 +1089,7 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
10601089
var key *tpm.Key
10611090
akName := properties.name
10621091
if !properties.ak {
1063-
key, err = k.tpm.GetKey(ctx, properties.name)
1092+
key, err = k.getKey(ctx, properties.name)
10641093
if err != nil {
10651094
return nil, err
10661095
}
@@ -1070,7 +1099,7 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
10701099
akName = key.AttestedBy()
10711100
}
10721101

1073-
ak, err := k.tpm.GetAK(ctx, akName)
1102+
ak, err := k.getAK(ctx, akName)
10741103
if err != nil {
10751104
return nil, err
10761105
}
@@ -1227,6 +1256,31 @@ func (k *TPMKMS) hasValidIdentity(ak *tpm.AK, ekURL *url.URL) error {
12271256
return ErrIdentityCertificateInvalid
12281257
}
12291258

1259+
func (k *TPMKMS) getAK(ctx context.Context, name string) (*tpm.AK, error) {
1260+
ak, err := k.tpm.GetAK(ctx, name)
1261+
if err != nil {
1262+
return nil, notFoundError(err)
1263+
}
1264+
return ak, nil
1265+
}
1266+
1267+
func (k *TPMKMS) getKey(ctx context.Context, name string) (*tpm.Key, error) {
1268+
key, err := k.tpm.GetKey(ctx, name)
1269+
if err != nil {
1270+
return nil, notFoundError(err)
1271+
}
1272+
return key, nil
1273+
}
1274+
1275+
func notFoundError(err error) error {
1276+
if errors.Is(err, tpm.ErrNotFound) {
1277+
return apiv1.NotFoundError{
1278+
Message: err.Error(),
1279+
}
1280+
}
1281+
return err
1282+
}
1283+
12301284
// generateKeyID generates a key identifier from the
12311285
// SHA256 hash of the public key.
12321286
func generateKeyID(pub crypto.PublicKey) ([]byte, error) {

kms/tpmkms/tpmkms_simulator_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,73 @@ func TestTPMKMS_CreateKey(t *testing.T) {
433433
}
434434
}
435435

436+
func TestTPMKMS_DeleteKey(t *testing.T) {
437+
okTPM := newSimulatedTPM(t,
438+
withAK("ak1"), withAK("ak2"),
439+
withKey("key1"), withKey("key2"),
440+
)
441+
442+
validatePending := func(t *testing.T, k *TPMKMS) {
443+
_, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak2;ak=true"})
444+
assert.NoError(t, err)
445+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key2"})
446+
assert.NoError(t, err)
447+
}
448+
449+
type fields struct {
450+
tpm *tpmp.TPM
451+
}
452+
type args struct {
453+
req *apiv1.DeleteKeyRequest
454+
}
455+
tests := []struct {
456+
name string
457+
fields fields
458+
args args
459+
assertion assert.ErrorAssertionFunc
460+
validate func(*testing.T, *TPMKMS)
461+
}{
462+
{"ok", fields{okTPM}, args{&apiv1.DeleteKeyRequest{
463+
Name: "tpmkms:name=key1",
464+
}}, assert.NoError, func(t *testing.T, k *TPMKMS) {
465+
_, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak1;ak=true"})
466+
assert.NoError(t, err)
467+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak2;ak=true"})
468+
assert.NoError(t, err)
469+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key1"})
470+
assert.ErrorIs(t, err, apiv1.NotFoundError{})
471+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key2"})
472+
assert.NoError(t, err)
473+
}},
474+
{"ok ak", fields{okTPM}, args{&apiv1.DeleteKeyRequest{
475+
Name: "tpmkms:name=ak1;ak=true",
476+
}}, assert.NoError, func(t *testing.T, k *TPMKMS) {
477+
_, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak1;ak=true"})
478+
assert.ErrorIs(t, err, apiv1.NotFoundError{})
479+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=ak2;ak=true"})
480+
assert.NoError(t, err)
481+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key1"})
482+
assert.ErrorIs(t, err, apiv1.NotFoundError{})
483+
_, err = k.GetPublicKey(&apiv1.GetPublicKeyRequest{Name: "tpmkms:name=key2"})
484+
assert.NoError(t, err)
485+
}},
486+
{"fail name", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: ""}}, assert.Error, validatePending},
487+
{"fail not ak", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "tpmkms:name=ak2"}}, assert.Error, validatePending},
488+
{"fail not key", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "tpmkms:name=key2;ak=true"}}, assert.Error, validatePending},
489+
{"fail missing other", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "tpmkms:name=missing"}}, assert.Error, validatePending},
490+
{"fail uri", fields{okTPM}, args{&apiv1.DeleteKeyRequest{Name: "kms:name=key2"}}, assert.Error, validatePending},
491+
}
492+
for _, tt := range tests {
493+
t.Run(tt.name, func(t *testing.T) {
494+
k := &TPMKMS{
495+
tpm: tt.fields.tpm,
496+
}
497+
tt.assertion(t, k.DeleteKey(tt.args.req))
498+
tt.validate(t, k)
499+
})
500+
}
501+
}
502+
436503
func TestTPMKMS_CreateSigner(t *testing.T) {
437504
tpmWithKey := newSimulatedTPM(t, withKey("key1"))
438505

kms/tpmkms/tpmkms_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ package tpmkms
33
import (
44
"context"
55
"encoding/asn1"
6+
"errors"
7+
"fmt"
68
"os"
79
"testing"
810

911
"github.com/stretchr/testify/assert"
1012
"github.com/stretchr/testify/require"
1113
"go.step.sm/crypto/kms/apiv1"
14+
"go.step.sm/crypto/tpm"
1215
"go.step.sm/crypto/tpm/tss2"
1316
)
1417

@@ -108,3 +111,30 @@ func Test_parseTSS2(t *testing.T) {
108111
})
109112
}
110113
}
114+
115+
func Test_notFoundError(t *testing.T) {
116+
type args struct {
117+
err error
118+
}
119+
tests := []struct {
120+
name string
121+
args args
122+
assertion assert.ErrorAssertionFunc
123+
}{
124+
{"nil", args{nil}, assert.NoError},
125+
{"tpm not found", args{tpm.ErrNotFound}, func(tt assert.TestingT, err error, i ...interface{}) bool {
126+
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, i...)
127+
}},
128+
{"tpm not found wrapped", args{fmt.Errorf("some error: %w", tpm.ErrNotFound)}, func(tt assert.TestingT, err error, i ...interface{}) bool {
129+
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, i...)
130+
}},
131+
{"other", args{tpm.ErrExists}, func(tt assert.TestingT, err error, i ...interface{}) bool {
132+
return assert.False(t, errors.Is(err, apiv1.NotFoundError{}), i...)
133+
}},
134+
}
135+
for _, tt := range tests {
136+
t.Run(tt.name, func(t *testing.T) {
137+
tt.assertion(t, notFoundError(tt.args.err))
138+
})
139+
}
140+
}

0 commit comments

Comments
 (0)