@@ -23,8 +23,11 @@ import (
2323 "maps"
2424 "strconv"
2525 "strings"
26+ "time"
2627
28+ "github.com/aws/aws-sdk-go-v2/aws"
2729 "github.com/aws/aws-sdk-go-v2/aws/arn"
30+ "github.com/aws/aws-sdk-go-v2/service/ec2/types"
2831 "github.com/awslabs/volume-modifier-for-k8s/pkg/rpc"
2932 csi "github.com/container-storage-interface/spec/lib/go/csi"
3033 "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
@@ -856,9 +859,11 @@ func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS
856859 cloud .AwsEbsDriverTagKey : isManagedByDriver ,
857860 }
858861
862+ // Get Parameters Here
859863 var vscTags []string
860864 var fsrAvailabilityZones []string
861865 vsProps := new (template.VolumeSnapshotProps )
866+ vsLock := new (cloud.SnapshotLockOptions )
862867 for key , value := range req .GetParameters () {
863868 switch strings .ToLower (key ) {
864869 case VolumeSnapshotNameKey :
@@ -876,6 +881,28 @@ func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS
876881 } else {
877882 return nil , status .Errorf (codes .InvalidArgument , "Invalid parameter value %s is not a valid arn" , value )
878883 }
884+ case SnapshotLockEnabled :
885+ vsLock .SnapshotLockEnabled = isTrue (value )
886+ case SnapshotLockMode :
887+ vsLock .LockSnapshotInput .LockMode = types .LockMode (value )
888+ case SnapshotLockDuration :
889+ lockDuration , err := strconv .ParseInt (value , 10 , 32 )
890+ if err != nil {
891+ return nil , status .Errorf (codes .InvalidArgument , "Could not parse SnapshotLockDuration: %q" , value )
892+ }
893+ vsLock .LockSnapshotInput .LockDuration = aws .Int32 (int32 (lockDuration ))
894+ case SnapshotLockExpirationDate :
895+ expirationDate , err := time .Parse (time .RFC3339 , value )
896+ if err != nil {
897+ return nil , status .Errorf (codes .InvalidArgument , "Could not parse SnapshotLockExpirationDate: %q" , value )
898+ }
899+ vsLock .LockSnapshotInput .ExpirationDate = & expirationDate
900+ case SnapshotLockCoolOffPeriod :
901+ lockCoolOffPeriod , err := strconv .ParseInt (value , 10 , 32 )
902+ if err != nil {
903+ return nil , status .Errorf (codes .InvalidArgument , "Could not parse SnapshotLockCoolOffPeriod: %q" , value )
904+ }
905+ vsLock .LockSnapshotInput .CoolOffPeriod = aws .Int32 (int32 (lockCoolOffPeriod ))
879906 default :
880907 if strings .HasPrefix (key , TagKeyPrefix ) {
881908 vscTags = append (vscTags , value )
@@ -936,12 +963,18 @@ func (d *ControllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS
936963 if len (fsrAvailabilityZones ) > 0 {
937964 _ , err := d .cloud .EnableFastSnapshotRestores (ctx , fsrAvailabilityZones , snapshot .SnapshotID )
938965 if err != nil {
939- if _ , deleteErr := d .cloud .DeleteSnapshot (ctx , snapshot .SnapshotID ); deleteErr != nil {
940- return nil , status .Errorf (codes .Internal , "Could not delete snapshot ID %q: %v" , snapshotName , deleteErr )
941- }
942- return nil , status .Errorf (codes .Internal , "Failed to create Fast Snapshot Restores for snapshot ID %q: %v" , snapshotName , err )
966+ return nil , d .cleanupSnapshotOnError (ctx , snapshot .SnapshotID , snapshotName , err , "Failed to create Fast Snapshot Restores" )
943967 }
944968 }
969+
970+ if vsLock .SnapshotLockEnabled {
971+ vsLock .LockSnapshotInput .SnapshotId = & snapshot .SnapshotID
972+ _ , err := d .cloud .LockSnapshot (ctx , vsLock .LockSnapshotInput )
973+ if err != nil {
974+ return nil , d .cleanupSnapshotOnError (ctx , snapshot .SnapshotID , snapshotName , err , "Failed to lock snapshot" )
975+ }
976+ }
977+
945978 return newCreateSnapshotResponse (snapshot ), nil
946979}
947980
@@ -1299,3 +1332,10 @@ func validateFormattingOption(volumeCapabilities []*csi.VolumeCapability, paramN
12991332func isTrue (value string ) bool {
13001333 return value == trueStr
13011334}
1335+
1336+ func (d * ControllerService ) cleanupSnapshotOnError (ctx context.Context , snapshotID , snapshotName string , originalErr error , errorMsg string ) error {
1337+ if _ , deleteErr := d .cloud .DeleteSnapshot (ctx , snapshotID ); deleteErr != nil {
1338+ return status .Errorf (codes .Internal , "Could not delete snapshot ID %q: %v" , snapshotName , deleteErr )
1339+ }
1340+ return status .Errorf (codes .Internal , "%s for snapshot ID %q: %v" , errorMsg , snapshotName , originalErr )
1341+ }
0 commit comments