Skip to content

Commit 389d8fa

Browse files
committed
Add snapshot limit enforcement for WCP with per-volume serialization
1 parent 7a6acd9 commit 389d8fa

File tree

5 files changed

+623
-5
lines changed

5 files changed

+623
-5
lines changed

pkg/csi/service/common/constants.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ const (
353353
// Guest cluster.
354354
SupervisorVolumeSnapshotAnnotationKey = "csi.vsphere.guest-initiated-csi-snapshot"
355355

356+
// MaxSnapshotsPerVolumeAnnotationKey represents the annotation key on Namespace CR
357+
// in Supervisor cluster to specify the maximum number of snapshots per volume
358+
MaxSnapshotsPerVolumeAnnotationKey = "csi.vsphere.max-snapshots-per-volume"
359+
360+
// DefaultMaxSnapshotsPerBlockVolumeInWCP is the default maximum number of snapshots per block volume in WCP
361+
DefaultMaxSnapshotsPerBlockVolumeInWCP = 4
362+
363+
// MaxAllowedSnapshotsPerBlockVolume is the hard cap for maximum snapshots per block volume
364+
MaxAllowedSnapshotsPerBlockVolume = 32
365+
356366
// AttributeSupervisorVolumeSnapshotClass represents name of VolumeSnapshotClass
357367
AttributeSupervisorVolumeSnapshotClass = "svvolumesnapshotclass"
358368

@@ -467,6 +477,8 @@ const (
467477
// WCPVMServiceVMSnapshots is a supervisor capability indicating
468478
// if supports_VM_service_VM_snapshots FSS is enabled
469479
WCPVMServiceVMSnapshots = "supports_VM_service_VM_snapshots"
480+
// SnapshotLimitWCP is an internal FSS that enables snapshot limit enforcement in WCP
481+
SnapshotLimitWCP = "snapshot-limit-wcp"
470482
)
471483

472484
var WCPFeatureStates = map[string]struct{}{

pkg/csi/service/wcp/controller.go

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,23 @@ var (
9898
vmMoidToHostMoid, volumeIDToVMMap map[string]string
9999
)
100100

101+
// volumeLock represents a lock for a specific volume with reference counting
102+
type volumeLock struct {
103+
mutex sync.Mutex
104+
refCount int
105+
}
106+
107+
// snapshotLockManager manages per-volume locks for snapshot operations
108+
type snapshotLockManager struct {
109+
locks map[string]*volumeLock
110+
mapMutex sync.RWMutex
111+
}
112+
101113
type controller struct {
102-
manager *common.Manager
103-
authMgr common.AuthorizationService
104-
topologyMgr commoncotypes.ControllerTopologyService
114+
manager *common.Manager
115+
authMgr common.AuthorizationService
116+
topologyMgr commoncotypes.ControllerTopologyService
117+
snapshotLockMgr *snapshotLockManager
105118
csi.UnimplementedControllerServer
106119
}
107120

@@ -211,6 +224,12 @@ func (c *controller) Init(config *cnsconfig.Config, version string) error {
211224
CryptoClient: cryptoClient,
212225
}
213226

227+
// Initialize snapshot lock manager
228+
c.snapshotLockMgr = &snapshotLockManager{
229+
locks: make(map[string]*volumeLock),
230+
}
231+
log.Info("Initialized snapshot lock manager for per-volume serialization")
232+
214233
vc, err := common.GetVCenter(ctx, c.manager)
215234
if err != nil {
216235
log.Errorf("failed to get vcenter. err=%v", err)
@@ -447,6 +466,53 @@ func (c *controller) ReloadConfiguration(reconnectToVCFromNewConfig bool) error
447466
return nil
448467
}
449468

469+
// acquireSnapshotLock acquires a lock for the given volume ID.
470+
// It creates a new lock if one doesn't exist and increments the reference count.
471+
// The caller must call releaseSnapshotLock when done.
472+
func (c *controller) acquireSnapshotLock(ctx context.Context, volumeID string) {
473+
log := logger.GetLogger(ctx)
474+
c.snapshotLockMgr.mapMutex.Lock()
475+
defer c.snapshotLockMgr.mapMutex.Unlock()
476+
477+
vLock, exists := c.snapshotLockMgr.locks[volumeID]
478+
if !exists {
479+
vLock = &volumeLock{}
480+
c.snapshotLockMgr.locks[volumeID] = vLock
481+
log.Debugf("Created new lock for volume %q", volumeID)
482+
}
483+
vLock.refCount++
484+
log.Debugf("Acquired lock for volume %q, refCount: %d", volumeID, vLock.refCount)
485+
486+
// Unlock the map before acquiring the volume lock to avoid deadlock
487+
c.snapshotLockMgr.mapMutex.Unlock()
488+
vLock.mutex.Lock()
489+
c.snapshotLockMgr.mapMutex.Lock()
490+
}
491+
492+
// releaseSnapshotLock releases the lock for the given volume ID.
493+
// It decrements the reference count and removes the lock if count reaches zero.
494+
func (c *controller) releaseSnapshotLock(ctx context.Context, volumeID string) {
495+
log := logger.GetLogger(ctx)
496+
c.snapshotLockMgr.mapMutex.Lock()
497+
defer c.snapshotLockMgr.mapMutex.Unlock()
498+
499+
vLock, exists := c.snapshotLockMgr.locks[volumeID]
500+
if !exists {
501+
log.Warnf("Attempted to release non-existent lock for volume %q", volumeID)
502+
return
503+
}
504+
505+
vLock.mutex.Unlock()
506+
vLock.refCount--
507+
log.Debugf("Released lock for volume %q, refCount: %d", volumeID, vLock.refCount)
508+
509+
// Clean up the lock if reference count reaches zero
510+
if vLock.refCount == 0 {
511+
delete(c.snapshotLockMgr.locks, volumeID)
512+
log.Debugf("Cleaned up lock for volume %q", volumeID)
513+
}
514+
}
515+
450516
// createBlockVolume creates a block volume based on the CreateVolumeRequest.
451517
func (c *controller) createBlockVolume(ctx context.Context, req *csi.CreateVolumeRequest,
452518
isWorkloadDomainIsolationEnabled bool, clusterMoIds []string) (
@@ -2446,8 +2512,47 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot
24462512
"Queried VolumeType: %v", volumeType, cnsVolumeDetailsMap[volumeID].VolumeType)
24472513
}
24482514

2449-
// TODO: We may need to add logic to check the limit of max number of snapshots by using
2450-
// GlobalMaxSnapshotsPerBlockVolume etc. variables in the future.
2515+
// Acquire lock for this volume to serialize snapshot operations
2516+
// Check snapshot limit if the feature is enabled
2517+
isSnapshotLimitWCPEnabled := commonco.ContainerOrchestratorUtility.IsFSSEnabled(ctx, common.SnapshotLimitWCP)
2518+
if isSnapshotLimitWCPEnabled {
2519+
c.acquireSnapshotLock(ctx, volumeID)
2520+
defer c.releaseSnapshotLock(ctx, volumeID)
2521+
2522+
// Extract namespace from request parameters
2523+
volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey]
2524+
if volumeSnapshotNamespace == "" {
2525+
return nil, logger.LogNewErrorCodef(log, codes.Internal,
2526+
"volumesnapshot namespace is not set in the request parameters")
2527+
}
2528+
2529+
// Get snapshot limit from namespace annotation
2530+
snapshotLimit, err := getSnapshotLimitFromNamespace(ctx, volumeSnapshotNamespace)
2531+
if err != nil {
2532+
return nil, logger.LogNewErrorCodef(log, codes.Internal,
2533+
"failed to get snapshot limit for namespace %q: %v", volumeSnapshotNamespace, err)
2534+
}
2535+
log.Infof("Snapshot limit for namespace %q is set to %d", volumeSnapshotNamespace, snapshotLimit)
2536+
2537+
// Query existing snapshots for this volume
2538+
snapshotList, _, err := common.QueryVolumeSnapshotsByVolumeID(ctx, c.manager.VolumeManager, volumeID,
2539+
common.QuerySnapshotLimit)
2540+
if err != nil {
2541+
return nil, logger.LogNewErrorCodef(log, codes.Internal,
2542+
"failed to query snapshots for volume %q: %v", volumeID, err)
2543+
}
2544+
2545+
// Check if the limit is exceeded
2546+
currentSnapshotCount := len(snapshotList)
2547+
if currentSnapshotCount >= snapshotLimit {
2548+
return nil, logger.LogNewErrorCodef(log, codes.FailedPrecondition,
2549+
"the number of snapshots (%d) on the source volume %s has reached or exceeded "+
2550+
"the configured maximum (%d) for namespace %s",
2551+
currentSnapshotCount, volumeID, snapshotLimit, volumeSnapshotNamespace)
2552+
}
2553+
log.Infof("Current snapshot count for volume %q is %d, within limit of %d",
2554+
volumeID, currentSnapshotCount, snapshotLimit)
2555+
}
24512556

