Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 80 additions & 43 deletions nvd/nvd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nvd

import (
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
Expand All @@ -23,6 +24,11 @@ const (
maxResultsPerPage = 2000
retryAfter = 30 * time.Second
apiKeyEnvName = "NVD_API_KEY"

// statusOriginTimeout is a non-standard Cloudflare status code (524) returned
// when the origin server (NVD backend behind the proxy) does not respond in time.
// There is no constant for it in net/http.
statusOriginTimeout = 524
)

type Option func(*Updater)
Expand Down Expand Up @@ -135,68 +141,99 @@ func (u Updater) saveEntry(interval TimeInterval, startIndex int) (int, error) {

func (u Updater) fetchEntry(url string) (Entry, error) {
var entry Entry
r, err := u.fetchURL(url)
body, err := u.fetchURL(url)
if err != nil {
return Entry{}, xerrors.Errorf("unable to fetch: %w", err)
} else if r == nil {
} else if body == nil {
return Entry{}, xerrors.Errorf("unable to get entry from %q", url)
}
defer r.Close()

if err = json.NewDecoder(r).Decode(&entry); err != nil {
if err = json.Unmarshal(body, &entry); err != nil {
return Entry{}, xerrors.Errorf("unable to decode response for %q: %w", url, err)
}
return entry, nil
}

func (u Updater) fetchURL(url string) (io.ReadCloser, error) {
// errRetry signals fetchURL to retry. retryAfter is the server-mandated
// minimum wait (rate limit); zero means apply the caller's backoff.
type errRetry struct{ retryAfter time.Duration }

func (errRetry) Error() string { return "retryable" }

func (u Updater) fetchURL(url string) ([]byte, error) {
var c http.Client
for i := 0; i <= u.retry; i++ {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, xerrors.Errorf("unable to build request for %q: %w", url, err)
}
if u.apiKey != "" {
req.Header.Set("apiKey", u.apiKey)
}

resp, err := c.Do(req)
if err != nil {
slog.Error("Response error. Try to get the entry again.", slog.String("error", err.Error()))
continue
}
switch resp.StatusCode {
case http.StatusForbidden, http.StatusTooManyRequests:
slog.Error("NVD rate limit. Wait to gain access.")
ra := u.retryAfter
// NVD returns the `Retry-After` header as 0.
// But if they start setting a non-zero value, we can use that duration.
if headerRetry := resp.Header.Get("Retry-After"); headerRetry != "0" {
hRetry, err := time.ParseDuration(headerRetry)
if err == nil {
ra = hRetry
}
body, err := u.doRequest(&c, url)
var re errRetry
switch {
case err == nil:
return body, nil
case errors.As(err, &re):
wait := re.retryAfter
if wait == 0 {
wait = time.Duration(i) * time.Second
}
// NVD limits:
// Without API key: 5 requests / 30 seconds window
// With API key: 50 requests / 30 seconds window
time.Sleep(ra)
continue
case http.StatusServiceUnavailable, http.StatusRequestTimeout, http.StatusBadGateway, http.StatusGatewayTimeout:
slog.Error("NVD API is unstable. Try to fetch URL again.", slog.String("status_code", resp.Status))
// NVD API works unstable
time.Sleep(time.Duration(i) * time.Second)
continue
case http.StatusOK:
return resp.Body, nil
time.Sleep(wait)
default:
return nil, xerrors.Errorf("unexpected status code: %s", resp.Status)
return nil, err
}

}
return nil, xerrors.Errorf("unable to fetch url. Retry limit exceeded.")
}

// doRequest performs a single NVD request attempt and closes the response body
// before returning. It returns the response body on success, or errRetry to
// signal that fetchURL should retry the request.
func (u Updater) doRequest(c *http.Client, url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, xerrors.Errorf("unable to build request for %q: %w", url, err)
}
if u.apiKey != "" {
req.Header.Set("apiKey", u.apiKey)
}

resp, err := c.Do(req)
if err != nil {
slog.Error("Response error. Try to get the entry again.", slog.String("error", err.Error()))
return nil, errRetry{}
}
defer resp.Body.Close()

switch resp.StatusCode {
case http.StatusForbidden, http.StatusTooManyRequests:
slog.Error("NVD rate limit. Wait to gain access.")
ra := u.retryAfter
// NVD returns the `Retry-After` header as 0.
// But if they start setting a non-zero value, we can use that duration.
if headerRetry := resp.Header.Get("Retry-After"); headerRetry != "0" {
if hRetry, err := time.ParseDuration(headerRetry); err == nil {
ra = hRetry
}
}
// NVD limits:
// Without API key: 5 requests / 30 seconds window
// With API key: 50 requests / 30 seconds window
return nil, errRetry{retryAfter: ra}
case http.StatusServiceUnavailable, http.StatusRequestTimeout, http.StatusBadGateway, http.StatusGatewayTimeout, statusOriginTimeout:
slog.Error("NVD API is unstable. Try to fetch URL again.", slog.String("status_code", resp.Status))
// NVD API works unstable
return nil, errRetry{}
case http.StatusOK:
// Read the body here so that a transient error while reading the response
// (e.g. HTTP/2 `INTERNAL_ERROR` when NVD aborts the stream mid-body) is
// retried instead of failing the whole run.
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Error("Failed to read NVD response body. Try to fetch URL again.", slog.String("error", err.Error()))
return nil, errRetry{}
}
return body, nil
default:
return nil, xerrors.Errorf("unexpected status code: %s", resp.Status)
}
}

// TimeIntervals returns time intervals for NVD API
// NVD API doesn't allow to get more than 120 days per request.
// So we need to split the time range into intervals.
Expand Down
106 changes: 106 additions & 0 deletions nvd/nvd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ func TestUpdate(t *testing.T) {
"last_updated.json",
},
},
{
name: "happy path 1 page after reconnect (524)",
maxResultsPerPage: 10,
wantApiKey: "test_api_key",
retry: 1,
lastUpdatedTime: time.Date(2023, 11, 26, 0, 0, 0, 0, time.UTC),
fakeTimeNow: time.Date(2023, 11, 28, 0, 0, 0, 0, time.UTC),
respFiles: map[string]string{
"resultsPerPage=1&startIndex=0": "testdata/fixtures/rootResp.json",
"resultsPerPage=10&startIndex=0": "testdata/fixtures/respPageFull.json",
},
respStatus: 524,
wantFiles: []string{
filepath.Join("api", "2020", "CVE-2020-8167.json"),
filepath.Join("api", "2021", "CVE-2021-22903.json"),
filepath.Join("api", "2021", "CVE-2021-3881.json"),
"last_updated.json",
},
},
{
name: "happy path 2 pages",
maxResultsPerPage: 2,
Expand Down Expand Up @@ -153,6 +172,15 @@ func TestUpdate(t *testing.T) {
respStatus: 504,
wantError: "unable to fetch url",
},
{
name: "524 response",
maxResultsPerPage: 10,
wantApiKey: "test_api_key",
lastUpdatedTime: time.Date(2023, 11, 26, 0, 0, 0, 0, time.UTC),
fakeTimeNow: time.Date(2023, 11, 28, 0, 0, 0, 0, time.UTC),
respStatus: 524,
wantError: "unable to fetch url",
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -229,6 +257,84 @@ func TestUpdate(t *testing.T) {
}
}

// TestUpdate_RetryOnBodyReadError verifies that a transient error while reading
// the response body (e.g. HTTP/2 `INTERNAL_ERROR` when NVD aborts the stream
// mid-body) is retried instead of failing the whole run.
func TestUpdate_RetryOnBodyReadError(t *testing.T) {
t.Setenv("NVD_API_KEY", "test_api_key")

tmpDir := t.TempDir()
savedVulnListDir := utils.VulnListDir()
utils.SetVulnListDir(tmpDir)
defer utils.SetVulnListDir(savedVulnListDir)

err := utils.SetLastUpdatedDate("api", time.Date(2023, 11, 26, 0, 0, 0, 0, time.UTC))
require.NoError(t, err)

respFiles := map[string]string{
"resultsPerPage=1&startIndex=0": "testdata/fixtures/rootResp.json",
"resultsPerPage=10&startIndex=0": "testdata/fixtures/respPageFull.json",
}

var bodyBrokenOnce bool
mux := http.NewServeMux()
mux.HandleFunc("/", func(resp http.ResponseWriter, req *http.Request) {
if !bodyBrokenOnce {
bodyBrokenOnce = true
// Declare a larger body than we actually send, then abort the
// connection so the client fails with an unexpected EOF on read.
hj, ok := resp.(http.Hijacker)
require.True(t, ok)
conn, _, err := hj.Hijack()
require.NoError(t, err)
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 1024\r\n\r\npartial"))
conn.Close()
return
}

var filePath string
for params, path := range respFiles {
if strings.Contains(req.URL.String(), params) {
filePath = path
break
}
}
if filePath == "" {
t.Errorf("response files doesn't exist for %q", req.URL.String())
}

b, err := os.ReadFile(filePath)
require.NoError(t, err)

_, err = resp.Write(b)
require.NoError(t, err)
})
ts := httptest.NewServer(mux)
defer ts.Close()

u := nvd.NewUpdater(nvd.WithBaseURL(ts.URL), nvd.WithMaxResultsPerPage(10),
nvd.WithRetry(1), nvd.WithLastModEndDate(time.Date(2023, 11, 28, 0, 0, 0, 0, time.UTC)),
nvd.WithRetryAfter(1*time.Second))
err = u.Update()
require.NoError(t, err)

wantFiles := []string{
filepath.Join("api", "2020", "CVE-2020-8167.json"),
filepath.Join("api", "2021", "CVE-2021-22903.json"),
filepath.Join("api", "2021", "CVE-2021-3881.json"),
"last_updated.json",
}
for _, wantFile := range wantFiles {
got, err := os.ReadFile(filepath.Join(tmpDir, wantFile))
require.NoError(t, err)

want, err := os.ReadFile(filepath.Join("testdata", "golden", wantFile))
require.NoError(t, err)

require.JSONEq(t, string(want), string(got))
}
}

func TestTimeIntervals(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading