diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index eda4c6fe20..a023d26c13 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -316,18 +316,23 @@ func (s *LocalScanner) distributeJob(ctx context.Context, job *Job, upstream *up // attach the asset details to the assets list for i := range batch { asset := batch[i].Asset + runtime := batch[i].Runtime log.Debug().Str("asset", asset.Name).Strs("platform-ids", asset.PlatformIds).Msg("update asset") for _, platformId := range asset.PlatformIds { if details, ok := platformAssetMapping[platformId]; ok { asset.Mrn = details.AssetMrn asset.Url = details.Url } + if runtime != nil { + runtime.AssetUpdated(asset) + } } } } else { // ensure we have non-empty asset MRNs for i := range batch { asset := batch[i].Asset + runtime := batch[i].Runtime if asset.Mrn == "" { randID := "//" + explorer.SERVICE_NAME + "/" + explorer.MRN_RESOURCE_ASSET + "/" + ksuid.New().String() x, err := mrn.NewMRN(randID) @@ -335,6 +340,10 @@ func (s *LocalScanner) distributeJob(ctx context.Context, job *Job, upstream *up return nil, multierr.Wrap(err, "failed to generate a random asset MRN") } asset.Mrn = x.String() + // update the asset in the runtime as well + if runtime != nil { + runtime.AssetUpdated(asset) + } } } } diff --git a/providers-sdk/v1/recording/asset_recording.go b/providers-sdk/v1/recording/asset_recording.go index 1a19c2f7f2..55bfb56156 100644 --- a/providers-sdk/v1/recording/asset_recording.go +++ b/providers-sdk/v1/recording/asset_recording.go @@ -10,7 +10,7 @@ import ( ) type Asset struct { - Asset assetInfo `json:"asset"` + Asset *assetInfo `json:"asset"` Connections []connection `json:"connections"` Resources []Resource `json:"resources"` @@ -38,7 +38,7 @@ type connection struct { ProviderID string `json:"provider"` Connector string `json:"connector"` Version string `json:"version"` - id uint32 `json:"-"` + Id uint32 `json:"id"` } type Resource struct { diff --git a/providers-sdk/v1/recording/asset_recording_test.go b/providers-sdk/v1/recording/asset_recording_test.go new file mode 100644 index 0000000000..9c1e1423b2 --- /dev/null +++ b/providers-sdk/v1/recording/asset_recording_test.go @@ -0,0 +1,87 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package recording + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.mondoo.com/cnquery/v12/providers-sdk/v1/inventory" +) + +func TestAssetRecording(t *testing.T) { + t.Run("add asset by id only", func(t *testing.T) { + rec := &recording{ + assets: map[uint32]*Asset{}, + Assets: []*Asset{}, + } + + asset := &inventory.Asset{ + Id: "asset-id", + PlatformIds: []string{}, + Platform: &inventory.Platform{}, + } + conf := &inventory.Config{ + Type: "local", + } + rec.EnsureAsset(asset, "provider", 1, conf) + + require.Len(t, rec.assets, 1) + require.Len(t, rec.Assets, 1) + require.Len(t, rec.Assets[0].connections, 1) + require.Len(t, rec.Assets[0].Resources, 0) + a := rec.Assets[0].Asset + require.Equal(t, "asset-id", a.ID) + require.Equal(t, []string{}, a.PlatformIDs) + + // re-add again, should be idempotent + asset.PlatformIds = []string{"platform-id"} + rec.EnsureAsset(asset, "provider", 1, conf) + require.Len(t, rec.assets, 1) + require.Len(t, rec.Assets, 1) + require.Len(t, rec.Assets[0].connections, 1) + require.Len(t, rec.Assets[0].Resources, 0) + a = rec.Assets[0].Asset + require.Equal(t, "asset-id", a.ID) + require.Equal(t, []string{"platform-id"}, a.PlatformIDs) + }) + + t.Run("add asset by id and mrn", func(t *testing.T) { + rec := &recording{ + assets: map[uint32]*Asset{}, + Assets: []*Asset{}, + } + + asset := &inventory.Asset{ + Id: "asset-id", + PlatformIds: []string{"platform-id"}, + Platform: &inventory.Platform{}, + } + conf := &inventory.Config{ + Type: "local", + } + rec.EnsureAsset(asset, "provider", 1, conf) + + require.Len(t, rec.assets, 1) + require.Len(t, rec.Assets, 1) + require.Len(t, rec.Assets[0].connections, 1) + require.Len(t, rec.Assets[0].Resources, 0) + a := rec.Assets[0].Asset + require.Equal(t, "asset-id", a.ID) + require.Equal(t, []string{"platform-id"}, a.PlatformIDs) + + // re-add again by MRN, ensure nothing gets duplicated + asset.Mrn = "asset-mrn" + asset.PlatformIds = []string{"platform-id", "asset-mrn"} + rec.EnsureAsset(asset, "provider", 1, conf) + require.Len(t, rec.assets, 1) + require.Len(t, rec.Assets, 1) + require.Len(t, rec.Assets[0].connections, 1) + require.Len(t, rec.Assets[0].Resources, 0) + a = rec.Assets[0].Asset + + require.Equal(t, "asset-mrn", a.ID) + require.Equal(t, []string{"platform-id", "asset-mrn"}, a.PlatformIDs) + }) +} diff --git a/providers-sdk/v1/recording/recording.go b/providers-sdk/v1/recording/recording.go index 8db499dc3b..12f071adfd 100644 --- a/providers-sdk/v1/recording/recording.go +++ b/providers-sdk/v1/recording/recording.go @@ -6,7 +6,9 @@ package recording import ( "encoding/json" "errors" + "fmt" "os" + "slices" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v12/llx" @@ -34,7 +36,7 @@ func NewAssetRecording(asset *inventory.Asset) *Asset { if id == "" && asset.Platform != nil { id = asset.Platform.Title } - ai := assetInfo{ + ai := &assetInfo{ ID: id, Name: asset.Name, PlatformIDs: asset.PlatformIds, @@ -90,9 +92,9 @@ func (n *readOnly) Save() error { func (n *readOnly) EnsureAsset(asset *inventory.Asset, provider string, connectionID uint32, conf *inventory.Config) { // For read-only recordings we are still loading from file, so that means // we are severely lacking connection IDs. - found, _ := n.findAssetConnID(asset) - if found != -1 { - n.assets[connectionID] = n.Assets[found] + existing := n.getExistingAsset(asset) + if existing != nil { + n.assets[connectionID] = existing } } @@ -206,13 +208,12 @@ func (r *recording) refreshCache() { asset := r.Assets[i] asset.RefreshCache() - for i := range asset.Connections { - conn := asset.Connections[i] + for _, conn := range asset.Connections { // only connection ID's != 0 are valid IDs. We get lots of 0 when we // initially load this object, so we won't know yet which asset belongs // to which connection. - if conn.id != 0 { - r.assets[conn.id] = asset + if conn.Id != 0 { + r.assets[conn.Id] = asset } } } @@ -318,51 +319,44 @@ func (r *recording) finalize() { } } -func (r *recording) findAssetConnID(asset *inventory.Asset) (int, string) { +func (r *recording) getExistingAsset(asset *inventory.Asset) *Asset { if asset.Mrn != "" || asset.Id != "" { - for i := range r.Assets { - id := r.Assets[i].Asset.ID + for _, existing := range r.Assets { + id := existing.Asset.ID if id == "" { continue } if id == asset.Mrn { - return i, asset.Mrn + return existing } if id == asset.Id { - return i, asset.Id + return existing } - for _, pidExisting := range r.Assets[i].Asset.PlatformIDs { - for _, pid := range asset.PlatformIds { - if pidExisting == pid { - return i, asset.Mrn - } + for _, pidExisting := range existing.Asset.PlatformIDs { + if slices.Contains(asset.PlatformIds, pidExisting) { + return existing } } } } - return -1, "" + return nil } func (r *recording) EnsureAsset(asset *inventory.Asset, providerID string, connectionID uint32, conf *inventory.Config) { - found, _ := r.findAssetConnID(asset) - if asset.Platform == nil { log.Warn().Msg("cannot store asset in recording, asset has no platform") return } - if found == -1 { - id := asset.Mrn - if id == "" { - id = asset.Id - } - if id == "" { - id = asset.Platform.Title - } - - r.Assets = append(r.Assets, &Asset{ - Asset: assetInfo{ - ID: id, + id := getAssetIdForRecording(asset) + if id == "" { + log.Debug().Msg("cannot store asset in recording, asset has no id or mrn") + return + } + recordingAsset := r.getExistingAsset(asset) + if recordingAsset == nil { + recordingAsset = &Asset{ + Asset: &assetInfo{ PlatformIDs: asset.PlatformIds, Name: asset.Platform.Name, Arch: asset.Platform.Arch, @@ -376,30 +370,23 @@ func (r *recording) EnsureAsset(asset *inventory.Asset, providerID string, conne }, connections: map[string]*connection{}, resources: map[string]*Resource{}, - }) - found = len(r.Assets) - 1 - } - - // An asset is sometimes added to the recording, before it has its MRN assigned. - // This method may be called again, after the MRN has been assigned. In that - // case we make sure that the asset ID matches the MRN. - // TODO: figure out a better position to do this, both for the MRN and IDs - assetObj := r.Assets[found] - if asset.Mrn != "" { - assetObj.Asset.ID = asset.Mrn + } + r.Assets = append(r.Assets, recordingAsset) } + // always update the id for the asset, sometimes we get assets by id and then + // they get updated with an MRN attached + recordingAsset.Asset.ID = getAssetIdForRecording(asset) if len(asset.PlatformIds) != 0 { - assetObj.Asset.PlatformIDs = asset.PlatformIds + recordingAsset.Asset.PlatformIDs = asset.PlatformIds } - url := conf.ToUrl() - assetObj.connections[url] = &connection{ - Url: url, + recordingAsset.connections[fmt.Sprintf("%d", conf.Id)] = &connection{ + Url: conf.ToUrl(), ProviderID: providerID, Connector: conf.Type, - id: conf.Id, + Id: conf.Id, } - r.assets[connectionID] = assetObj + r.assets[connectionID] = recordingAsset } func (r *recording) AddData(connectionID uint32, resource string, id string, field string, data *llx.RawData) { @@ -501,7 +488,7 @@ func (r *recording) SetAssetRecording(id uint32, reco *Asset) { // This method makes sure the asset metadata is always included in the data // dump of a recording -func ensureAssetMetadata(resources map[string]*Resource, asset assetInfo) { +func ensureAssetMetadata(resources map[string]*Resource, asset *assetInfo) { id := "asset\x00" existing, ok := resources[id] if !ok { @@ -576,3 +563,11 @@ func RawDataArgsToResultArgs(args map[string]*llx.RawData) (map[string]*llx.Resu return all, err.Deduplicate() } + +func getAssetIdForRecording(asset *inventory.Asset) string { + id := asset.Mrn + if id == "" { + id = asset.Id + } + return id +} diff --git a/providers-sdk/v1/recording/upstream_recording.go b/providers-sdk/v1/recording/upstream_recording.go index 506dac42e8..32d30bdf11 100644 --- a/providers-sdk/v1/recording/upstream_recording.go +++ b/providers-sdk/v1/recording/upstream_recording.go @@ -106,7 +106,7 @@ func (n *Upstream) GetData(connectionID uint32, resource string, id string, fiel return res, ok } -func (n *Upstream) GetResource(connectionID uint32, resource string, id string) (map[string]*llx.RawData, bool) { +func (n *Upstream) GetResource(_ uint32, resource string, id string) (map[string]*llx.RawData, bool) { n.lock.Lock() defer n.lock.Unlock() diff --git a/providers/runtime.go b/providers/runtime.go index 05f4669787..aeecbbbab2 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -290,7 +290,12 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { } func (r *Runtime) AssetUpdated(asset *inventory.Asset) { - r.Recording().EnsureAsset(r.Provider.Connection.Asset, r.Provider.Instance.ID, r.Provider.Connection.Id, asset.Connections[0]) + rec := r.Recording() + rec.EnsureAsset( + r.Provider.Connection.Asset, + r.Provider.Instance.ID, + r.Provider.Connection.Id, + asset.Connections[0]) } func (r *Runtime) CreateResource(name string, args map[string]*llx.Primitive) (llx.Resource, error) {