24522557
// the returned snapshotID below is a combination of CNS VolumeID and CNS SnapshotID concatenated by the "+"
24532558
// sign. That is, a string of "<UUID>+<UUID>". Because, all other CNS snapshot APIs still require both

pkg/csi/service/wcp/controller_helper.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import (
3737
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3838
"k8s.io/apimachinery/pkg/fields"
3939
"k8s.io/apimachinery/pkg/types"
40+
"k8s.io/client-go/kubernetes"
41+
restclient "k8s.io/client-go/rest"
4042
api "k8s.io/kubernetes/pkg/apis/core"
4143
"sigs.k8s.io/controller-runtime/pkg/client/config"
4244
spv1alpha1 "sigs.k8s.io/vsphere-csi-driver/v3/pkg/apis/storagepool/cns/v1alpha1"
@@ -849,6 +851,15 @@ func validateControllerPublishVolumeRequesInWcp(ctx context.Context, req *csi.Co
849851

850852
var newK8sClient = k8s.NewClient
851853

854+
// getK8sConfig is a variable that can be overridden for testing
855+
var getK8sConfig = config.GetConfig
856+
857+
// newK8sClientFromConfig is a variable that can be overridden for testing
858+
// It wraps kubernetes.NewForConfig and returns Interface for easier testing
859+
var newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) {
860+
return kubernetes.NewForConfig(c)
861+
}
862+
852863
// getPodVMUUID returns the UUID of the VM(running on the node) on which the pod that is trying to
853864
// use the volume is scheduled.
854865
func getPodVMUUID(ctx context.Context, volumeID, nodeName string) (string, error) {
@@ -950,3 +961,63 @@ func GetZonesFromAccessibilityRequirements(ctx context.Context,
950961
}
951962
return zones, nil
952963
}
964+
965+
// getSnapshotLimitFromNamespace retrieves the snapshot limit from the namespace annotation.
966+
// If the annotation is not present, it returns the default value.
967+
// If the annotation value exceeds the maximum allowed, it caps the value and logs a warning.
968+
func getSnapshotLimitFromNamespace(ctx context.Context, namespace string) (int, error) {
969+
log := logger.GetLogger(ctx)
970+
971+
// Get Kubernetes config
972+
cfg, err := getK8sConfig()
973+
if err != nil {
974+
return 0, logger.LogNewErrorCodef(log, codes.Internal,
975+
"failed to get Kubernetes config: %v", err)
976+
}
977+
978+
// Create Kubernetes clientset
979+
k8sClient, err := newK8sClientFromConfig(cfg)
980+
if err != nil {
981+
return 0, logger.LogNewErrorCodef(log, codes.Internal,
982+
"failed to create Kubernetes client: %v", err)
983+
}
984+
985+
// Get namespace object
986+
ns, err := k8sClient.CoreV1().Namespaces().Get(ctx, namespace, metav1.GetOptions{})
987+
if err != nil {
988+
return 0, logger.LogNewErrorCodef(log, codes.Internal,
989+
"failed to get namespace %q: %v", namespace, err)
990+
}
991+
992+
// Check if annotation exists
993+
annotationValue, exists := ns.Annotations[common.MaxSnapshotsPerVolumeAnnotationKey]
994+
if !exists {
995+
log.Infof("Annotation %q not found in namespace %q, using default value %d",
996+
common.MaxSnapshotsPerVolumeAnnotationKey, namespace, common.DefaultMaxSnapshotsPerBlockVolumeInWCP)
997+
return common.DefaultMaxSnapshotsPerBlockVolumeInWCP, nil
998+
}
999+
1000+
// Parse annotation value
1001+
limit, err := strconv.Atoi(annotationValue)
1002+
if err != nil {
1003+
return 0, logger.LogNewErrorCodef(log, codes.Internal,
1004+
"failed to parse annotation %q value %q in namespace %q: %v",
1005+
common.MaxSnapshotsPerVolumeAnnotationKey, annotationValue, namespace, err)
1006+
}
1007+
1008+
// Validate limit
1009+
if limit < 0 {
1010+
return 0, logger.LogNewErrorCodef(log, codes.InvalidArgument,
1011+
"invalid snapshot limit %d in namespace %q: must be >= 0", limit, namespace)
1012+
}
1013+
1014+
// Cap to maximum allowed value
1015+
if limit > common.MaxAllowedSnapshotsPerBlockVolume {
1016+
log.Warnf("Snapshot limit %d in namespace %q exceeds maximum allowed %d, capping to %d",
1017+
limit, namespace, common.MaxAllowedSnapshotsPerBlockVolume, common.MaxAllowedSnapshotsPerBlockVolume)
1018+
return common.MaxAllowedSnapshotsPerBlockVolume, nil
1019+
}
1020+
1021+
log.Infof("Snapshot limit for namespace %q is set to %d", namespace, limit)
1022+
return limit, nil
1023+
}

0 commit comments

Comments
 (0)