Skip to content

Commit 952d1d4

Browse files
committed
Add snapshot limit enforcement for WCP with per-volume serialization
Signed-off-by: Deepak Kinni <[email protected]>
1 parent b28f50c commit 952d1d4

File tree

5 files changed

+648
-6
lines changed

5 files changed

+648
-6
lines changed

pkg/csi/service/common/constants.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,18 @@ const (
353353
// Guest cluster.
354354
SupervisorVolumeSnapshotAnnotationKey = "csi.vsphere.guest-initiated-csi-snapshot"
355355

356+
// ConfigMapCSILimits is the ConfigMap name for CSI limits configuration
357+
ConfigMapCSILimits = "cns-csi-limits"
358+
359+
// ConfigMapKeyMaxSnapshotsPerVolume is the ConfigMap key for snapshot limit per volume
360+
ConfigMapKeyMaxSnapshotsPerVolume = "csi.vsphere.max-snapshots-per-volume"
361+
362+
// DefaultMaxSnapshotsPerVolume is the default maximum number of snapshots per block volume in WCP
363+
DefaultMaxSnapshotsPerVolume = 4
364+
365+
// AbsoluteMaxSnapshotsPerVolume is the hard cap for maximum snapshots per block volume
366+
AbsoluteMaxSnapshotsPerVolume = 32
367+
356368
// AttributeSupervisorVolumeSnapshotClass represents name of VolumeSnapshotClass
357369
AttributeSupervisorVolumeSnapshotClass = "svvolumesnapshotclass"
358370

pkg/csi/service/wcp/controller.go

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,23 @@ var (
9898
vmMoidToHostMoid, volumeIDToVMMap map[string]string
9999
)
100100

101+
// volumeLock represents a lock for a specific volume with reference counting
102+
type volumeLock struct {
103+
mutex sync.Mutex
104+
refCount int
105+
}
106+
107+
// snapshotLockManager manages per-volume locks for snapshot operations
108+
type snapshotLockManager struct {
109+
locks map[string]*volumeLock
110+
mapMutex sync.RWMutex
111+
}
112+
101113
type controller struct {
102-
manager *common.Manager
103-
authMgr common.AuthorizationService
104-
topologyMgr commoncotypes.ControllerTopologyService
114+
manager *common.Manager
115+
authMgr common.AuthorizationService
116+
topologyMgr commoncotypes.ControllerTopologyService
117+
snapshotLockMgr *snapshotLockManager
105118
csi.UnimplementedControllerServer
106119
}
107120

@@ -216,6 +229,12 @@ func (c *controller) Init(config *cnsconfig.Config, version string) error {
216229
CryptoClient: cryptoClient,
217230
}
218231

232+
// Initialize snapshot lock manager
233+
c.snapshotLockMgr = &snapshotLockManager{
234+
locks: make(map[string]*volumeLock),
235+
}
236+
log.Info("Initialized snapshot lock manager for per-volume serialization")
237+
219238
vc, err := common.GetVCenter(ctx, c.manager)
220239
if err != nil {
221240
log.Errorf("failed to get vcenter. err=%v", err)
@@ -452,6 +471,53 @@ func (c *controller) ReloadConfiguration(reconnectToVCFromNewConfig bool) error
452471
return nil
453472
}
454473

474+
// acquireSnapshotLock acquires a lock for the given volume ID.
475+
// It creates a new lock if one doesn't exist and increments the reference count.
476+
// The caller must call releaseSnapshotLock when done.
477+
func (c *controller) acquireSnapshotLock(ctx context.Context, volumeID string) {
478+
log := logger.GetLogger(ctx)
479+
c.snapshotLockMgr.mapMutex.Lock()
480+
defer c.snapshotLockMgr.mapMutex.Unlock()
481+
482+
vLock, exists := c.snapshotLockMgr.locks[volumeID]
483+
if !exists {
484+
vLock = &volumeLock{}
485+
c.snapshotLockMgr.locks[volumeID] = vLock
486+
log.Debugf("Created new lock for volume %q", volumeID)
487+
}
488+
vLock.refCount++
489+
log.Debugf("Acquired lock for volume %q, refCount: %d", volumeID, vLock.refCount)
490+
491+
// Unlock the map before acquiring the volume lock to avoid deadlock
492+
c.snapshotLockMgr.mapMutex.Unlock()
493+
vLock.mutex.Lock()
494+
c.snapshotLockMgr.mapMutex.Lock()
495+
}
496+
497+
// releaseSnapshotLock releases the lock for the given volume ID.
498+
// It decrements the reference count and removes the lock if count reaches zero.
499+
func (c *controller) releaseSnapshotLock(ctx context.Context, volumeID string) {
500+
log := logger.GetLogger(ctx)
501+
c.snapshotLockMgr.mapMutex.Lock()
502+
defer c.snapshotLockMgr.mapMutex.Unlock()
503+
504+
vLock, exists := c.snapshotLockMgr.locks[volumeID]
505+
if !exists {
506+
log.Warnf("Attempted to release non-existent lock for volume %q", volumeID)
507+
return
508+
}
509+
510+
vLock.mutex.Unlock()
511+
vLock.refCount--
512+
log.Debugf("Released lock for volume %q, refCount: %d", volumeID, vLock.refCount)
513+
514+
// Clean up the lock if reference count reaches zero
515+
if vLock.refCount == 0 {
516+
delete(c.snapshotLockMgr.locks, volumeID)
517+
log.Debugf("Cleaned up lock for volume %q", volumeID)
518+
}
519+
}
520+
455521
// createBlockVolume creates a block volume based on the CreateVolumeRequest.
456522
func (c *controller) createBlockVolume(ctx context.Context, req *csi.CreateVolumeRequest,
457523
isWorkloadDomainIsolationEnabled bool, clusterMoIds []string) (
@@ -2451,8 +2517,43 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot
24512517
"Queried VolumeType: %v", volumeType, cnsVolumeDetailsMap[volumeID].VolumeType)
24522518
}
24532519

