Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pkg/csi/service/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,18 @@ const (
// Guest cluster.
SupervisorVolumeSnapshotAnnotationKey = "csi.vsphere.guest-initiated-csi-snapshot"

// ConfigMapCSILimits is the ConfigMap name for CSI limits configuration
ConfigMapCSILimits = "cns-csi-limits"

// ConfigMapKeyMaxSnapshotsPerVolume is the ConfigMap key for snapshot limit per volume
ConfigMapKeyMaxSnapshotsPerVolume = "max-snapshots-per-volume"

// DefaultMaxSnapshotsPerVolume is the default maximum number of snapshots per block volume in WCP
DefaultMaxSnapshotsPerVolume = 4

// AbsoluteMaxSnapshotsPerVolume is the hard cap for maximum snapshots per block volume
AbsoluteMaxSnapshotsPerVolume = 32

// AttributeSupervisorVolumeSnapshotClass represents name of VolumeSnapshotClass
AttributeSupervisorVolumeSnapshotClass = "svvolumesnapshotclass"

Expand Down
112 changes: 106 additions & 6 deletions pkg/csi/service/wcp/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,23 @@ var (
vmMoidToHostMoid, volumeIDToVMMap map[string]string
)

// volumeLock represents a lock for a specific volume with reference counting
type volumeLock struct {
mutex sync.Mutex
refCount int
}

// snapshotLockManager manages per-volume locks for snapshot operations
type snapshotLockManager struct {
locks map[string]*volumeLock
mapMutex sync.RWMutex
}

type controller struct {
manager *common.Manager
authMgr common.AuthorizationService
topologyMgr commoncotypes.ControllerTopologyService
manager *common.Manager
authMgr common.AuthorizationService
topologyMgr commoncotypes.ControllerTopologyService
snapshotLockMgr *snapshotLockManager
csi.UnimplementedControllerServer
}

Expand Down Expand Up @@ -216,6 +229,12 @@ func (c *controller) Init(config *cnsconfig.Config, version string) error {
CryptoClient: cryptoClient,
}

// Initialize snapshot lock manager
c.snapshotLockMgr = &snapshotLockManager{
locks: make(map[string]*volumeLock),
}
log.Info("Initialized snapshot lock manager for per-volume serialization")

vc, err := common.GetVCenter(ctx, c.manager)
if err != nil {
log.Errorf("failed to get vcenter. err=%v", err)
Expand Down Expand Up @@ -452,6 +471,53 @@ func (c *controller) ReloadConfiguration(reconnectToVCFromNewConfig bool) error
return nil
}

// acquireSnapshotLock acquires a lock for the given volume ID.
// It creates a new lock if one doesn't exist and increments the reference count.
// The caller must call releaseSnapshotLock when done.
func (c *controller) acquireSnapshotLock(ctx context.Context, volumeID string) {
log := logger.GetLogger(ctx)
c.snapshotLockMgr.mapMutex.Lock()
defer c.snapshotLockMgr.mapMutex.Unlock()

vLock, exists := c.snapshotLockMgr.locks[volumeID]
if !exists {
vLock = &volumeLock{}
c.snapshotLockMgr.locks[volumeID] = vLock
log.Debugf("Created new lock for volume %q", volumeID)
}
vLock.refCount++
log.Debugf("Acquired lock for volume %q, refCount: %d", volumeID, vLock.refCount)

// Unlock the map before acquiring the volume lock to avoid deadlock
c.snapshotLockMgr.mapMutex.Unlock()
vLock.mutex.Lock()
c.snapshotLockMgr.mapMutex.Lock()
}

// releaseSnapshotLock releases the lock for the given volume ID.
// It decrements the reference count and removes the lock if count reaches zero.
func (c *controller) releaseSnapshotLock(ctx context.Context, volumeID string) {
log := logger.GetLogger(ctx)
c.snapshotLockMgr.mapMutex.Lock()
defer c.snapshotLockMgr.mapMutex.Unlock()

vLock, exists := c.snapshotLockMgr.locks[volumeID]
if !exists {
log.Warnf("Attempted to release non-existent lock for volume %q", volumeID)
return
}

vLock.mutex.Unlock()
vLock.refCount--
log.Debugf("Released lock for volume %q, refCount: %d", volumeID, vLock.refCount)

// Clean up the lock if reference count reaches zero
if vLock.refCount == 0 {
delete(c.snapshotLockMgr.locks, volumeID)
log.Debugf("Cleaned up lock for volume %q", volumeID)
}
}

// createBlockVolume creates a block volume based on the CreateVolumeRequest.
func (c *controller) createBlockVolume(ctx context.Context, req *csi.CreateVolumeRequest,
isWorkloadDomainIsolationEnabled bool, clusterMoIds []string) (
Expand Down Expand Up @@ -2455,8 +2521,43 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot
"Queried VolumeType: %v", volumeType, cnsVolumeDetailsMap[volumeID].VolumeType)
}

