From 1c8f93d5aa7b34fb4b4bfa1d96a489a437519864 Mon Sep 17 00:00:00 2001 From: Rafael Mendez Date: Thu, 20 Nov 2025 19:48:40 +0000 Subject: [PATCH] Adding SnapshotLock Capabilities --- docs/snapshot.md | 41 +- .../manifests/classes/snapshotclass.yaml | 2 +- hack/e2e/kops/patch-cluster.yaml | 3 +- pkg/cloud/cloud.go | 16 +- pkg/cloud/cloud_test.go | 53 +++ pkg/cloud/interface.go | 1 + pkg/cloud/mock_cloud.go | 15 + pkg/cloud/mock_ec2.go | 20 + pkg/driver/constants.go | 12 + pkg/driver/controller.go | 45 ++- pkg/driver/controller_test.go | 277 ++++++++++++++ pkg/util/ec2_interface.go | 1 + tests/e2e/requires_aws_api.go | 350 ++++++++++++++++++ tests/sanity/fake_sanity_cloud.go | 4 + 14 files changed, 832 insertions(+), 8 deletions(-) diff --git a/docs/snapshot.md b/docs/snapshot.md index 5643cdf3bf..e0616726e3 100644 --- a/docs/snapshot.md +++ b/docs/snapshot.md @@ -2,7 +2,11 @@ | Parameter | Description of value | |------------------------------------|-----------------------------------------------------------| | fastSnapshotRestoreAvailabilityZones | Comma separated list of availability zones | -| outpostArn | Arn of the outpost you wish to have the snapshot saved to | +| outpostArn | Arn of the outpost you wish to have the snapshot saved to | +| snapshotLockMode | Lock mode (governance/compliance) | +| snapshotLockDuration | Lock duration in days | +| snapshotLockExpirationDate | Lock expiration date (RFC3339 format) | +| snapshotLockCoolOffPeriod | Cool-off period in hours (compliance mode only) | The AWS EBS CSI Driver supports [tagging](tagging.md) through `VolumeSnapshotClass.parameters` (in v1.6.0 and later). ## Prerequisites @@ -44,6 +48,41 @@ parameters: The driver will attempt to check if the availability zones provided are supported for fast snapshot restore before attempting to create the snapshot. If the `EnableFastSnapshotRestores` API call fails, the driver will hard-fail the request and delete the snapshot. This is to ensure that the snapshot is not left in an inconsistent state. +# Snapshot Lock + +The EBS CSI Driver provides support for [EBS Snapshot Lock](https://docs.aws.amazon.com/ebs/latest/userguide/ebs-snapshot-lock.html) via `VolumeSnapshotClass.parameters`. Snapshot locking protects snapshots from accidental or malicious deletion. A locked snapshot can't be deleted. + +**Example - Lock in Governance Mode with Specified Duration** +```yaml +apiVersion: snapshot.storage.k8s.io/v1 +kind: VolumeSnapshotClass +metadata: + name: csi-aws-vsc-locked +driver: ebs.csi.aws.com +deletionPolicy: Delete +parameters: + snapshotLockMode: "governance" + snapshotLockDuration: "7" +``` + +**Example - Lock in Compliance Mode with Expiration Date and Cool Off Period** +```yaml +apiVersion: snapshot.storage.k8s.io/v1 +kind: VolumeSnapshotClass +metadata: + name: csi-aws-vsc-compliance +driver: ebs.csi.aws.com +deletionPolicy: Delete +parameters: + snapshotLockMode: "compliance" + snapshotLockExpirationDate: "2030-12-31T23:59:59Z" + snapshotLockCoolOffPeriod: "24" +``` + +## Failure Mode + +If the `LockSnapshot` API call fails, the driver will hard-fail the request and delete the snapshot. This ensures that the snapshot is not left in an unlocked state when locking was explicitly requested. + # Amazon EBS Local Snapshots on Outposts diff --git a/examples/kubernetes/snapshot/manifests/classes/snapshotclass.yaml b/examples/kubernetes/snapshot/manifests/classes/snapshotclass.yaml index 7a9df00080..d6d2deecdc 100644 --- a/examples/kubernetes/snapshot/manifests/classes/snapshotclass.yaml +++ b/examples/kubernetes/snapshot/manifests/classes/snapshotclass.yaml @@ -17,4 +17,4 @@ kind: VolumeSnapshotClass metadata: name: csi-aws-vsc driver: ebs.csi.aws.com -deletionPolicy: Delete +deletionPolicy: Delete \ No newline at end of file diff --git a/hack/e2e/kops/patch-cluster.yaml b/hack/e2e/kops/patch-cluster.yaml index b734e49e0d..fed66ac033 100644 --- a/hack/e2e/kops/patch-cluster.yaml +++ b/hack/e2e/kops/patch-cluster.yaml @@ -86,7 +86,8 @@ spec: "Effect": "Allow", "Action": [ "ec2:CreateVolume", - "ec2:EnableFastSnapshotRestores" + "ec2:EnableFastSnapshotRestores", + "ec2:LockSnapshot" ], "Resource": "arn:aws:ec2:*:*:snapshot/*" }, diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index b2394fee4d..049c9c9fe0 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -309,12 +309,17 @@ type ListSnapshotsResponse struct { NextToken string } -// SnapshotOptions represents parameters to create an EBS volume. +// SnapshotOptions represents parameters to create an EBS snapshot. type SnapshotOptions struct { Tags map[string]string OutpostArn string } +// SnapshotLockOptions represents the snapshot lock specific parameters for locking en EBS snapshot. +type SnapshotLockOptions struct { + LockSnapshotInput ec2.LockSnapshotInput +} + // ec2ListSnapshotsResponse is a helper struct returned from the AWS API calling function to the main ListSnapshots function. type ec2ListSnapshotsResponse struct { Snapshots []types.Snapshot @@ -1872,6 +1877,15 @@ func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOpt }, nil } +func (c *cloud) LockSnapshot(ctx context.Context, lockSnapshotInput ec2.LockSnapshotInput) (*ec2.LockSnapshotOutput, error) { + klog.InfoS("Attempting to lock Snapshot", "request parameters: ", lockSnapshotInput) + response, err := c.ec2.LockSnapshot(ctx, &lockSnapshotInput) + if err != nil { + return nil, err + } + return response, nil +} + func (c *cloud) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) { request := &ec2.DeleteSnapshotInput{} request.SnapshotId = aws.String(snapshotID) diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index b23de66992..f946eb7cb9 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -5320,3 +5320,56 @@ func TestCheckIfIopsIncreaseOnExpansion(t *testing.T) { }) } } + +func TestLockSnapshot(t *testing.T) { + testCases := []struct { + name string + input ec2.LockSnapshotInput + mockError error + expectErr bool + }{ + { + name: "success: API call succeeds", + input: ec2.LockSnapshotInput{ + SnapshotId: aws.String("snap-test-id"), + LockMode: types.LockModeGovernance, + LockDuration: aws.Int32(1), + }, + mockError: nil, + expectErr: false, + }, + { + name: "fail: AWS API error is propagated", + input: ec2.LockSnapshotInput{ + SnapshotId: aws.String("snap-test-id"), + }, + mockError: errors.New("InvalidSnapshot.NotFound"), + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + + ctx := context.Background() + + if tc.mockError != nil { + mockEC2.EXPECT().LockSnapshot(ctx, &tc.input).Return(nil, tc.mockError) + } else { + mockEC2.EXPECT().LockSnapshot(ctx, &tc.input).Return(&ec2.LockSnapshotOutput{}, nil) + } + + _, err := c.LockSnapshot(ctx, tc.input) + + if tc.expectErr { + require.Error(t, err) + require.Equal(t, tc.mockError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/cloud/interface.go b/pkg/cloud/interface.go index 4297b26fbf..29232450e7 100644 --- a/pkg/cloud/interface.go +++ b/pkg/cloud/interface.go @@ -42,4 +42,5 @@ type Cloud interface { AvailabilityZones(ctx context.Context) (map[string]struct{}, error) DryRun(ctx context.Context) error GetInstancesPatching(ctx context.Context, nodeIDs []string) ([]*types.Instance, error) + LockSnapshot(ctx context.Context, lockOptions ec2.LockSnapshotInput) (*ec2.LockSnapshotOutput, error) } diff --git a/pkg/cloud/mock_cloud.go b/pkg/cloud/mock_cloud.go index e63591e96e..b77b0bb6b3 100644 --- a/pkg/cloud/mock_cloud.go +++ b/pkg/cloud/mock_cloud.go @@ -289,6 +289,21 @@ func (mr *MockCloudMockRecorder) ListSnapshots(ctx, volumeID, maxResults, nextTo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSnapshots", reflect.TypeOf((*MockCloud)(nil).ListSnapshots), ctx, volumeID, maxResults, nextToken) } +// LockSnapshot mocks base method. +func (m *MockCloud) LockSnapshot(ctx context.Context, lockOptions ec2.LockSnapshotInput) (*ec2.LockSnapshotOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LockSnapshot", ctx, lockOptions) + ret0, _ := ret[0].(*ec2.LockSnapshotOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LockSnapshot indicates an expected call of LockSnapshot. +func (mr *MockCloudMockRecorder) LockSnapshot(ctx, lockOptions interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockSnapshot", reflect.TypeOf((*MockCloud)(nil).LockSnapshot), ctx, lockOptions) +} + // ModifyTags mocks base method. func (m *MockCloud) ModifyTags(ctx context.Context, volumeID string, tagOptions ModifyTagsOptions) error { m.ctrl.T.Helper() diff --git a/pkg/cloud/mock_ec2.go b/pkg/cloud/mock_ec2.go index dc293e42b3..cac1f8be4b 100644 --- a/pkg/cloud/mock_ec2.go +++ b/pkg/cloud/mock_ec2.go @@ -375,6 +375,26 @@ func (mr *MockEC2APIMockRecorder) EnableFastSnapshotRestores(ctx, params interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableFastSnapshotRestores", reflect.TypeOf((*MockEC2API)(nil).EnableFastSnapshotRestores), varargs...) } +// LockSnapshot mocks base method. +func (m *MockEC2API) LockSnapshot(ctx context.Context, params *ec2.LockSnapshotInput, optFns ...func(*ec2.Options)) (*ec2.LockSnapshotOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "LockSnapshot", varargs...) + ret0, _ := ret[0].(*ec2.LockSnapshotOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LockSnapshot indicates an expected call of LockSnapshot. +func (mr *MockEC2APIMockRecorder) LockSnapshot(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockSnapshot", reflect.TypeOf((*MockEC2API)(nil).LockSnapshot), varargs...) +} + // ModifyVolume mocks base method. func (m *MockEC2API) ModifyVolume(ctx context.Context, params *ec2.ModifyVolumeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyVolumeOutput, error) { m.ctrl.T.Helper() diff --git a/pkg/driver/constants.go b/pkg/driver/constants.go index 38d6a2ea69..ce9cc56fbc 100644 --- a/pkg/driver/constants.go +++ b/pkg/driver/constants.go @@ -126,6 +126,18 @@ const ( const ( // FastSnapshotRestoreAvailabilityZones represents key for fast snapshot restore availability zones. FastSnapshotRestoreAvailabilityZones = "fastsnapshotrestoreavailabilityzones" + + // SnapshotLockMode represents a key for indicating whether snapshots are locked in Governance or Compliance mode. + SnapshotLockMode = "snapshotlockmode" + + // SnapshotLockDuration is a key for the duration for which to lock the snapshots, specified in days. + SnapshotLockDuration = "snapshotlockduration" + + // SnapshotLockExpirationDate is a key for specifying the expiration date for the snapshot lock, specified in the format "YYYY-MM-DDThh:mm:ss.sssZ". + SnapshotLockExpirationDate = "snapshotlockexpirationdate" + + // SnapshotLockCoolOffPeriod is a key specifying the cooling-off period for compliance mode, specified in hours. + SnapshotLockCoolOffPeriod = "snapshotlockcooloffperiod" ) // constants for volume tags and their values. diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 9fcf090d23..b6de01efb9 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -23,8 +23,11 @@ import ( "maps" "strconv" "strings" + "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/awslabs/volume-modifier-for-k8s/pkg/rpc" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" @@ -857,6 +860,7 @@ func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS var vscTags []string var fsrAvailabilityZones []string vsProps := new(template.VolumeSnapshotProps) + vsLock := new(cloud.SnapshotLockOptions) for key, value := range req.GetParameters() { switch strings.ToLower(key) { case VolumeSnapshotNameKey: @@ -874,6 +878,26 @@ func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS } else { return nil, status.Errorf(codes.InvalidArgument, "Invalid parameter value %s is not a valid arn", value) } + case SnapshotLockMode: + vsLock.LockSnapshotInput.LockMode = types.LockMode(value) + case SnapshotLockDuration: + lockDuration, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Could not parse SnapshotLockDuration: %q", value) + } + vsLock.LockSnapshotInput.LockDuration = aws.Int32(int32(lockDuration)) + case SnapshotLockExpirationDate: + expirationDate, err := time.Parse(time.RFC3339, value) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Could not parse SnapshotLockExpirationDate: %q", value) + } + vsLock.LockSnapshotInput.ExpirationDate = &expirationDate + case SnapshotLockCoolOffPeriod: + lockCoolOffPeriod, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Could not parse SnapshotLockCoolOffPeriod: %q", value) + } + vsLock.LockSnapshotInput.CoolOffPeriod = aws.Int32(int32(lockCoolOffPeriod)) default: if strings.HasPrefix(key, TagKeyPrefix) { vscTags = append(vscTags, value) @@ -934,12 +958,18 @@ func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS if len(fsrAvailabilityZones) > 0 { _, err := d.cloud.EnableFastSnapshotRestores(ctx, fsrAvailabilityZones, snapshot.SnapshotID) if err != nil { - if _, deleteErr := d.cloud.DeleteSnapshot(ctx, snapshot.SnapshotID); deleteErr != nil { - return nil, status.Errorf(codes.Internal, "Could not delete snapshot ID %q: %v", snapshotName, deleteErr) - } - return nil, status.Errorf(codes.Internal, "Failed to create Fast Snapshot Restores for snapshot ID %q: %v", snapshotName, err) + return nil, d.cleanupSnapshotOnError(ctx, snapshot.SnapshotID, snapshotName, err, "Failed to create Fast Snapshot Restores") } } + + if vsLock.LockSnapshotInput.LockMode != "" || vsLock.LockSnapshotInput.LockDuration != nil || vsLock.LockSnapshotInput.ExpirationDate != nil || vsLock.LockSnapshotInput.CoolOffPeriod != nil { + vsLock.LockSnapshotInput.SnapshotId = &snapshot.SnapshotID + _, err := d.cloud.LockSnapshot(ctx, vsLock.LockSnapshotInput) + if err != nil { + return nil, d.cleanupSnapshotOnError(ctx, snapshot.SnapshotID, snapshotName, err, "Failed to lock snapshot") + } + } + return newCreateSnapshotResponse(snapshot), nil } @@ -1297,3 +1327,10 @@ func validateFormattingOption(volumeCapabilities []*csi.VolumeCapability, paramN func isTrue(value string) bool { return value == trueStr } + +func (d *ControllerService) cleanupSnapshotOnError(ctx context.Context, snapshotID, snapshotName string, originalErr error, errorMsg string) error { + if _, deleteErr := d.cloud.DeleteSnapshot(ctx, snapshotID); deleteErr != nil { + return status.Errorf(codes.Internal, "Could not delete snapshot ID %q: %v", snapshotName, deleteErr) + } + return status.Errorf(codes.Internal, "%s for snapshot ID %q: %v", errorMsg, snapshotName, originalErr) +} diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index ba01b7dae8..e45d9a3241 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -3915,6 +3915,283 @@ func TestCreateSnapshot(t *testing.T) { } }, }, + { + name: "success with snapshot lock governance mode", + testFunc: func(t *testing.T) { + t.Helper() + const ( + snapshotName = "test-snapshot" + ) + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + SnapshotLockMode: "governance", + SnapshotLockDuration: "1", + }, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + + ctx := t.Context() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.GetSourceVolumeId(), + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + expLockSnapshotInput := ec2.LockSnapshotInput{ + SnapshotId: &mockSnapshot.SnapshotID, + LockMode: types.LockModeGovernance, + LockDuration: aws.Int32(1), + } + + expOutput := &ec2.LockSnapshotOutput{ + SnapshotId: &mockSnapshot.SnapshotID, + LockState: types.LockStateGovernance, + LockDuration: aws.Int32(1), + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.GetSourceVolumeId()), gomock.Any()).Return(mockSnapshot, nil) + mockCloud.EXPECT().LockSnapshot(gomock.Eq(ctx), gomock.Eq(expLockSnapshotInput)).Return(expOutput, nil) + + awsDriver := ControllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + options: &Options{}, + } + resp, err := awsDriver.CreateSnapshot(t.Context(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if snap := resp.GetSnapshot(); snap == nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + }, + }, + { + name: "success with snapshot lock compliance mode", + testFunc: func(t *testing.T) { + t.Helper() + const ( + snapshotName = "test-snapshot" + ) + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + SnapshotLockMode: "compliance", + SnapshotLockDuration: "7", + SnapshotLockCoolOffPeriod: "24", + }, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + + ctx := t.Context() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.GetSourceVolumeId(), + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + expLockSnapshotInput := ec2.LockSnapshotInput{ + SnapshotId: &mockSnapshot.SnapshotID, + LockMode: types.LockModeCompliance, + LockDuration: aws.Int32(7), + CoolOffPeriod: aws.Int32(24), + } + + expOutput := &ec2.LockSnapshotOutput{ + SnapshotId: &mockSnapshot.SnapshotID, + LockState: types.LockStateCompliance, + LockDuration: aws.Int32(7), + CoolOffPeriod: aws.Int32(24), + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.GetSourceVolumeId()), gomock.Any()).Return(mockSnapshot, nil) + mockCloud.EXPECT().LockSnapshot(gomock.Eq(ctx), gomock.Eq(expLockSnapshotInput)).Return(expOutput, nil) + + awsDriver := ControllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + options: &Options{}, + } + resp, err := awsDriver.CreateSnapshot(t.Context(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if snap := resp.GetSnapshot(); snap == nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + }, + }, + { + name: "success with snapshot lock governance mode with expiration date", + testFunc: func(t *testing.T) { + t.Helper() + const ( + snapshotName = "test-snapshot" + ) + expirationDate := time.Now().Add(24 * time.Hour).Format(time.RFC3339) + expectedTime, _ := time.Parse(time.RFC3339, expirationDate) + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + SnapshotLockMode: "governance", + SnapshotLockExpirationDate: expirationDate, + }, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + + ctx := t.Context() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.GetSourceVolumeId(), + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + expLockSnapshotInput := ec2.LockSnapshotInput{ + SnapshotId: &mockSnapshot.SnapshotID, + LockMode: types.LockModeGovernance, + ExpirationDate: &expectedTime, + } + + expOutput := &ec2.LockSnapshotOutput{ + SnapshotId: &mockSnapshot.SnapshotID, + LockState: types.LockStateGovernance, + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.GetSourceVolumeId()), gomock.Any()).Return(mockSnapshot, nil) + mockCloud.EXPECT().LockSnapshot(gomock.Eq(ctx), gomock.Eq(expLockSnapshotInput)).Return(expOutput, nil) + + awsDriver := ControllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + options: &Options{}, + } + resp, err := awsDriver.CreateSnapshot(t.Context(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if snap := resp.GetSnapshot(); snap == nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + }, + }, + { + name: "fail with snapshot lock and cleanup snapshot", + testFunc: func(t *testing.T) { + t.Helper() + const ( + snapshotName = "test-snapshot" + ) + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + SnapshotLockMode: "governance", + SnapshotLockDuration: "1", + }, + SourceVolumeId: "vol-test", + } + + ctx := t.Context() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.GetSourceVolumeId(), + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.GetSourceVolumeId()), gomock.Any()).Return(mockSnapshot, nil) + mockCloud.EXPECT().LockSnapshot(gomock.Eq(ctx), gomock.Any()).Return(nil, errors.New("Failed to lock snapshot")) + mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq(mockSnapshot.SnapshotID)).Return(true, nil) + + awsDriver := ControllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + options: &Options{}, + } + _, err := awsDriver.CreateSnapshot(t.Context(), req) + if err == nil { + t.Fatalf("Expected error, got nil") + } + }, + }, + { + name: "should still call LockSnapshot without all required parameters", + testFunc: func(t *testing.T) { + t.Helper() + const ( + snapshotName = "test-snapshot" + ) + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + SnapshotLockCoolOffPeriod: "2", + }, + SourceVolumeId: "vol-test", + } + + ctx := t.Context() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.GetSourceVolumeId(), + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + expLockSnapshotInput := ec2.LockSnapshotInput{ + SnapshotId: &mockSnapshot.SnapshotID, + CoolOffPeriod: aws.Int32(2), + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.GetSourceVolumeId()), gomock.Any()).Return(mockSnapshot, nil) + mockCloud.EXPECT().LockSnapshot(gomock.Eq(ctx), gomock.Eq(expLockSnapshotInput)).Return(nil, errors.New("Failed to lock snapshot due to missing parameters")) + mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq(mockSnapshot.SnapshotID)).Return(true, nil) + + awsDriver := ControllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + options: &Options{}, + } + _, err := awsDriver.CreateSnapshot(t.Context(), req) + if err == nil { + t.Fatalf("Expected error, got nil") + } + }, + }, { name: "success with EnableFastSnapshotRestore - failed to get availability zones", testFunc: func(t *testing.T) { diff --git a/pkg/util/ec2_interface.go b/pkg/util/ec2_interface.go index 7f7fac8a7c..999e5a5727 100644 --- a/pkg/util/ec2_interface.go +++ b/pkg/util/ec2_interface.go @@ -42,4 +42,5 @@ type EC2API interface { CreateTags(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) DeleteTags(ctx context.Context, params *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) EnableFastSnapshotRestores(ctx context.Context, params *ec2.EnableFastSnapshotRestoresInput, optFns ...func(*ec2.Options)) (*ec2.EnableFastSnapshotRestoresOutput, error) + LockSnapshot(ctx context.Context, params *ec2.LockSnapshotInput, optFns ...func(*ec2.Options)) (*ec2.LockSnapshotOutput, error) } diff --git a/tests/e2e/requires_aws_api.go b/tests/e2e/requires_aws_api.go index 12a7154331..18251c1e45 100644 --- a/tests/e2e/requires_aws_api.go +++ b/tests/e2e/requires_aws_api.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "strconv" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -351,6 +352,355 @@ var _ = Describe("[ebs-csi-e2e] [single-az] [requires-aws-api] Dynamic Provision } test.Run(cs, snapshotrcs, ns) }) + + It("should create a snapshot with governance mode lock for 1 day", func() { + pod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdWriteToVolume("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + restoredPod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdGrepVolumeData("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + test := testsuites.DynamicallyProvisionedVolumeSnapshotTest{ + CSIDriver: ebsDriver, + Pod: pod, + RestoredPod: restoredPod, + Parameters: map[string]string{ + ebscsidriver.SnapshotLockMode: "governance", + ebscsidriver.SnapshotLockDuration: "1", + }, + ValidateFunc: func(snapshot *volumesnapshotv1.VolumeSnapshot) { + describeResult := validateEc2Snapshot(context.Background(), ec2Client, &ec2.DescribeSnapshotsInput{ + Filters: []types.Filter{ + { + Name: aws.String("tag:" + awscloud.SnapshotNameTagKey), + Values: []string{"snapshot-" + string(snapshot.UID)}, + }, + }, + }) + + snapshotId := *describeResult.Snapshots[0].SnapshotId + + result, err := ec2Client.DescribeLockedSnapshots(context.Background(), &ec2.DescribeLockedSnapshotsInput{ + SnapshotIds: []string{snapshotId}, + }) + if err != nil { + Fail(fmt.Sprintf("failed to describe locked snapshots: %v", err)) + } + + if len(result.Snapshots) != 1 { + Fail(fmt.Sprintf("expected 1 locked snapshot, got %d", len(result.Snapshots))) + } + + lockedSnapshot := result.Snapshots[0] + if types.LockMode(lockedSnapshot.LockState) != types.LockModeGovernance { + Fail(fmt.Sprintf("expected lock mode governance, got %s", lockedSnapshot.LockState)) + } + + if lockedSnapshot.LockDuration == nil || *lockedSnapshot.LockDuration != 1 { + Fail(fmt.Sprintf("expected lock duration 1 day, got %v", lockedSnapshot.LockDuration)) + } + + _, err = ec2Client.UnlockSnapshot(context.Background(), &ec2.UnlockSnapshotInput{ + SnapshotId: aws.String(snapshotId), + }) + if err != nil { + Fail(fmt.Sprintf("failed to unlock snapshot: %v", err)) + } + }, + } + test.Run(cs, snapshotrcs, ns) + }) + + It("should create a snapshot with governance mode lock using expiration date", func() { + // Set expiration date to 1 day from now + expirationDate := time.Now().Add(24 * time.Hour).Format(time.RFC3339) + + pod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdWriteToVolume("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + restoredPod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdGrepVolumeData("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + test := testsuites.DynamicallyProvisionedVolumeSnapshotTest{ + CSIDriver: ebsDriver, + Pod: pod, + RestoredPod: restoredPod, + Parameters: map[string]string{ + ebscsidriver.SnapshotLockMode: "governance", + ebscsidriver.SnapshotLockExpirationDate: expirationDate, + }, + ValidateFunc: func(snapshot *volumesnapshotv1.VolumeSnapshot) { + describeResult := validateEc2Snapshot(context.Background(), ec2Client, &ec2.DescribeSnapshotsInput{ + Filters: []types.Filter{ + { + Name: aws.String("tag:" + awscloud.SnapshotNameTagKey), + Values: []string{"snapshot-" + string(snapshot.UID)}, + }, + }, + }) + + snapshotId := *describeResult.Snapshots[0].SnapshotId + + result, err := ec2Client.DescribeLockedSnapshots(context.Background(), &ec2.DescribeLockedSnapshotsInput{ + SnapshotIds: []string{snapshotId}, + }) + if err != nil { + Fail(fmt.Sprintf("failed to describe locked snapshots: %v", err)) + } + + if len(result.Snapshots) != 1 { + Fail(fmt.Sprintf("expected 1 locked snapshot, got %d", len(result.Snapshots))) + } + + lockedSnapshot := result.Snapshots[0] + if types.LockMode(lockedSnapshot.LockState) != types.LockModeGovernance { + Fail(fmt.Sprintf("expected lock mode governance, got %s", lockedSnapshot.LockState)) + } + + if lockedSnapshot.LockCreatedOn == nil { + Fail("expected lock creation date to be set") + } + + _, err = ec2Client.UnlockSnapshot(context.Background(), &ec2.UnlockSnapshotInput{ + SnapshotId: aws.String(snapshotId), + }) + if err != nil { + Fail(fmt.Sprintf("failed to unlock snapshot: %v", err)) + } + }, + } + test.Run(cs, snapshotrcs, ns) + }) + + It("should create a snapshot with compliance mode lock for 2 days with 12 hour cooloff", func() { + pod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdWriteToVolume("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + restoredPod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdGrepVolumeData("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + test := testsuites.DynamicallyProvisionedVolumeSnapshotTest{ + CSIDriver: ebsDriver, + Pod: pod, + RestoredPod: restoredPod, + Parameters: map[string]string{ + ebscsidriver.SnapshotLockMode: "compliance", + ebscsidriver.SnapshotLockDuration: "2", + // Must have cooloff otherwise will not be able to cleanup. + ebscsidriver.SnapshotLockCoolOffPeriod: "12", + }, + ValidateFunc: func(snapshot *volumesnapshotv1.VolumeSnapshot) { + describeResult := validateEc2Snapshot(context.Background(), ec2Client, &ec2.DescribeSnapshotsInput{ + Filters: []types.Filter{ + { + Name: aws.String("tag:" + awscloud.SnapshotNameTagKey), + Values: []string{"snapshot-" + string(snapshot.UID)}, + }, + }, + }) + + snapshotId := *describeResult.Snapshots[0].SnapshotId + + result, err := ec2Client.DescribeLockedSnapshots(context.Background(), &ec2.DescribeLockedSnapshotsInput{ + SnapshotIds: []string{snapshotId}, + }) + if err != nil { + Fail(fmt.Sprintf("failed to describe locked snapshots: %v", err)) + } + + if len(result.Snapshots) != 1 { + Fail(fmt.Sprintf("expected 1 locked snapshot, got %d", len(result.Snapshots))) + } + + lockedSnapshot := result.Snapshots[0] + if types.LockMode(lockedSnapshot.LockState) != types.LockMode(types.LockStateComplianceCooloff) { + Fail(fmt.Sprintf("expected lock mode compliance, got %s", lockedSnapshot.LockState)) + } + + if lockedSnapshot.LockDuration == nil || *lockedSnapshot.LockDuration != 2 { + Fail(fmt.Sprintf("expected lock duration 2 days, got %v", lockedSnapshot.LockDuration)) + } + + if lockedSnapshot.CoolOffPeriod == nil || *lockedSnapshot.CoolOffPeriod != 12 { + Fail(fmt.Sprintf("expected cooloff period 12 hours, got %v", lockedSnapshot.CoolOffPeriod)) + } + + _, err = ec2Client.UnlockSnapshot(context.Background(), &ec2.UnlockSnapshotInput{ + SnapshotId: aws.String(snapshotId), + }) + if err != nil { + Fail(fmt.Sprintf("failed to unlock snapshot: %v", err)) + } + }, + } + test.Run(cs, snapshotrcs, ns) + }) + + It("should create a snapshot with FSR enabled and governance mode lock for 1 day", func() { + azList, err := ec2Client.DescribeAvailabilityZones(context.Background(), &ec2.DescribeAvailabilityZonesInput{}) + if err != nil { + Fail(fmt.Sprintf("failed to list AZs: %v", err)) + } + fsrAvailabilityZone := *azList.AvailabilityZones[0].ZoneName + + pod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdWriteToVolume("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + restoredPod := testsuites.PodDetails{ + Cmd: testsuites.PodCmdGrepVolumeData("/mnt/test-1"), + Volumes: []testsuites.VolumeDetails{ + { + CreateVolumeParameters: map[string]string{ + ebscsidriver.VolumeTypeKey: awscloud.VolumeTypeGP3, + ebscsidriver.FSTypeKey: ebscsidriver.FSTypeExt4, + }, + ClaimSize: driver.MinimumSizeForVolumeType(awscloud.VolumeTypeGP3), + VolumeMount: testsuites.DefaultGeneratedVolumeMount, + }, + }, + } + test := testsuites.DynamicallyProvisionedVolumeSnapshotTest{ + CSIDriver: ebsDriver, + Pod: pod, + RestoredPod: restoredPod, + Parameters: map[string]string{ + ebscsidriver.FastSnapshotRestoreAvailabilityZones: fsrAvailabilityZone, + ebscsidriver.SnapshotLockMode: "governance", + ebscsidriver.SnapshotLockDuration: "1", + }, + ValidateFunc: func(snapshot *volumesnapshotv1.VolumeSnapshot) { + describeResult := validateEc2Snapshot(context.Background(), ec2Client, &ec2.DescribeSnapshotsInput{ + Filters: []types.Filter{ + { + Name: aws.String("tag:" + awscloud.SnapshotNameTagKey), + Values: []string{"snapshot-" + string(snapshot.UID)}, + }, + }, + }) + + snapshotId := *describeResult.Snapshots[0].SnapshotId + + fsrResult, err := ec2Client.DescribeFastSnapshotRestores(context.Background(), &ec2.DescribeFastSnapshotRestoresInput{ + Filters: []types.Filter{ + { + Name: aws.String("snapshot-id"), + Values: []string{snapshotId}, + }, + }, + }) + if err != nil { + Fail(fmt.Sprintf("failed to describe FSR: %v", err)) + } + + if len(fsrResult.FastSnapshotRestores) != 1 { + Fail(fmt.Sprintf("expected 1 FSR, got %d", len(fsrResult.FastSnapshotRestores))) + } + + if *fsrResult.FastSnapshotRestores[0].AvailabilityZone != fsrAvailabilityZone { + Fail(fmt.Sprintf("expected FSR for %s, got %s", fsrAvailabilityZone, *fsrResult.FastSnapshotRestores[0].AvailabilityZone)) + } + + lockResult, err := ec2Client.DescribeLockedSnapshots(context.Background(), &ec2.DescribeLockedSnapshotsInput{ + SnapshotIds: []string{snapshotId}, + }) + if err != nil { + Fail(fmt.Sprintf("failed to describe locked snapshots: %v", err)) + } + + if len(lockResult.Snapshots) != 1 { + Fail(fmt.Sprintf("expected 1 locked snapshot, got %d", len(lockResult.Snapshots))) + } + + lockedSnapshot := lockResult.Snapshots[0] + if types.LockMode(lockedSnapshot.LockState) != types.LockModeGovernance { + Fail(fmt.Sprintf("expected lock mode governance, got %s", lockedSnapshot.LockState)) + } + + if lockedSnapshot.LockDuration == nil || *lockedSnapshot.LockDuration != 1 { + Fail(fmt.Sprintf("expected lock duration 1 day, got %v", lockedSnapshot.LockDuration)) + } + + _, err = ec2Client.UnlockSnapshot(context.Background(), &ec2.UnlockSnapshotInput{ + SnapshotId: aws.String(snapshotId), + }) + if err != nil { + Fail(fmt.Sprintf("failed to unlock snapshot: %v", err)) + } + }, + } + test.Run(cs, snapshotrcs, ns) + }) + It("should copy a volume with different volume parameters", func() { pod := testsuites.PodDetails{ Cmd: testsuites.PodCmdWriteToVolume("/mnt/test-1"), diff --git a/tests/sanity/fake_sanity_cloud.go b/tests/sanity/fake_sanity_cloud.go index 9c2350887b..0f2519a693 100644 --- a/tests/sanity/fake_sanity_cloud.go +++ b/tests/sanity/fake_sanity_cloud.go @@ -210,6 +210,10 @@ func (d *fakeCloud) EnableFastSnapshotRestores(ctx context.Context, availability return &ec2.EnableFastSnapshotRestoresOutput{}, nil } +func (d *fakeCloud) LockSnapshot(ctx context.Context, lockOptions ec2.LockSnapshotInput) (*ec2.LockSnapshotOutput, error) { + return &ec2.LockSnapshotOutput{}, nil +} + func (d *fakeCloud) GetDiskByName(ctx context.Context, name string, capacityBytes int64) (*cloud.Disk, error) { return &cloud.Disk{}, nil }