Skip to content

Commit aa114cc

Browse files
🐛 Use sync map to prevent concurrent writes. (#6053)
Signed-off-by: Preslav <preslav@mondoo.com>
1 parent b11d223 commit aa114cc

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

providers-sdk/v1/recording/asset_recording_test.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ import (
88

99
"github.com/stretchr/testify/require"
1010
"go.mondoo.com/cnquery/v12/providers-sdk/v1/inventory"
11+
"go.mondoo.com/cnquery/v12/utils/syncx"
1112
)
1213

1314
func TestAssetRecording(t *testing.T) {
1415
t.Run("add asset by id only", func(t *testing.T) {
1516
rec := &recording{
16-
assets: map[uint32]*Asset{},
17+
assets: syncx.Map[*Asset]{},
1718
Assets: []*Asset{},
1819
}
1920

@@ -27,13 +28,12 @@ func TestAssetRecording(t *testing.T) {
2728
}
2829
rec.EnsureAsset(asset, "provider", 1, conf)
2930

30-
require.Len(t, rec.assets, 0)
3131
require.Len(t, rec.Assets, 0)
3232
})
3333

3434
t.Run("add asset by mrn", func(t *testing.T) {
3535
rec := &recording{
36-
assets: map[uint32]*Asset{},
36+
assets: syncx.Map[*Asset]{},
3737
Assets: []*Asset{},
3838
}
3939

@@ -47,7 +47,6 @@ func TestAssetRecording(t *testing.T) {
4747
}
4848
rec.EnsureAsset(asset, "provider", 1, conf)
4949

50-
require.Len(t, rec.assets, 1)
5150
require.Len(t, rec.Assets, 1)
5251
require.Len(t, rec.Assets[0].connections, 1)
5352
require.Len(t, rec.Assets[0].Resources, 0)
@@ -59,7 +58,6 @@ func TestAssetRecording(t *testing.T) {
5958
asset.Mrn = "asset-mrn"
6059
asset.PlatformIds = []string{"platform-id", "asset-mrn"}
6160
rec.EnsureAsset(asset, "provider", 1, conf)
62-
require.Len(t, rec.assets, 1)
6361
require.Len(t, rec.Assets, 1)
6462
require.Len(t, rec.Assets[0].connections, 1)
6563
require.Len(t, rec.Assets[0].Resources, 0)
@@ -71,7 +69,7 @@ func TestAssetRecording(t *testing.T) {
7169

7270
t.Run("add asset by platform id and mrn", func(t *testing.T) {
7371
rec := &recording{
74-
assets: map[uint32]*Asset{},
72+
assets: syncx.Map[*Asset]{},
7573
Assets: []*Asset{},
7674
}
7775

@@ -85,7 +83,6 @@ func TestAssetRecording(t *testing.T) {
8583
}
8684
rec.EnsureAsset(asset, "provider", 1, conf)
8785

88-
require.Len(t, rec.assets, 1)
8986
require.Len(t, rec.Assets, 1)
9087
require.Len(t, rec.Assets[0].connections, 1)
9188
require.Len(t, rec.Assets[0].Resources, 0)
@@ -96,7 +93,6 @@ func TestAssetRecording(t *testing.T) {
9693
// re-add again by platform id, ensure nothing gets duplicated
9794
asset.Mrn = ""
9895
rec.EnsureAsset(asset, "provider", 1, conf)
99-
require.Len(t, rec.assets, 1)
10096
require.Len(t, rec.Assets, 1)
10197
require.Len(t, rec.Assets[0].connections, 1)
10298
require.Len(t, rec.Assets[0].Resources, 0)
@@ -108,7 +104,6 @@ func TestAssetRecording(t *testing.T) {
108104
// re-add again by mrn, ensure nothing gets duplicated
109105
asset.Mrn = "asset-mrn"
110106
rec.EnsureAsset(asset, "provider", 1, conf)
111-
require.Len(t, rec.assets, 1)
112107
require.Len(t, rec.Assets, 1)
113108
require.Len(t, rec.Assets[0].connections, 1)
114109
require.Len(t, rec.Assets[0].Resources, 0)

providers-sdk/v1/recording/recording.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ import (
1515
"go.mondoo.com/cnquery/v12/providers-sdk/v1/inventory"
1616
"go.mondoo.com/cnquery/v12/types"
1717
"go.mondoo.com/cnquery/v12/utils/multierr"
18+
"go.mondoo.com/cnquery/v12/utils/syncx"
1819
)
1920

2021
type recording struct {
2122
Assets []*Asset `json:"assets"`
2223
Path string `json:"-"`
2324
// assets is used for fast connection to asset lookup
24-
assets map[uint32]*Asset `json:"-"`
25+
assets syncx.Map[*Asset] `json:"-"`
2526
prettyPrintJSON bool `json:"-"`
2627
// this mode is used when we use the recording layer for data,
2728
// but not for storing it on disk
@@ -94,7 +95,7 @@ func (n *readOnly) EnsureAsset(asset *inventory.Asset, provider string, connecti
9495
// we are severely lacking connection IDs.
9596
existing := n.getExistingAsset(asset)
9697
if existing != nil {
97-
n.assets[connectionID] = existing
98+
n.assets.Set(fmt.Sprintf("%d", connectionID), existing)
9899
}
99100
}
100101

@@ -142,7 +143,7 @@ func NewWithFile(path string, opts RecordingOptions) (llx.Recording, error) {
142143
Path: path,
143144
prettyPrintJSON: opts.PrettyPrintJSON,
144145
doNotSave: opts.DoNotSave,
145-
assets: map[uint32]*Asset{},
146+
assets: syncx.Map[*Asset]{},
146147
}
147148
res.refreshCache() // only for initialization
148149
return res, nil
@@ -203,7 +204,7 @@ func (r *recording) Save() error {
203204
}
204205

205206
func (r *recording) refreshCache() {
206-
r.assets = make(map[uint32]*Asset, len(r.Assets))
207+
r.assets = syncx.Map[*Asset]{}
207208
for i := range r.Assets {
208209
asset := r.Assets[i]
209210
asset.RefreshCache()
@@ -213,7 +214,7 @@ func (r *recording) refreshCache() {
213214
// initially load this object, so we won't know yet which asset belongs
214215
// to which connection.
215216
if conn.Id != 0 {
216-
r.assets[conn.Id] = asset
217+
r.assets.Set(fmt.Sprintf("%d", conn.Id), asset)
217218
}
218219
}
219220
}
@@ -379,11 +380,11 @@ func (r *recording) EnsureAsset(asset *inventory.Asset, providerID string, conne
379380
Connector: conf.Type,
380381
Id: conf.Id,
381382
}
382-
r.assets[connectionID] = recordingAsset
383+
r.assets.Set(fmt.Sprintf("%d", conf.Id), recordingAsset)
383384
}
384385

385386
func (r *recording) AddData(connectionID uint32, resource string, id string, field string, data *llx.RawData) {
386-
asset, ok := r.assets[connectionID]
387+
asset, ok := r.assets.Get(fmt.Sprintf("%d", connectionID))
387388
if !ok {
388389
return
389390
}
@@ -404,7 +405,7 @@ func (r *recording) AddData(connectionID uint32, resource string, id string, fie
404405
}
405406

406407
func (r *recording) GetData(connectionID uint32, resource string, id string, field string) (*llx.RawData, bool) {
407-
asset, ok := r.assets[connectionID]
408+
asset, ok := r.assets.Get(fmt.Sprintf("%d", connectionID))
408409
if !ok {
409410
return nil, false
410411
}
@@ -427,7 +428,7 @@ func (r *recording) GetData(connectionID uint32, resource string, id string, fie
427428
}
428429

429430
func (r *recording) GetResource(connectionID uint32, resource string, id string) (map[string]*llx.RawData, bool) {
430-
asset, ok := r.assets[connectionID]
431+
asset, ok := r.assets.Get(fmt.Sprintf("%d", connectionID))
431432
if !ok {
432433
return nil, false
433434
}
@@ -475,7 +476,7 @@ func (r *recording) GetAssetRecordings() []*Asset {
475476
}
476477

477478
func (r *recording) SetAssetRecording(id uint32, reco *Asset) {
478-
r.assets[id] = reco
479+
r.assets.Set(fmt.Sprintf("%d", id), reco)
479480
}
480481

481482
// This method makes sure the asset metadata is always included in the data

0 commit comments

Comments
 (0)