// TODO: We may need to add logic to check the limit of max number of snapshots by using
// GlobalMaxSnapshotsPerBlockVolume etc. variables in the future.
// Extract namespace from request parameters
volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey]
if volumeSnapshotNamespace == "" {
return nil, logger.LogNewErrorCodef(log, codes.Internal,
"volumesnapshot namespace is not set in the request parameters")
}

// Get snapshot limit from namespace ConfigMap
snapshotLimit, err := getSnapshotLimitForNamespace(ctx, volumeSnapshotNamespace)
if err != nil {
return nil, logger.LogNewErrorCodef(log, codes.Internal,
"failed to get snapshot limit for namespace %q: %v", volumeSnapshotNamespace, err)
}
log.Infof("Snapshot limit for namespace %q is set to %d", volumeSnapshotNamespace, snapshotLimit)

// Acquire lock for this volume to serialize snapshot operations
c.acquireSnapshotLock(ctx, volumeID)
defer c.releaseSnapshotLock(ctx, volumeID)

// Query existing snapshots for this volume
snapshotList, _, err := common.QueryVolumeSnapshotsByVolumeID(ctx, c.manager.VolumeManager, volumeID,
common.QuerySnapshotLimit)
if err != nil {
return nil, logger.LogNewErrorCodef(log, codes.Internal,
"failed to query snapshots for volume %q: %v", volumeID, err)
}

// Check if the limit is exceeded
currentSnapshotCount := len(snapshotList)
if currentSnapshotCount >= snapshotLimit {
return nil, logger.LogNewErrorCodef(log, codes.FailedPrecondition,
"the number of snapshots (%d) on the source volume %s has reached or exceeded "+
"the configured maximum (%d) for namespace %s",
currentSnapshotCount, volumeID, snapshotLimit, volumeSnapshotNamespace)
}
log.Infof("Current snapshot count for volume %q is %d, within limit of %d",
volumeID, currentSnapshotCount, snapshotLimit)

// the returned snapshotID below is a combination of CNS VolumeID and CNS SnapshotID concatenated by the "+"
// sign. That is, a string of "<UUID>+<UUID>". Because, all other CNS snapshot APIs still require both
Expand Down Expand Up @@ -2534,7 +2635,6 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot
cnsSnapshotInfo.SnapshotLatestOperationCompleteTime, createSnapshotResponse)