2454-
// TODO: We may need to add logic to check the limit of max number of snapshots by using
2455-
// GlobalMaxSnapshotsPerBlockVolume etc. variables in the future.
2520+
// Extract namespace from request parameters
2521+
volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey]
2522+
if volumeSnapshotNamespace == "" {
2523+
return nil, logger.LogNewErrorCodef(log, codes.Internal,
2524+
"volumesnapshot namespace is not set in the request parameters")
2525+
}
2526+
2527+
// Get snapshot limit from namespace ConfigMap
2528+
snapshotLimit, err := getSnapshotLimitForNamespace(ctx, volumeSnapshotNamespace)
2529+
if err != nil {
2530+
return nil, logger.LogNewErrorCodef(log, codes.Internal,
2531+
"failed to get snapshot limit for namespace %q: %v", volumeSnapshotNamespace, err)
2532+
}
2533+
log.Infof("Snapshot limit for namespace %q is set to %d", volumeSnapshotNamespace, snapshotLimit)
2534+
2535+
// Acquire lock for this volume to serialize snapshot operations
2536+
c.acquireSnapshotLock(ctx, volumeID)
2537+
defer c.releaseSnapshotLock(ctx, volumeID)
2538+
2539+
// Query existing snapshots for this volume
2540+
snapshotList, _, err := common.QueryVolumeSnapshotsByVolumeID(ctx, c.manager.VolumeManager, volumeID,
2541+
common.QuerySnapshotLimit)
2542+
if err != nil {
2543+
return nil, logger.LogNewErrorCodef(log, codes.Internal,
2544+
"failed to query snapshots for volume %q: %v", volumeID, err)
2545+
}
2546+
2547+
// Check if the limit is exceeded
2548+
currentSnapshotCount := len(snapshotList)
2549+
if currentSnapshotCount >= snapshotLimit {
2550+
return nil, logger.LogNewErrorCodef(log, codes.FailedPrecondition,
2551+
"the number of snapshots (%d) on the source volume %s has reached or exceeded "+
2552+
"the configured maximum (%d) for namespace %s",
2553+
currentSnapshotCount, volumeID, snapshotLimit, volumeSnapshotNamespace)
2554+
}
2555+
log.Infof("Current snapshot count for volume %q is %d, within limit of %d",
2556+
volumeID, currentSnapshotCount, snapshotLimit)
24562557

24572558
// the returned snapshotID below is a combination of CNS VolumeID and CNS SnapshotID concatenated by the "+"
24582559
// sign. That is, a string of "<UUID>+<UUID>". Because, all other CNS snapshot APIs still require both
@@ -2530,7 +2631,6 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot
25302631
cnsSnapshotInfo.SnapshotLatestOperationCompleteTime, createSnapshotResponse)
25312632

