Skip to content

Commit 60408cb

Browse files
committed
internal/symbols, cmd/vulnreport: move logic to add symbols to reports
Move the logic (but don't modify it) to populate symbols to its own file in internal/symbols. Add some basic tests that confirm the current behavior (which will likely be tweaked in follow up CLs). Change-Id: I10593154c343adb680733ebd66a4dd97abed2c43 Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/560778 Reviewed-by: Maceo Thompson <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 2584928 commit 60408cb

File tree

3 files changed

+204
-62
lines changed

3 files changed

+204
-62
lines changed

cmd/vulnreport/symbols.go

+2-62
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@ package main
66

77
import (
88
"context"
9-
"fmt"
10-
"path/filepath"
11-
"strings"
129

1310
"golang.org/x/vulndb/cmd/vulnreport/log"
14-
"golang.org/x/vulndb/internal/osv"
1511
"golang.org/x/vulndb/internal/report"
1612
"golang.org/x/vulndb/internal/symbols"
1713
)
@@ -34,66 +30,10 @@ func (s *symbolsCmd) run(ctx context.Context, filename string) (err error) {
3430
if err != nil {
3531
return err
3632
}
37-
var defaultFixes []string
3833

39-
for _, ref := range r.References {
40-
if ref.Type == osv.ReferenceTypeFix {
41-
if filepath.Base(filepath.Dir(ref.URL)) == "commit" {
42-
defaultFixes = append(defaultFixes, ref.URL)
43-
}
44-
}
45-
}
46-
if len(defaultFixes) == 0 {
47-
return fmt.Errorf("no commit fix links found")
34+
if err = symbols.Populate(r, log.Err); err != nil {
35+
return err
4836
}
4937

50-
for _, mod := range r.Modules {
51-
hasFixLink := mod.FixLink != ""
52-
if hasFixLink {
53-
defaultFixes = append(defaultFixes, mod.FixLink)
54-
}
55-
numFixedSymbols := make([]int, len(defaultFixes))
56-
for i, fixLink := range defaultFixes {
57-
fixHash := filepath.Base(fixLink)
58-
fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
59-
pkgsToSymbols, err := symbols.Patched(mod.Module, fixRepo, fixHash)
60-
if err != nil {
61-
log.Err(err)
62-
continue
63-
}
64-
packages := mod.AllPackages()
65-
for pkg, symbols := range pkgsToSymbols {
66-
if _, exists := packages[pkg]; exists {
67-
packages[pkg].Symbols = append(packages[pkg].Symbols, symbols...)
68-
} else {
69-
mod.Packages = append(mod.Packages, &report.Package{
70-
Package: pkg,
71-
Symbols: symbols,
72-
})
73-
}
74-
numFixedSymbols[i] += len(symbols)
75-
}
76-
}
77-
// if the module's link field wasn't already populated, populate it with
78-
// the link that results in the most symbols
79-
if hasFixLink {
80-
defaultFixes = defaultFixes[:len(defaultFixes)-1]
81-
} else {
82-
mod.FixLink = defaultFixes[indexMax(numFixedSymbols)]
83-
}
84-
}
8538
return r.Write(filename)
8639
}
87-
88-
// indexMax takes a slice of nonempty ints and returns the index of the maximum value
89-
func indexMax(s []int) (index int) {
90-
maxVal := s[0]
91-
index = 0
92-
for i, val := range s {
93-
if val > maxVal {
94-
maxVal = val
95-
index = i
96-
}
97-
}
98-
return index
99-
}

internal/symbols/populate.go

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package symbols
6+
7+
import (
8+
"fmt"
9+
"path/filepath"
10+
"strings"
11+
12+
"golang.org/x/vulndb/internal/osv"
13+
"golang.org/x/vulndb/internal/report"
14+
)
15+
16+
// Populate attempts to populate the report with symbols derived
17+
// from the patch link(s) in the report.
18+
func Populate(r *report.Report, errln logln) error {
19+
return populate(r, Patched, errln)
20+
}
21+
22+
func populate(r *report.Report, patched func(string, string, string) (map[string][]string, error), errln logln) error {
23+
var defaultFixes []string
24+
25+
for _, ref := range r.References {
26+
if ref.Type == osv.ReferenceTypeFix {
27+
if filepath.Base(filepath.Dir(ref.URL)) == "commit" {
28+
defaultFixes = append(defaultFixes, ref.URL)
29+
}
30+
}
31+
}
32+
if len(defaultFixes) == 0 {
33+
return fmt.Errorf("no commit fix links found")
34+
}
35+
36+
for _, mod := range r.Modules {
37+
hasFixLink := mod.FixLink != ""
38+
if hasFixLink {
39+
defaultFixes = append(defaultFixes, mod.FixLink)
40+
}
41+
numFixedSymbols := make([]int, len(defaultFixes))
42+
for i, fixLink := range defaultFixes {
43+
fixHash := filepath.Base(fixLink)
44+
fixRepo := strings.TrimSuffix(fixLink, "/commit/"+fixHash)
45+
pkgsToSymbols, err := patched(mod.Module, fixRepo, fixHash)
46+
if err != nil {
47+
errln(err)
48+
continue
49+
}
50+
packages := mod.AllPackages()
51+
for pkg, symbols := range pkgsToSymbols {
52+
if _, exists := packages[pkg]; exists {
53+
packages[pkg].Symbols = append(packages[pkg].Symbols, symbols...)
54+
} else {
55+
mod.Packages = append(mod.Packages, &report.Package{
56+
Package: pkg,
57+
Symbols: symbols,
58+
})
59+
}
60+
numFixedSymbols[i] += len(symbols)
61+
}
62+
}
63+
// if the module's link field wasn't already populated, populate it with
64+
// the link that results in the most symbols
65+
if hasFixLink {
66+
defaultFixes = defaultFixes[:len(defaultFixes)-1]
67+
} else {
68+
mod.FixLink = defaultFixes[indexMax(numFixedSymbols)]
69+
}
70+
}
71+
72+
return nil
73+
}
74+
75+
// indexMax takes a slice of nonempty ints and returns the index of the maximum value
76+
func indexMax(s []int) (index int) {
77+
maxVal := s[0]
78+
index = 0
79+
for i, val := range s {
80+
if val > maxVal {
81+
maxVal = val
82+
index = i
83+
}
84+
}
85+
return index
86+
}

internal/symbols/populate_test.go

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package symbols
6+
7+
import (
8+
"fmt"
9+
"testing"
10+
11+
"github.com/google/go-cmp/cmp"
12+
"golang.org/x/vulndb/internal/osv"
13+
"golang.org/x/vulndb/internal/report"
14+
)
15+
16+
func TestPopulate(t *testing.T) {
17+
for _, tc := range []struct {
18+
name string
19+
input *report.Report
20+
want *report.Report
21+
}{
22+
{
23+
name: "basic",
24+
input: &report.Report{
25+
Modules: []*report.Module{{
26+
Module: "example.com/module",
27+
}},
28+
References: []*report.Reference{{
29+
Type: osv.ReferenceTypeFix,
30+
URL: "https://example.com/module/commit/1234",
31+
}},
32+
},
33+
want: &report.Report{
34+
Modules: []*report.Module{{
35+
Module: "example.com/module",
36+
Packages: []*report.Package{{
37+
Package: "example.com/module/package",
38+
Symbols: []string{"symbol1", "symbol2"},
39+
}},
40+
FixLink: "https://example.com/module/commit/1234",
41+
}},
42+
References: []*report.Reference{
43+
{
44+
Type: osv.ReferenceTypeFix,
45+
URL: "https://example.com/module/commit/1234",
46+
},
47+
},
48+
},
49+
},
50+
{
51+
name: "multiple_fixes",
52+
input: &report.Report{
53+
Modules: []*report.Module{{
54+
Module: "example.com/module",
55+
}},
56+
References: []*report.Reference{
57+
{
58+
Type: osv.ReferenceTypeFix,
59+
URL: "https://example.com/module/commit/1234",
60+
},
61+
{
62+
Type: osv.ReferenceTypeFix,
63+
URL: "https://example.com/module/commit/5678",
64+
},
65+
},
66+
},
67+
want: &report.Report{
68+
Modules: []*report.Module{{
69+
Module: "example.com/module",
70+
Packages: []*report.Package{{
71+
Package: "example.com/module/package",
72+
// We don't yet dedupe the symbols.
73+
Symbols: []string{"symbol1", "symbol2", "symbol1", "symbol2", "symbol3"},
74+
}},
75+
// This commit is picked because it results in the most symbols.
76+
FixLink: "https://example.com/module/commit/5678",
77+
}},
78+
References: []*report.Reference{
79+
{
80+
Type: osv.ReferenceTypeFix,
81+
URL: "https://example.com/module/commit/1234",
82+
},
83+
{
84+
Type: osv.ReferenceTypeFix,
85+
URL: "https://example.com/module/commit/5678",
86+
},
87+
},
88+
},
89+
},
90+
} {
91+
t.Run(tc.name, func(t *testing.T) {
92+
discardLog := func(...any) {}
93+
if err := populate(tc.input, patchedFake, discardLog); err != nil {
94+
t.Fatal(err)
95+
}
96+
got := tc.input
97+
if diff := cmp.Diff(tc.want, got); diff != "" {
98+
t.Errorf("populate mismatch (-want, +got):\n%s", diff)
99+
}
100+
})
101+
}
102+
}
103+
104+
func patchedFake(module string, repo string, hash string) (map[string][]string, error) {
105+
if module == "example.com/module" && repo == "https://example.com/module" && hash == "1234" {
106+
return map[string][]string{
107+
"example.com/module/package": {"symbol1", "symbol2"},
108+
}, nil
109+
}
110+
if module == "example.com/module" && repo == "https://example.com/module" && hash == "5678" {
111+
return map[string][]string{
112+
"example.com/module/package": {"symbol1", "symbol2", "symbol3"},
113+
}, nil
114+
}
115+
return nil, fmt.Errorf("unrecognized inputs: module=%s,repo=%s,hash=%s", module, repo, hash)
116+
}

0 commit comments

Comments
 (0)