diff --git a/pkg/csi/service/common/constants.go b/pkg/csi/service/common/constants.go index 90e0bf78da..95ac8a2f12 100644 --- a/pkg/csi/service/common/constants.go +++ b/pkg/csi/service/common/constants.go @@ -353,6 +353,18 @@ const ( // Guest cluster. SupervisorVolumeSnapshotAnnotationKey = "csi.vsphere.guest-initiated-csi-snapshot" + // ConfigMapCSILimits is the ConfigMap name for CSI limits configuration + ConfigMapCSILimits = "cns-csi-limits" + + // ConfigMapKeyMaxSnapshotsPerVolume is the ConfigMap key for snapshot limit per volume + ConfigMapKeyMaxSnapshotsPerVolume = "max-snapshots-per-volume" + + // DefaultMaxSnapshotsPerVolume is the default maximum number of snapshots per block volume in WCP + DefaultMaxSnapshotsPerVolume = 4 + + // AbsoluteMaxSnapshotsPerVolume is the hard cap for maximum snapshots per block volume + AbsoluteMaxSnapshotsPerVolume = 32 + // AttributeSupervisorVolumeSnapshotClass represents name of VolumeSnapshotClass AttributeSupervisorVolumeSnapshotClass = "svvolumesnapshotclass" diff --git a/pkg/csi/service/wcp/controller.go b/pkg/csi/service/wcp/controller.go index 94dadfa984..aef78efc23 100644 --- a/pkg/csi/service/wcp/controller.go +++ b/pkg/csi/service/wcp/controller.go @@ -98,10 +98,23 @@ var ( vmMoidToHostMoid, volumeIDToVMMap map[string]string ) +// volumeLock represents a lock for a specific volume with reference counting +type volumeLock struct { + mutex sync.Mutex + refCount int +} + +// snapshotLockManager manages per-volume locks for snapshot operations +type snapshotLockManager struct { + locks map[string]*volumeLock + mapMutex sync.RWMutex +} + type controller struct { - manager *common.Manager - authMgr common.AuthorizationService - topologyMgr commoncotypes.ControllerTopologyService + manager *common.Manager + authMgr common.AuthorizationService + topologyMgr commoncotypes.ControllerTopologyService + snapshotLockMgr *snapshotLockManager csi.UnimplementedControllerServer } @@ -216,6 +229,12 @@ func (c *controller) Init(config *cnsconfig.Config, version string) error { CryptoClient: cryptoClient, } + // Initialize snapshot lock manager + c.snapshotLockMgr = &snapshotLockManager{ + locks: make(map[string]*volumeLock), + } + log.Info("Initialized snapshot lock manager for per-volume serialization") + vc, err := common.GetVCenter(ctx, c.manager) if err != nil { log.Errorf("failed to get vcenter. err=%v", err) @@ -452,6 +471,53 @@ func (c *controller) ReloadConfiguration(reconnectToVCFromNewConfig bool) error return nil } +// acquireSnapshotLock acquires a lock for the given volume ID. +// It creates a new lock if one doesn't exist and increments the reference count. +// The caller must call releaseSnapshotLock when done. +func (c *controller) acquireSnapshotLock(ctx context.Context, volumeID string) { + log := logger.GetLogger(ctx) + c.snapshotLockMgr.mapMutex.Lock() + defer c.snapshotLockMgr.mapMutex.Unlock() + + vLock, exists := c.snapshotLockMgr.locks[volumeID] + if !exists { + vLock = &volumeLock{} + c.snapshotLockMgr.locks[volumeID] = vLock + log.Debugf("Created new lock for volume %q", volumeID) + } + vLock.refCount++ + log.Debugf("Acquired lock for volume %q, refCount: %d", volumeID, vLock.refCount) + + // Unlock the map before acquiring the volume lock to avoid deadlock + c.snapshotLockMgr.mapMutex.Unlock() + vLock.mutex.Lock() + c.snapshotLockMgr.mapMutex.Lock() +} + +// releaseSnapshotLock releases the lock for the given volume ID. +// It decrements the reference count and removes the lock if count reaches zero. +func (c *controller) releaseSnapshotLock(ctx context.Context, volumeID string) { + log := logger.GetLogger(ctx) + c.snapshotLockMgr.mapMutex.Lock() + defer c.snapshotLockMgr.mapMutex.Unlock() + + vLock, exists := c.snapshotLockMgr.locks[volumeID] + if !exists { + log.Warnf("Attempted to release non-existent lock for volume %q", volumeID) + return + } + + vLock.mutex.Unlock() + vLock.refCount-- + log.Debugf("Released lock for volume %q, refCount: %d", volumeID, vLock.refCount) + + // Clean up the lock if reference count reaches zero + if vLock.refCount == 0 { + delete(c.snapshotLockMgr.locks, volumeID) + log.Debugf("Cleaned up lock for volume %q", volumeID) + } +} + // createBlockVolume creates a block volume based on the CreateVolumeRequest. func (c *controller) createBlockVolume(ctx context.Context, req *csi.CreateVolumeRequest, isWorkloadDomainIsolationEnabled bool, clusterMoIds []string) ( @@ -2455,8 +2521,43 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot "Queried VolumeType: %v", volumeType, cnsVolumeDetailsMap[volumeID].VolumeType) } - // TODO: We may need to add logic to check the limit of max number of snapshots by using - // GlobalMaxSnapshotsPerBlockVolume etc. variables in the future. + // Extract namespace from request parameters + volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey] + if volumeSnapshotNamespace == "" { + return nil, logger.LogNewErrorCodef(log, codes.Internal, + "volumesnapshot namespace is not set in the request parameters") + } + + // Get snapshot limit from namespace ConfigMap + snapshotLimit, err := getSnapshotLimitForNamespace(ctx, volumeSnapshotNamespace) + if err != nil { + return nil, logger.LogNewErrorCodef(log, codes.Internal, + "failed to get snapshot limit for namespace %q: %v", volumeSnapshotNamespace, err) + } + log.Infof("Snapshot limit for namespace %q is set to %d", volumeSnapshotNamespace, snapshotLimit) + + // Acquire lock for this volume to serialize snapshot operations + c.acquireSnapshotLock(ctx, volumeID) + defer c.releaseSnapshotLock(ctx, volumeID) + + // Query existing snapshots for this volume + snapshotList, _, err := common.QueryVolumeSnapshotsByVolumeID(ctx, c.manager.VolumeManager, volumeID, + common.QuerySnapshotLimit) + if err != nil { + return nil, logger.LogNewErrorCodef(log, codes.Internal, + "failed to query snapshots for volume %q: %v", volumeID, err) + } + + // Check if the limit is exceeded + currentSnapshotCount := len(snapshotList) + if currentSnapshotCount >= snapshotLimit { + return nil, logger.LogNewErrorCodef(log, codes.FailedPrecondition, + "the number of snapshots (%d) on the source volume %s has reached or exceeded "+ + "the configured maximum (%d) for namespace %s", + currentSnapshotCount, volumeID, snapshotLimit, volumeSnapshotNamespace) + } + log.Infof("Current snapshot count for volume %q is %d, within limit of %d", + volumeID, currentSnapshotCount, snapshotLimit) // the returned snapshotID below is a combination of CNS VolumeID and CNS SnapshotID concatenated by the "+" // sign. That is, a string of "+". Because, all other CNS snapshot APIs still require both @@ -2534,7 +2635,6 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot cnsSnapshotInfo.SnapshotLatestOperationCompleteTime, createSnapshotResponse) volumeSnapshotName := req.Parameters[common.VolumeSnapshotNameKey] - volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey] log.Infof("Attempting to annotate volumesnapshot %s/%s with annotation %s:%s", volumeSnapshotNamespace, volumeSnapshotName, common.VolumeSnapshotInfoKey, snapshotID) annotated, err := commonco.ContainerOrchestratorUtility.AnnotateVolumeSnapshot(ctx, volumeSnapshotName, diff --git a/pkg/csi/service/wcp/controller_helper.go b/pkg/csi/service/wcp/controller_helper.go index c1e248d6f6..a13f7f5e92 100644 --- a/pkg/csi/service/wcp/controller_helper.go +++ b/pkg/csi/service/wcp/controller_helper.go @@ -34,9 +34,12 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" + restclient "k8s.io/client-go/rest" api "k8s.io/kubernetes/pkg/apis/core" "sigs.k8s.io/controller-runtime/pkg/client/config" spv1alpha1 "sigs.k8s.io/vsphere-csi-driver/v3/pkg/apis/storagepool/cns/v1alpha1" @@ -849,6 +852,15 @@ func validateControllerPublishVolumeRequesInWcp(ctx context.Context, req *csi.Co var newK8sClient = k8s.NewClient +// getK8sConfig is a variable that can be overridden for testing +var getK8sConfig = config.GetConfig + +// newK8sClientFromConfig is a variable that can be overridden for testing +// It wraps kubernetes.NewForConfig and returns Interface for easier testing +var newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return kubernetes.NewForConfig(c) +} + // getPodVMUUID returns the UUID of the VM(running on the node) on which the pod that is trying to // use the volume is scheduled. func getPodVMUUID(ctx context.Context, volumeID, nodeName string) (string, error) { @@ -950,3 +962,74 @@ func GetZonesFromAccessibilityRequirements(ctx context.Context, } return zones, nil } + +// getSnapshotLimitForNamespace reads the snapshot limit from ConfigMap in the namespace. +// Returns the effective limit after applying defaults and absolute max clamping. +func getSnapshotLimitForNamespace(ctx context.Context, namespace string) (int, error) { + log := logger.GetLogger(ctx) + + // Get Kubernetes config + cfg, err := getK8sConfig() + if err != nil { + // If k8s config is not available (e.g., in test environments), use default + log.Infof("getSnapshotLimitForNamespace: failed to get Kubernetes config, using default: %d. Error: %v", + common.DefaultMaxSnapshotsPerVolume, err) + return common.DefaultMaxSnapshotsPerVolume, nil + } + + // Create Kubernetes clientset + k8sClient, err := newK8sClientFromConfig(cfg) + if err != nil { + // If k8s client creation fails, use default + log.Infof("getSnapshotLimitForNamespace: failed to create Kubernetes client, using default: %d. Error: %v", + common.DefaultMaxSnapshotsPerVolume, err) + return common.DefaultMaxSnapshotsPerVolume, nil + } + + // Get ConfigMap from the namespace + cm, err := k8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, common.ConfigMapCSILimits, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + // ConfigMap not found, use default + log.Infof("getSnapshotLimitForNamespace: ConfigMap %s not found in namespace %s, using default: %d", + common.ConfigMapCSILimits, namespace, common.DefaultMaxSnapshotsPerVolume) + return common.DefaultMaxSnapshotsPerVolume, nil + } + // Other error occurred + log.Errorf("getSnapshotLimitForNamespace: failed to get ConfigMap %s in namespace %s, err: %v", + common.ConfigMapCSILimits, namespace, err) + return 0, err + } + + // Check if the key exists in ConfigMap data + limitStr, exists := cm.Data[common.ConfigMapKeyMaxSnapshotsPerVolume] + if !exists { + // Key not found in ConfigMap, fail the request + errMsg := fmt.Sprintf("ConfigMap %s exists in namespace %s but missing required key '%s'", + common.ConfigMapCSILimits, namespace, common.ConfigMapKeyMaxSnapshotsPerVolume) + log.Errorf("getSnapshotLimitForNamespace: %s", errMsg) + return 0, errors.New(errMsg) + } + + // Parse the limit value + limit, err := strconv.Atoi(limitStr) + if err != nil || limit < 0 { + // Invalid value, fail the request + errMsg := fmt.Sprintf( + "ConfigMap %s in namespace %s has invalid value '%s' for key '%s': must be a non-negative integer", + common.ConfigMapCSILimits, namespace, limitStr, common.ConfigMapKeyMaxSnapshotsPerVolume) + log.Errorf("getSnapshotLimitForNamespace: %s", errMsg) + return 0, errors.New(errMsg) + } + + // Clamp to absolute max if exceeded + if limit > common.AbsoluteMaxSnapshotsPerVolume { + log.Warnf( + "getSnapshotLimitForNamespace: namespace %s ConfigMap limit %d exceeds absolute max %d, clamping to absolute max", + namespace, limit, common.AbsoluteMaxSnapshotsPerVolume) + return common.AbsoluteMaxSnapshotsPerVolume, nil + } + + log.Infof("getSnapshotLimitForNamespace: namespace %s snapshot limit from ConfigMap: %d", namespace, limit) + return limit, nil +} diff --git a/pkg/csi/service/wcp/controller_helper_test.go b/pkg/csi/service/wcp/controller_helper_test.go index 4eaa2dfcfd..99e073e310 100644 --- a/pkg/csi/service/wcp/controller_helper_test.go +++ b/pkg/csi/service/wcp/controller_helper_test.go @@ -10,8 +10,10 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/fake" + restclient "k8s.io/client-go/rest" k8stesting "k8s.io/client-go/testing" "sigs.k8s.io/vsphere-csi-driver/v3/pkg/common/unittestcommon" + "sigs.k8s.io/vsphere-csi-driver/v3/pkg/csi/service/common" "sigs.k8s.io/vsphere-csi-driver/v3/pkg/csi/service/common/commonco" ) @@ -158,3 +160,215 @@ func newMockPod(name, namespace, nodeName string, volumes []string, }, } } + +func TestGetSnapshotLimitForNamespace(t *testing.T) { + // Save original functions and restore after tests + originalGetConfig := getK8sConfig + originalNewK8sClientFromConfig := newK8sClientFromConfig + defer func() { + getK8sConfig = originalGetConfig + newK8sClientFromConfig = originalNewK8sClientFromConfig + }() + + // Mock getK8sConfig to return a fake config + getK8sConfig = func() (*restclient.Config, error) { + return &restclient.Config{}, nil + } + + t.Run("WhenConfigMapExists_ValidValue", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{ + common.ConfigMapKeyMaxSnapshotsPerVolume: "5", + }, + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 5, limit) + }) + + t.Run("WhenConfigMapExists_ValueEqualsMax", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{ + common.ConfigMapKeyMaxSnapshotsPerVolume: "32", + }, + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 32, limit) + }) + + t.Run("WhenConfigMapExists_ValueExceedsMax", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{ + common.ConfigMapKeyMaxSnapshotsPerVolume: "50", + }, + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 32, limit) // Should be capped to absolute max + }) + + t.Run("WhenConfigMapExists_ValueIsZero", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{ + common.ConfigMapKeyMaxSnapshotsPerVolume: "0", + }, + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 0, limit) // 0 means block all snapshots + }) + + t.Run("WhenConfigMapExists_ValueIsNegative", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{ + common.ConfigMapKeyMaxSnapshotsPerVolume: "-5", + }, + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + _, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "invalid value") + assert.Contains(t, err.Error(), "must be a non-negative integer") + }) + + t.Run("WhenConfigMapExists_InvalidFormat", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{ + common.ConfigMapKeyMaxSnapshotsPerVolume: "abc", + }, + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + _, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "invalid value") + assert.Contains(t, err.Error(), "must be a non-negative integer") + }) + + t.Run("WhenConfigMapExists_MissingKey", func(t *testing.T) { + // Setup + cm := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: common.ConfigMapCSILimits, + Namespace: "test-namespace", + }, + Data: map[string]string{}, // ConfigMap exists but key is missing + } + fakeClient := fake.NewSimpleClientset(cm) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + _, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "missing required key") + }) + + t.Run("WhenConfigMapNotFound", func(t *testing.T) { + // Setup + fakeClient := fake.NewSimpleClientset() // Empty clientset + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, common.DefaultMaxSnapshotsPerVolume, limit) // Should return default (4) + }) + + t.Run("WhenK8sClientCreationFails", func(t *testing.T) { + // Setup + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return nil, assert.AnError + } + + // Execute + limit, err := getSnapshotLimitForNamespace(context.Background(), "test-namespace") + + // Verify - should return default instead of error + assert.Nil(t, err) + assert.Equal(t, common.DefaultMaxSnapshotsPerVolume, limit) + }) +} diff --git a/pkg/csi/service/wcp/controller_test.go b/pkg/csi/service/wcp/controller_test.go index 679ab5df90..d9b77d82c5 100644 --- a/pkg/csi/service/wcp/controller_test.go +++ b/pkg/csi/service/wcp/controller_test.go @@ -22,6 +22,7 @@ import ( "strings" "sync" "testing" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -137,6 +138,9 @@ func getControllerTest(t *testing.T) *controllerTest { c := &controller{ manager: manager, topologyMgr: topologyMgr, + snapshotLockMgr: &snapshotLockManager{ + locks: make(map[string]*volumeLock), + }, } controllerTestInstance = &controllerTest{ @@ -673,6 +677,9 @@ func TestWCPCreateDeleteSnapshot(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -748,6 +755,9 @@ func TestListSnapshots(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -867,6 +877,9 @@ func TestListSnapshotsOnSpecificVolume(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -987,6 +1000,9 @@ func TestListSnapshotsWithToken(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -1113,6 +1129,9 @@ func TestListSnapshotsOnSpecificVolumeAndSnapshot(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -1270,6 +1289,9 @@ func TestCreateVolumeFromSnapshot(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -1436,6 +1458,9 @@ func TestWCPDeleteVolumeWithSnapshots(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -1541,6 +1566,9 @@ func TestWCPExpandVolumeWithSnapshots(t *testing.T) { reqCreateSnapshot := &csi.CreateSnapshotRequest{ SourceVolumeId: volID, Name: "snapshot-" + uuid.New().String(), + Parameters: map[string]string{ + common.VolumeSnapshotNamespaceKey: "default", + }, } respCreateSnapshot, err := ct.controller.CreateSnapshot(ctx, reqCreateSnapshot) @@ -1858,3 +1886,210 @@ func TestControllerModifyVolume(t *testing.T) { } }) } + +func TestSnapshotLockManager(t *testing.T) { + ct := getControllerTest(t) + + t.Run("AcquireAndRelease_SingleVolume", func(t *testing.T) { + volumeID := "test-volume-1" + + // Acquire lock + ct.controller.acquireSnapshotLock(ctx, volumeID) + + // Verify lock exists and refCount = 1 + ct.controller.snapshotLockMgr.mapMutex.RLock() + vLock, exists := ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists { + t.Fatal("Lock should exist after acquire") + } + if vLock.refCount != 1 { + t.Fatalf("Expected refCount=1, got %d", vLock.refCount) + } + + // Release lock + ct.controller.releaseSnapshotLock(ctx, volumeID) + + // Verify lock is removed + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists = ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should be removed after release") + } + }) + + t.Run("AcquireMultipleTimes_SameVolume", func(t *testing.T) { + volumeID := "test-volume-2" + + // Use two goroutines to acquire the lock + var wg sync.WaitGroup + acquired := make(chan bool, 2) + + // First goroutine acquires and holds the lock + wg.Add(1) + go func() { + defer wg.Done() + ct.controller.acquireSnapshotLock(ctx, volumeID) + acquired <- true + // Hold lock briefly + time.Sleep(100 * time.Millisecond) + ct.controller.releaseSnapshotLock(ctx, volumeID) + }() + + // Wait for first goroutine to acquire + <-acquired + + // Verify refCount = 1, lock exists + ct.controller.snapshotLockMgr.mapMutex.RLock() + vLock, exists := ct.controller.snapshotLockMgr.locks[volumeID] + refCount1 := vLock.refCount + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists { + t.Fatal("Lock should exist") + } + if refCount1 != 1 { + t.Fatalf("Expected refCount=1, got %d", refCount1) + } + + // Second goroutine tries to acquire (will be blocked) + wg.Add(1) + go func() { + defer wg.Done() + ct.controller.acquireSnapshotLock(ctx, volumeID) + acquired <- true + ct.controller.releaseSnapshotLock(ctx, volumeID) + }() + + // Give second goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Verify refCount increased to 2 (second goroutine is waiting) + ct.controller.snapshotLockMgr.mapMutex.RLock() + vLock, exists = ct.controller.snapshotLockMgr.locks[volumeID] + refCount2 := vLock.refCount + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists { + t.Fatal("Lock should exist") + } + if refCount2 != 2 { + t.Fatalf("Expected refCount=2, got %d", refCount2) + } + + // Wait for both goroutines to complete + wg.Wait() + + // Verify lock is removed after both releases + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists = ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should be removed after all releases") + } + }) + + t.Run("AcquireRelease_MultipleVolumes", func(t *testing.T) { + volume1 := "test-volume-3" + volume2 := "test-volume-4" + volume3 := "test-volume-5" + + // Acquire locks for all volumes + ct.controller.acquireSnapshotLock(ctx, volume1) + ct.controller.acquireSnapshotLock(ctx, volume2) + ct.controller.acquireSnapshotLock(ctx, volume3) + + // Verify all locks exist + ct.controller.snapshotLockMgr.mapMutex.RLock() + count := len(ct.controller.snapshotLockMgr.locks) + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if count < 3 { + t.Fatalf("Expected at least 3 locks, got %d", count) + } + + // Release volume2 + ct.controller.releaseSnapshotLock(ctx, volume2) + + // Verify volume2 removed, others remain + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists1 := ct.controller.snapshotLockMgr.locks[volume1] + _, exists2 := ct.controller.snapshotLockMgr.locks[volume2] + _, exists3 := ct.controller.snapshotLockMgr.locks[volume3] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists1 { + t.Fatal("Volume1 lock should still exist") + } + if exists2 { + t.Fatal("Volume2 lock should be removed") + } + if !exists3 { + t.Fatal("Volume3 lock should still exist") + } + + // Cleanup + ct.controller.releaseSnapshotLock(ctx, volume1) + ct.controller.releaseSnapshotLock(ctx, volume3) + }) + + t.Run("ConcurrentAccess_SameVolume", func(t *testing.T) { + volumeID := "test-volume-concurrent" + counter := 0 + var wg sync.WaitGroup + goroutines := 5 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ct.controller.acquireSnapshotLock(ctx, volumeID) + defer ct.controller.releaseSnapshotLock(ctx, volumeID) + + // Critical section - increment counter + temp := counter + // Simulate some work + for j := 0; j < 100; j++ { + _ = j * 2 + } + counter = temp + 1 + }() + } + + wg.Wait() + + // Verify counter = goroutines (no race condition) + if counter != goroutines { + t.Fatalf("Expected counter=%d, got %d (race condition detected)", goroutines, counter) + } + + // Verify lock is cleaned up + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists := ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should be cleaned up after all goroutines finish") + } + }) + + t.Run("ReleaseNonExistentLock", func(t *testing.T) { + volumeID := "non-existent-volume" + + // This should not panic + ct.controller.releaseSnapshotLock(ctx, volumeID) + + // Verify no lock was created + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists := ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should not exist after releasing non-existent lock") + } + }) +}