Skip to content

Commit a5fe089

Browse files
authored
check if batch query is empty, otherwise skip (#2252)
Signed-off-by: pxp928 <[email protected]>
1 parent fad3dd5 commit a5fe089

File tree

1 file changed

+68
-32
lines changed

1 file changed

+68
-32
lines changed

pkg/assembler/backends/ent/backend/search.go

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ func (b *EntBackend) FindPackagesThatNeedScanning(ctx context.Context, queryType
199199
func (b *EntBackend) QueryPackagesListForScan(ctx context.Context, pkgIDs []string, after *string, first *int) (*model.PackageConnection, error) {
200200
var afterCursor *entgql.Cursor[uuid.UUID]
201201

202+
// if empty pkgIDs slice is passed in return nothing
203+
if len(pkgIDs) == 0 {
204+
return nil, nil
205+
}
206+
202207
if after != nil {
203208
globalID := fromGlobalID(*after)
204209
if globalID.nodeType != packageversion.Table {
@@ -244,14 +249,17 @@ func (b *EntBackend) QueryPackagesListForScan(ctx context.Context, pkgIDs []stri
244249
shortenedQueryList = append(shortenedQueryList, convertedID)
245250
}
246251
}
247-
var queryErr error
248-
pkgConn, queryErr = b.client.PackageVersion.Query().
249-
Where(packageversion.IDIn(shortenedQueryList...)).
250-
WithName(func(q *ent.PackageNameQuery) {}).
251-
Paginate(ctx, afterCursor, first, nil, nil)
252252

253-
if queryErr != nil {
254-
return nil, fmt.Errorf("failed package query based on package IDs that need scanning with error: %w", queryErr)
253+
if len(shortenedQueryList) > 0 {
254+
var queryErr error
255+
pkgConn, queryErr = b.client.PackageVersion.Query().
256+
Where(packageversion.IDIn(shortenedQueryList...)).
257+
WithName(func(q *ent.PackageNameQuery) {}).
258+
Paginate(ctx, afterCursor, first, nil, nil)
259+
260+
if queryErr != nil {
261+
return nil, fmt.Errorf("failed package query based on package IDs that need scanning with error: %w", queryErr)
262+
}
255263
}
256264

257265
// if not found return nil
@@ -295,6 +303,11 @@ func constructPkgConn(pkgConn *ent.PackageVersionConnection, totalCount int, has
295303

296304
func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []string) ([]*model.CertifyVuln, error) {
297305

306+
// if empty pkgIDs slice is passed in return nothing
307+
if len(pkgIDs) == 0 {
308+
return nil, nil
309+
}
310+
298311
// static ID for noVuln that is generated from type = novuln and vulnid = ""
299312
// this is generated via:
300313
vulnIDs := helpers.GetKey[*model.VulnerabilityInputSpec, helpers.VulnIds](&model.VulnerabilityInputSpec{Type: NoVuln, VulnerabilityID: ""}, helpers.VulnServerKey)
@@ -343,25 +356,32 @@ func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []st
343356
))
344357
}
345358

346-
certVulnConn, err := b.client.CertifyVuln.Query().
347-
Where(certifyvuln.Or(predicates...)).
348-
WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}).
349-
WithPackage(func(q *ent.PackageVersionQuery) {
350-
q.WithName(func(q *ent.PackageNameQuery) {})
351-
}).All(ctx)
352-
353-
if err != nil {
354-
return nil, fmt.Errorf("failed certifyVuln query based on package IDs with error: %w", err)
355-
}
356359
var collectedCertVuln []*model.CertifyVuln
357-
for _, entCertVuln := range certVulnConn {
358-
collectedCertVuln = append(collectedCertVuln, toModelCertifyVulnerability(entCertVuln))
360+
if len(predicates) > 0 {
361+
certVulnConn, err := b.client.CertifyVuln.Query().
362+
Where(certifyvuln.Or(predicates...)).
363+
WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}).
364+
WithPackage(func(q *ent.PackageVersionQuery) {
365+
q.WithName(func(q *ent.PackageNameQuery) {})
366+
}).All(ctx)
367+
368+
if err != nil {
369+
return nil, fmt.Errorf("failed certifyVuln query based on package IDs with error: %w", err)
370+
}
371+
for _, entCertVuln := range certVulnConn {
372+
collectedCertVuln = append(collectedCertVuln, toModelCertifyVulnerability(entCertVuln))
373+
}
359374
}
360375
return collectedCertVuln, nil
361376
}
362377

363378
func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []string) ([]*model.CertifyLegal, error) {
364379

380+
// if empty pkgIDs slice is passed in return nothing
381+
if len(pkgIDs) == 0 {
382+
return nil, nil
383+
}
384+
365385
var queryList []uuid.UUID
366386

367387
for _, id := range pkgIDs {
@@ -408,26 +428,36 @@ func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []s
408428
))
409429
}
410430

411-
certLegalConn, err := b.client.CertifyLegal.Query().
412-
Where(certifylegal.Or(predicates...)).
413-
WithPackage(func(q *ent.PackageVersionQuery) {
414-
q.WithName(func(q *ent.PackageNameQuery) {})
415-
}).
416-
WithDeclaredLicenses().
417-
WithDiscoveredLicenses().All(ctx)
431+
var collectedCertLegal []*model.CertifyLegal
418432

419-
if err != nil {
420-
return nil, fmt.Errorf("failed certifyLegal query based on package IDs with error: %w", err)
421-
}
433+
if len(predicates) > 0 {
434+
certLegalConn, err := b.client.CertifyLegal.Query().
435+
Where(certifylegal.Or(predicates...)).
436+
WithPackage(func(q *ent.PackageVersionQuery) {
437+
q.WithName(func(q *ent.PackageNameQuery) {})
438+
}).
439+
WithDeclaredLicenses().
440+
WithDiscoveredLicenses().All(ctx)
422441

423-
var collectedCertLegal []*model.CertifyLegal
424-
for _, entCertLegal := range certLegalConn {
425-
collectedCertLegal = append(collectedCertLegal, toModelCertifyLegal(entCertLegal))
442+
if err != nil {
443+
return nil, fmt.Errorf("failed certifyLegal query based on package IDs with error: %w", err)
444+
}
445+
446+
for _, entCertLegal := range certLegalConn {
447+
collectedCertLegal = append(collectedCertLegal, toModelCertifyLegal(entCertLegal))
448+
}
426449
}
450+
427451
return collectedCertLegal, nil
428452
}
429453

430454
func (b *EntBackend) BatchQuerySubjectPkgDependency(ctx context.Context, pkgIDs []string) ([]*model.IsDependency, error) {
455+
456+
// if empty pkgIDs slice is passed in return nothing
457+
if len(pkgIDs) == 0 {
458+
return nil, nil
459+
}
460+
431461
var queryList []uuid.UUID
432462

433463
for _, id := range pkgIDs {
@@ -456,6 +486,12 @@ func (b *EntBackend) BatchQuerySubjectPkgDependency(ctx context.Context, pkgIDs
456486
}
457487

458488
func (b *EntBackend) BatchQueryDepPkgDependency(ctx context.Context, pkgIDs []string) ([]*model.IsDependency, error) {
489+
490+
// if empty pkgIDs slice is passed in return nothing
491+
if len(pkgIDs) == 0 {
492+
return nil, nil
493+
}
494+
459495
var queryList []uuid.UUID
460496

461497
for _, id := range pkgIDs {

0 commit comments

Comments
 (0)