25322633
volumeSnapshotName := req.Parameters[common.VolumeSnapshotNameKey]
2533-
volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey]
25342634
log.Infof("Attempting to annotate volumesnapshot %s/%s with annotation %s:%s",
25352635
volumeSnapshotNamespace, volumeSnapshotName, common.VolumeSnapshotInfoKey, snapshotID)
25362636
annotated, err := commonco.ContainerOrchestratorUtility.AnnotateVolumeSnapshot(ctx, volumeSnapshotName,

pkg/csi/service/wcp/controller_helper.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ import (
3434
"google.golang.org/grpc/credentials/insecure"
3535
"google.golang.org/grpc/status"
3636
v1 "k8s.io/api/core/v1"
37+
apierrors "k8s.io/apimachinery/pkg/api/errors"
3738
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3839
"k8s.io/apimachinery/pkg/fields"
3940
"k8s.io/apimachinery/pkg/types"
41+
"k8s.io/client-go/kubernetes"
42+
restclient "k8s.io/client-go/rest"
4043
api "k8s.io/kubernetes/pkg/apis/core"
4144
"sigs.k8s.io/controller-runtime/pkg/client/config"
4245
spv1alpha1 "sigs.k8s.io/vsphere-csi-driver/v3/pkg/apis/storagepool/cns/v1alpha1"
@@ -849,6 +852,15 @@ func validateControllerPublishVolumeRequesInWcp(ctx context.Context, req *csi.Co
849852

850853
var newK8sClient = k8s.NewClient
851854

855+
// getK8sConfig is a variable that can be overridden for testing
856+
var getK8sConfig = config.GetConfig
857+
858+
// newK8sClientFromConfig is a variable that can be overridden for testing
859+
// It wraps kubernetes.NewForConfig and returns Interface for easier testing
860+
var newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) {
861+
return kubernetes.NewForConfig(c)
862+
}
863+
852864
// getPodVMUUID returns the UUID of the VM(running on the node) on which the pod that is trying to
853865
// use the volume is scheduled.
854866
func getPodVMUUID(ctx context.Context, volumeID, nodeName string) (string, error) {
@@ -950,3 +962,72 @@ func GetZonesFromAccessibilityRequirements(ctx context.Context,
950962
}
951963
return zones, nil
952964
}
965+
966+
// getSnapshotLimitForNamespace reads the snapshot limit from ConfigMap in the namespace.
967+
// Returns the effective limit after applying defaults and absolute max clamping.
968+
func getSnapshotLimitForNamespace(ctx context.Context, namespace string) (int, error) {
969+
log := logger.GetLogger(ctx)
970+
971+
// Get Kubernetes config
972+
cfg, err := getK8sConfig()
973+
if err != nil {
974+
// If k8s config is not available (e.g., in test environments), use default
975+
log.Infof("getSnapshotLimitForNamespace: failed to get Kubernetes config, using default: %d. Error: %v",
976+
common.DefaultMaxSnapshotsPerVolume, err)
977+
return common.DefaultMaxSnapshotsPerVolume, nil
978+
}
979+
980+
// Create Kubernetes clientset
981+
k8sClient, err := newK8sClientFromConfig(cfg)
982+
if err != nil {
983+
// If k8s client creation fails, use default
984+
log.Infof("getSnapshotLimitForNamespace: failed to create Kubernetes client, using default: %d. Error: %v",
985+
common.DefaultMaxSnapshotsPerVolume, err)
986+
return common.DefaultMaxSnapshotsPerVolume, nil
987+
}
988+
989+
// Get ConfigMap from the namespace
990+
cm, err := k8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, common.ConfigMapCSILimits, metav1.GetOptions{})
991+
if err != nil {
992+
if apierrors.IsNotFound(err) {
993+
// ConfigMap not found, use default
994+
log.Infof("getSnapshotLimitForNamespace: ConfigMap %s not found in namespace %s, using default: %d",
995+
common.ConfigMapCSILimits, namespace, common.DefaultMaxSnapshotsPerVolume)
996+
return common.DefaultMaxSnapshotsPerVolume, nil
997+
}
998+
// Other error occurred
999+
log.Errorf("getSnapshotLimitForNamespace: failed to get ConfigMap %s in namespace %s, err: %v",
1000+
common.ConfigMapCSILimits, namespace, err)
1001+
return 0, err
1002+
}
1003+
1004+
// Check if the key exists in ConfigMap data
1005+
limitStr, exists := cm.Data[common.ConfigMapKeyMaxSnapshotsPerVolume]
1006+
if !exists {
1007+
// Key not found in ConfigMap, fail the request
1008+
errMsg := fmt.Sprintf("ConfigMap %s exists in namespace %s but missing required key '%s'",
1009+
common.ConfigMapCSILimits, namespace, common.ConfigMapKeyMaxSnapshotsPerVolume)
1010+
log.Errorf("getSnapshotLimitForNamespace: %s", errMsg)
1011+
return 0, errors.New(errMsg)
1012+
}
1013+
1014+
// Parse the limit value
1015+
limit, err := strconv.Atoi(limitStr)
1016+
if err != nil || limit < 0 {
1017+
// Invalid value, fail the request
1018+
errMsg := fmt.Sprintf("ConfigMap %s in namespace %s has invalid value '%s' for key '%s': must be a non-negative integer",
1019+
common.ConfigMapCSILimits, namespace, limitStr, common.ConfigMapKeyMaxSnapshotsPerVolume)
1020+
log.Errorf("getSnapshotLimitForNamespace: %s", errMsg)
1021+
return 0, errors.New(errMsg)
1022+
}
1023+
1024+
// Clamp to absolute max if exceeded
1025+
if limit > common.AbsoluteMaxSnapshotsPerVolume {
1026+
log.Warnf("getSnapshotLimitForNamespace: namespace %s ConfigMap limit %d exceeds absolute max %d, clamping to absolute max",
1027+
namespace, limit, common.AbsoluteMaxSnapshotsPerVolume)
1028+
return common.AbsoluteMaxSnapshotsPerVolume, nil
1029+
}
1030+
1031+
log.Infof("getSnapshotLimitForNamespace: namespace %s snapshot limit from ConfigMap: %d", namespace, limit)
1032+
return limit, nil
1033+
}

0 commit comments

Comments
 (0)