Skip to content

Commit 8bbb3d4

Browse files
Add map lock on Publish/Unpublish volume
1 parent ed23ee0 commit 8bbb3d4

File tree

3 files changed

+141
-12
lines changed

3 files changed

+141
-12
lines changed

pkg/csi_driver/node.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ type nodeServer struct {
5252
volumeLocks *util.VolumeLocks
5353
k8sClients clientset.Interface
5454
limiter rate.Limiter
55-
volumeStateStore map[string]*volumeState
56-
}
57-
58-
type volumeState struct {
59-
bucketAccessCheckPassed bool
55+
volumeStateStore *util.VolumeStateStore
6056
}
6157

6258
func newNodeServer(driver *GCSDriver, mounter mount.Interface) csi.NodeServer {
@@ -67,7 +63,7 @@ func newNodeServer(driver *GCSDriver, mounter mount.Interface) csi.NodeServer {
6763
volumeLocks: util.NewVolumeLocks(),
6864
k8sClients: driver.config.K8sClients,
6965
limiter: *rate.NewLimiter(rate.Every(time.Second), 10),
70-
volumeStateStore: make(map[string]*volumeState),
66+
volumeStateStore: util.NewVolumeStateStore(),
7167
}
7268
}
7369

@@ -113,13 +109,13 @@ func (s *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublish
113109
if bucketName != "_" && !skipBucketAccessCheck {
114110
// Use target path as an volume identifier because it corresponds to Pods and volumes.
115111
// Pods may belong to different namespaces and would need their own access check.
116-
vs, ok := s.volumeStateStore[targetPath]
112+
vs, ok := s.volumeStateStore.Load(targetPath)
117113
if !ok {
118-
s.volumeStateStore[targetPath] = &volumeState{}
119-
vs = s.volumeStateStore[targetPath]
114+
s.volumeStateStore.Store(targetPath, &util.VolumeState{})
115+
vs, _ = s.volumeStateStore.Load(targetPath)
120116
}
121117

122-
if !vs.bucketAccessCheckPassed {
118+
if !vs.BucketAccessCheckPassed {
123119
storageService, err := s.prepareStorageService(ctx, vc)
124120
if err != nil {
125121
return nil, status.Errorf(codes.Unauthenticated, "failed to prepare storage service: %v", err)
@@ -130,7 +126,7 @@ func (s *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublish
130126
return nil, status.Errorf(storage.ParseErrCode(err), "failed to get GCS bucket %q: %v", bucketName, err)
131127
}
132128

133-
vs.bucketAccessCheckPassed = true
129+
vs.BucketAccessCheckPassed = true
134130
}
135131
}
136132

@@ -239,7 +235,7 @@ func (s *nodeServer) NodeUnpublishVolume(_ context.Context, req *csi.NodeUnpubli
239235
s.driver.config.MetricsManager.UnregisterMetricsCollector(targetPath)
240236
}
241237

242-
delete(s.volumeStateStore, targetPath)
238+
s.volumeStateStore.Delete(targetPath)
243239

244240
// Check if the target path is already mounted
245241
if mounted, err := s.isDirMounted(targetPath); mounted || err != nil {

pkg/csi_driver/node_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ import (
2121
"errors"
2222
"os"
2323
"path/filepath"
24+
"sync"
2425
"testing"
2526

2627
csi "github.com/container-storage-interface/spec/lib/go/csi"
2728
"github.com/google/go-cmp/cmp"
2829
"github.com/google/go-cmp/cmp/cmpopts"
2930
"github.com/googlecloudplatform/gcs-fuse-csi-driver/pkg/cloud_provider/storage"
31+
"github.com/googlecloudplatform/gcs-fuse-csi-driver/pkg/util"
3032
"golang.org/x/net/context"
3133
"google.golang.org/grpc/codes"
3234
"google.golang.org/grpc/status"
@@ -247,6 +249,38 @@ func TestNodeUnpublishVolume(t *testing.T) {
247249
}
248250
}
249251

252+
// Attempts to modify a shared map, causes fatal error: concurrent map writes unless using alter with lock.
253+
func TestConcurrentMapWrites(t *testing.T) {
254+
t.Parallel()
255+
// Create a shared map for the test
256+
sharedVSS := make(map[string]*volumeState)
257+
ns := &nodeServer{volumeLocks: util.NewVolumeLocks()}
258+
// Number of concurrent writes we want to simulate
259+
numWrites := 200
260+
261+
// Use a WaitGroup to wait for all goroutines to finish
262+
var wg sync.WaitGroup
263+
264+
// Run concurrent tests manually using goroutines
265+
for i := range numWrites {
266+
wg.Add(1)
267+
go func() {
268+
defer wg.Done()
269+
// Simulate concurrent writes to the shared map
270+
// sharedVSS[string(rune(i))] = &volumeState{}
271+
ns.alterWithLock("vss", func() { sharedVSS[string(rune(i))] = &volumeState{} })
272+
}()
273+
}
274+
275+
// Wait for all goroutines to finish
276+
wg.Wait()
277+
278+
// validate correct number of writes occurred
279+
if len(sharedVSS) != numWrites {
280+
t.Errorf("expected %d entries in the map, got %d", numWrites, len(sharedVSS))
281+
}
282+
}
283+
250284
func validateMountPoint(t *testing.T, name string, fm *mount.FakeMounter, e *mount.MountPoint) {
251285
t.Helper()
252286
if e == nil {
@@ -279,3 +313,31 @@ func validateMountPoint(t *testing.T, name string, fm *mount.FakeMounter, e *mou
279313
t.Errorf("unexpected options args (-got, +want)\n%s", diff)
280314
}
281315
}
316+
func TestConcurrentMapWrites(t *testing.T) {
317+
t.Parallel()
318+
// Create a shared map for the test
319+
sharedVSS := util.NewVolumeStateStore()
320+
// Number of concurrent writes we want to simulate
321+
numWrites := 2000
322+
323+
// Use a WaitGroup to wait for all goroutines to finish
324+
var wg sync.WaitGroup
325+
326+
// Run concurrent tests manually using goroutines
327+
for i := range numWrites {
328+
wg.Add(1)
329+
go func() {
330+
defer wg.Done()
331+
// Simulate concurrent writes to the shared map
332+
sharedVSS.Store(string(rune(i)), &util.VolumeState{})
333+
}()
334+
}
335+
336+
// Wait for all goroutines to finish
337+
wg.Wait()
338+
339+
// validate correct number of writes occurred
340+
if int(sharedVSS.Size()) != numWrites {
341+
t.Errorf("expected %d entries in the map, got %d", numWrites, sharedVSS.Size())
342+
}
343+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
Copyright 2018 The Kubernetes Authors.
3+
Copyright 2022 Google LLC
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
https://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
*/
17+
18+
package util
19+
20+
import (
21+
"sync"
22+
"sync/atomic"
23+
)
24+
25+
// VolumeStateStore provides a thread-safe map for storing volume states.
26+
type VolumeStateStore struct {
27+
store sync.Map // Concurrent-safe map
28+
size int32
29+
}
30+
31+
type VolumeState struct {
32+
BucketAccessCheckPassed bool
33+
}
34+
35+
// NewVolumeStateStore initializes the volume state store.
36+
func NewVolumeStateStore() *VolumeStateStore {
37+
return &VolumeStateStore{}
38+
}
39+
40+
// NewVolumeStateStore initializes the volume state store.
41+
func (vss *VolumeStateStore) Size() int32 {
42+
return vss.size
43+
}
44+
45+
// Store adds or updates a volume state.
46+
func (vss *VolumeStateStore) Store(volumeID string, state *VolumeState) {
47+
vss.store.Store(volumeID, state)
48+
atomic.AddInt32(&vss.size, 1)
49+
}
50+
51+
// Load retrieves the state of a volume.
52+
func (vss *VolumeStateStore) Load(volumeID string) (*VolumeState, bool) {
53+
value, ok := vss.store.Load(volumeID)
54+
if !ok {
55+
return nil, false
56+
}
57+
return value.(*VolumeState), true // Type assertion
58+
}
59+
60+
// Delete removes a volume from the store.
61+
func (vss *VolumeStateStore) Delete(volumeID string) {
62+
vss.store.Delete(volumeID)
63+
atomic.AddInt32(&vss.size, -1)
64+
}
65+
66+
// Range iterates over all stored volumes.
67+
func (vss *VolumeStateStore) Range(f func(volumeID string, state *VolumeState) bool) {
68+
vss.store.Range(func(key, value interface{}) bool {
69+
return f(key.(string), value.(*VolumeState))
70+
})
71+
}

0 commit comments

Comments
 (0)