volumeSnapshotName := req.Parameters[common.VolumeSnapshotNameKey]
volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey]
log.Infof("Attempting to annotate volumesnapshot %s/%s with annotation %s:%s",
volumeSnapshotNamespace, volumeSnapshotName, common.VolumeSnapshotInfoKey, snapshotID)
annotated, err := commonco.ContainerOrchestratorUtility.AnnotateVolumeSnapshot(ctx, volumeSnapshotName,
Expand Down
83 changes: 83 additions & 0 deletions pkg/csi/service/wcp/controller_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes"
restclient "k8s.io/client-go/rest"
api "k8s.io/kubernetes/pkg/apis/core"
"sigs.k8s.io/controller-runtime/pkg/client/config"
spv1alpha1 "sigs.k8s.io/vsphere-csi-driver/v3/pkg/apis/storagepool/cns/v1alpha1"
Expand Down Expand Up @@ -849,6 +852,15 @@ func validateControllerPublishVolumeRequesInWcp(ctx context.Context, req *csi.Co

var newK8sClient = k8s.NewClient

// getK8sConfig is a variable that can be overridden for testing
var getK8sConfig = config.GetConfig

// newK8sClientFromConfig is a variable that can be overridden for testing
// It wraps kubernetes.NewForConfig and returns Interface for easier testing
var newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) {
return kubernetes.NewForConfig(c)
}

// getPodVMUUID returns the UUID of the VM(running on the node) on which the pod that is trying to
// use the volume is scheduled.
func getPodVMUUID(ctx context.Context, volumeID, nodeName string) (string, error) {
Expand Down Expand Up @@ -950,3 +962,74 @@ func GetZonesFromAccessibilityRequirements(ctx context.Context,
}
return zones, nil
}

// getSnapshotLimitForNamespace reads the snapshot limit from ConfigMap in the namespace.
// Returns the effective limit after applying defaults and absolute max clamping.
func getSnapshotLimitForNamespace(ctx context.Context, namespace string) (int, error) {
log := logger.GetLogger(ctx)

// Get Kubernetes config
cfg, err := getK8sConfig()
if err != nil {
// If k8s config is not available (e.g., in test environments), use default
log.Infof("getSnapshotLimitForNamespace: failed to get Kubernetes config, using default: %d. Error: %v",
common.DefaultMaxSnapshotsPerVolume, err)
return common.DefaultMaxSnapshotsPerVolume, nil
}

// Create Kubernetes clientset
k8sClient, err := newK8sClientFromConfig(cfg)
if err != nil {
// If k8s client creation fails, use default
log.Infof("getSnapshotLimitForNamespace: failed to create Kubernetes client, using default: %d. Error: %v",
common.DefaultMaxSnapshotsPerVolume, err)
return common.DefaultMaxSnapshotsPerVolume, nil
}

// Get ConfigMap from the namespace
cm, err := k8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, common.ConfigMapCSILimits, metav1.GetOptions{})
if err != nil {
if apierrors.IsNotFound(err) {
// ConfigMap not found, use default
log.Infof("getSnapshotLimitForNamespace: ConfigMap %s not found in namespace %s, using default: %d",
common.ConfigMapCSILimits, namespace, common.DefaultMaxSnapshotsPerVolume)
return common.DefaultMaxSnapshotsPerVolume, nil
}
// Other error occurred
log.Errorf("getSnapshotLimitForNamespace: failed to get ConfigMap %s in namespace %s, err: %v",
common.ConfigMapCSILimits, namespace, err)
return 0, err
}

// Check if the key exists in ConfigMap data
limitStr, exists := cm.Data[common.ConfigMapKeyMaxSnapshotsPerVolume]
if !exists {
// Key not found in ConfigMap, fail the request
errMsg := fmt.Sprintf("ConfigMap %s exists in namespace %s but missing required key '%s'",
common.ConfigMapCSILimits, namespace, common.ConfigMapKeyMaxSnapshotsPerVolume)
log.Errorf("getSnapshotLimitForNamespace: %s", errMsg)
return 0, errors.New(errMsg)
}

// Parse the limit value
limit, err := strconv.Atoi(limitStr)
if err != nil || limit < 0 {
// Invalid value, fail the request
errMsg := fmt.Sprintf(
"ConfigMap %s in namespace %s has invalid value '%s' for key '%s': must be a non-negative integer",
common.ConfigMapCSILimits, namespace, limitStr, common.ConfigMapKeyMaxSnapshotsPerVolume)
log.Errorf("getSnapshotLimitForNamespace: %s", errMsg)
return 0, errors.New(errMsg)
}

// Clamp to absolute max if exceeded
if limit > common.AbsoluteMaxSnapshotsPerVolume {
log.Warnf(
"getSnapshotLimitForNamespace: namespace %s ConfigMap limit %d exceeds absolute max %d, clamping to absolute max",
namespace, limit, common.AbsoluteMaxSnapshotsPerVolume)
return common.AbsoluteMaxSnapshotsPerVolume, nil
}

log.Infof("getSnapshotLimitForNamespace: namespace %s snapshot limit from ConfigMap: %d", namespace, limit)
return limit, nil
}
Loading