Skip to content

Commit ce0856c

Browse files
committed
fix: move internal/ to pkg/, verify checksum, minor fixes
1 parent 6f1f4e3 commit ce0856c

File tree

7 files changed

+287
-44
lines changed

7 files changed

+287
-44
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ ls export-contents/*.schema.txt
171171
- **Schema only** - Table schemas are exported, but **no actual table data** is included
172172
- **Read-only** - The tool only reads data and makes no modifications to your cluster
173173
- **Local export** - All data is written to a local zip file under your control
174+
- **Verified updates** - `workload-exporter update` verifies the SHA256 checksum of the downloaded binary against the checksums published with each GitHub release before installing
174175

175176
## Requirements
176177

cmd/update.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"os"
77
"time"
88

9-
"github.com/cockroachlabs/workload-exporter/internal/update"
9+
"github.com/cockroachlabs/workload-exporter/pkg/update"
1010
"github.com/spf13/cobra"
1111
)
1212

@@ -27,7 +27,9 @@ func newUpdateCmd() *cobra.Command {
2727
return runUpdateCheck(cmd.Context())
2828
}
2929

30-
return update.PerformUpdate(cmd.Context(), os.Stdout, Version)
30+
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Minute)
31+
defer cancel()
32+
return update.PerformUpdate(ctx, os.Stdout, Version)
3133
},
3234
}
3335

@@ -53,7 +55,7 @@ func runUpdateCheck(ctx context.Context) error {
5355
return fmt.Errorf("update check failed: %w", err)
5456
}
5557

56-
if result.TagName == Version {
58+
if !update.IsNewer(result.TagName, Version) {
5759
fmt.Println("up to date")
5860
} else {
5961
fmt.Printf("\nnew version available: %s (current: %s)\n", result.TagName, Version)

docs/DEVELOPMENT.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This guide is for developers who want to build, test, or contribute to the workl
44

55
## Prerequisites
66

7-
- Go 1.18 or later
7+
- Go 1.23 or later
88
- Git
99

1010
## Building from Source
@@ -64,11 +64,17 @@ workload-exporter/
6464
├── cmd/ # CLI commands
6565
│ ├── root.go # Root command
6666
│ ├── export.go # Export command
67+
│ ├── update.go # Update command
6768
│ └── version.go # Version command
6869
├── pkg/
69-
│ └── export/ # Core export functionality
70-
│ ├── exporter.go # Main exporter logic
71-
│ └── exporter_test.go # Unit tests
70+
│ ├── export/ # Core export functionality
71+
│ │ ├── exporter.go # Main exporter logic
72+
│ │ └── exporter_test.go # Unit tests
73+
│ └── update/ # Self-update functionality
74+
│ ├── update.go # Update check, caching, and semver comparison
75+
│ ├── selfupdate.go # Binary download, checksum verification, and replacement
76+
│ ├── update_test.go # Unit tests for update checking
77+
│ └── selfupdate_test.go # Unit tests for self-update
7278
├── docs/ # Documentation
7379
├── Makefile # Build automation
7480
└── go.mod # Go dependencies
Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"bytes"
77
"compress/gzip"
88
"context"
9+
"crypto/sha256"
10+
"encoding/hex"
11+
"errors"
912
"fmt"
1013
"io"
1114
"net/http"
@@ -17,6 +20,11 @@ import (
1720
"time"
1821
)
1922

23+
const (
24+
maxBinarySize = 100 * 1024 * 1024 // 100 MB guard against runaway downloads
25+
maxChecksumSize = 1 << 20 // 1 MB, well above any real checksums.txt
26+
)
27+
2028
// UpdateDeps holds the external dependencies for PerformUpdate.
2129
// Tests inject fakes; production uses defaultUpdateDeps().
2230
type UpdateDeps struct {
@@ -42,7 +50,9 @@ func defaultUpdateDeps(version string) UpdateDeps {
4250
CheckLatest: func(ctx context.Context) (*ReleaseInfo, error) {
4351
return Check(ctx, version)
4452
},
45-
Download: defaultDownload,
53+
Download: func(ctx context.Context, v string) ([]byte, error) {
54+
return defaultDownload(ctx, http.DefaultClient, v)
55+
},
4656
RunVersion: defaultRunVersion,
4757
CurrentVersion: version,
4858
}
@@ -74,30 +84,64 @@ func assetDownloadURL(version string) string {
7484
return fmt.Sprintf("https://github.com/%s/releases/download/%s/%s", repo, version, assetName(version))
7585
}
7686

77-
func defaultDownload(ctx context.Context, version string) ([]byte, error) {
78-
downloadURL := assetDownloadURL(version)
79-
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
87+
func defaultDownload(ctx context.Context, client HTTPDoer, version string) ([]byte, error) {
88+
// Download the release archive.
89+
archiveData, err := fetchURL(ctx, client, assetDownloadURL(version), maxBinarySize)
90+
if err != nil {
91+
return nil, fmt.Errorf("downloading asset: %w", err)
92+
}
93+
94+
// Download checksums.txt and verify the archive before extracting.
95+
checksumURL := fmt.Sprintf("https://github.com/%s/releases/download/%s/checksums.txt", repo, version)
96+
checksumData, err := fetchURL(ctx, client, checksumURL, maxChecksumSize)
8097
if err != nil {
98+
return nil, fmt.Errorf("downloading checksums: %w", err)
99+
}
100+
if err := verifyChecksum(archiveData, assetName(version), checksumData); err != nil {
81101
return nil, err
82102
}
83103

84-
resp, err := http.DefaultClient.Do(req)
104+
// Extract binary from archive.
105+
return extractBinary(archiveData, version)
106+
}
107+
108+
// fetchURL performs a GET request and returns the response body, bounded by limit bytes.
109+
func fetchURL(ctx context.Context, client HTTPDoer, url string, limit int64) ([]byte, error) {
110+
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
85111
if err != nil {
86-
return nil, fmt.Errorf("downloading asset: %w", err)
112+
return nil, err
113+
}
114+
resp, err := client.Do(req)
115+
if err != nil {
116+
return nil, err
87117
}
88118
defer func() { _ = resp.Body.Close() }()
89-
90119
if resp.StatusCode != http.StatusOK {
91-
return nil, fmt.Errorf("asset download returned %d", resp.StatusCode)
120+
return nil, fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
92121
}
122+
return io.ReadAll(io.LimitReader(resp.Body, limit))
123+
}
93124

94-
archiveData, err := io.ReadAll(resp.Body)
95-
if err != nil {
96-
return nil, fmt.Errorf("reading asset: %w", err)
125+
// verifyChecksum checks the SHA256 of data against the entry for filename in
126+
// a checksums.txt file (sha256sum format: "<hash> <filename>").
127+
func verifyChecksum(data []byte, filename string, checksumFile []byte) error {
128+
for _, line := range strings.Split(string(checksumFile), "\n") {
129+
line = strings.TrimSpace(line)
130+
if line == "" {
131+
continue
132+
}
133+
parts := strings.Fields(line)
134+
if len(parts) != 2 || parts[1] != filename {
135+
continue
136+
}
137+
sum := sha256.Sum256(data)
138+
actual := hex.EncodeToString(sum[:])
139+
if actual != parts[0] {
140+
return fmt.Errorf("checksum mismatch for %s: got %s, want %s", filename, actual, parts[0])
141+
}
142+
return nil
97143
}
98-
99-
// Extract binary from archive.
100-
return extractBinary(archiveData, version)
144+
return fmt.Errorf("no checksum entry found for %s in checksums.txt", filename)
101145
}
102146

103147
// extractBinary extracts the workload-exporter binary from a .tar.gz or .zip archive.
@@ -120,14 +164,14 @@ func extractFromTarGz(data []byte, wantName string) ([]byte, error) {
120164
tr := tar.NewReader(gz)
121165
for {
122166
hdr, err := tr.Next()
123-
if err == io.EOF {
167+
if errors.Is(err, io.EOF) {
124168
break
125169
}
126170
if err != nil {
127171
return nil, fmt.Errorf("reading tar: %w", err)
128172
}
129173
if hdr.Name == wantName || filepath.Base(hdr.Name) == wantName {
130-
return io.ReadAll(tr)
174+
return io.ReadAll(io.LimitReader(tr, maxBinarySize))
131175
}
132176
}
133177
return nil, fmt.Errorf("binary %s not found in archive", wantName)
@@ -145,7 +189,7 @@ func extractFromZip(data []byte, wantName string) ([]byte, error) {
145189
return nil, err
146190
}
147191
defer func() { _ = rc.Close() }()
148-
return io.ReadAll(rc)
192+
return io.ReadAll(io.LimitReader(rc, maxBinarySize))
149193
}
150194
}
151195
return nil, fmt.Errorf("binary %s not found in archive", wantName)
@@ -180,7 +224,7 @@ func performUpdate(ctx context.Context, w io.Writer, deps UpdateDeps) error {
180224
return nil
181225
}
182226

183-
if release.TagName == deps.CurrentVersion {
227+
if !semverGreater(release.TagName, deps.CurrentVersion) {
184228
_, _ = fmt.Fprintf(w, "already up to date (%s)\n", deps.CurrentVersion)
185229
return nil
186230
}
@@ -232,6 +276,9 @@ func performUpdate(ctx context.Context, w io.Writer, deps UpdateDeps) error {
232276
if !strings.Contains(newVersion, "workload-exporter") {
233277
return fmt.Errorf("sanity check failed: unexpected version output: %s", newVersion)
234278
}
279+
if !strings.Contains(newVersion, release.TagName) {
280+
return fmt.Errorf("sanity check failed: version output %q does not contain expected %s", newVersion, release.TagName)
281+
}
235282

236283
_, _ = fmt.Fprintf(w, "verified: %s\n", newVersion)
237284

@@ -265,9 +312,9 @@ func copyFile(src, dst string) error {
265312
if err != nil {
266313
return err
267314
}
268-
defer func() { _ = out.Close() }()
269315

270316
if _, err := io.Copy(out, in); err != nil {
317+
_ = out.Close()
271318
return err
272319
}
273320
return out.Close()
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package update
33
import (
44
"bytes"
55
"context"
6+
"crypto/sha256"
7+
"encoding/hex"
68
"fmt"
79
"os"
810
"path/filepath"
@@ -162,6 +164,106 @@ func TestPerformUpdate_BadVersionOutput(t *testing.T) {
162164
}
163165
}
164166

167+
func TestPerformUpdate_VersionMismatch(t *testing.T) {
168+
_, binaryPath := setupBinary(t)
169+
deps := fakeDeps(t, binaryPath)
170+
// RunVersion returns the wrong version tag.
171+
deps.RunVersion = func(_ context.Context, _ string) (string, error) {
172+
return "workload-exporter version v1.0.0", nil
173+
}
174+
175+
var buf bytes.Buffer
176+
err := performUpdate(context.Background(), &buf, deps)
177+
if err == nil {
178+
t.Fatal("expected error for version mismatch, got nil")
179+
}
180+
if !strings.Contains(err.Error(), "sanity check failed") {
181+
t.Errorf("error = %q, want it to contain 'sanity check failed'", err)
182+
}
183+
}
184+
185+
func TestPerformUpdate_Downgrade(t *testing.T) {
186+
_, binaryPath := setupBinary(t)
187+
deps := fakeDeps(t, binaryPath)
188+
// Simulate a case where the installed binary is already newer than "latest".
189+
deps.CurrentVersion = "v3.0.0"
190+
deps.CheckLatest = func(_ context.Context) (*ReleaseInfo, error) {
191+
return &ReleaseInfo{TagName: "v2.0.0"}, nil
192+
}
193+
194+
var buf bytes.Buffer
195+
err := performUpdate(context.Background(), &buf, deps)
196+
if err != nil {
197+
t.Fatalf("unexpected error: %v", err)
198+
}
199+
if !strings.Contains(buf.String(), "already up to date") {
200+
t.Errorf("expected 'already up to date' for downgrade scenario:\n%s", buf.String())
201+
}
202+
203+
// Binary should be untouched.
204+
data, _ := os.ReadFile(binaryPath)
205+
if string(data) != "old-binary" {
206+
t.Error("binary should not have been replaced in downgrade scenario")
207+
}
208+
}
209+
210+
// --- verifyChecksum tests ---
211+
212+
func checksumLine(data []byte, filename string) string {
213+
sum := sha256.Sum256(data)
214+
return hex.EncodeToString(sum[:]) + " " + filename + "\n"
215+
}
216+
217+
func TestVerifyChecksum_Valid(t *testing.T) {
218+
data := []byte("binary content")
219+
checksumFile := checksumLine(data, "myfile.tar.gz")
220+
221+
if err := verifyChecksum(data, "myfile.tar.gz", []byte(checksumFile)); err != nil {
222+
t.Errorf("unexpected error: %v", err)
223+
}
224+
}
225+
226+
func TestVerifyChecksum_ValidAmongMultipleEntries(t *testing.T) {
227+
data := []byte("binary content")
228+
checksumFile := "aaaa other-file.zip\n" + checksumLine(data, "myfile.tar.gz") + "bbbb another.tar.gz\n"
229+
230+
if err := verifyChecksum(data, "myfile.tar.gz", []byte(checksumFile)); err != nil {
231+
t.Errorf("unexpected error: %v", err)
232+
}
233+
}
234+
235+
func TestVerifyChecksum_Mismatch(t *testing.T) {
236+
data := []byte("binary content")
237+
wrong := strings.Repeat("a", 64) + " myfile.tar.gz\n"
238+
239+
err := verifyChecksum(data, "myfile.tar.gz", []byte(wrong))
240+
if err == nil {
241+
t.Fatal("expected checksum mismatch error")
242+
}
243+
if !strings.Contains(err.Error(), "checksum mismatch") {
244+
t.Errorf("error %q should mention 'checksum mismatch'", err)
245+
}
246+
}
247+
248+
func TestVerifyChecksum_NotFound(t *testing.T) {
249+
checksumFile := "abc123 other-file.tar.gz\n"
250+
251+
err := verifyChecksum([]byte("data"), "myfile.tar.gz", []byte(checksumFile))
252+
if err == nil {
253+
t.Fatal("expected not-found error")
254+
}
255+
if !strings.Contains(err.Error(), "no checksum entry") {
256+
t.Errorf("error %q should mention 'no checksum entry'", err)
257+
}
258+
}
259+
260+
func TestVerifyChecksum_EmptyFile(t *testing.T) {
261+
err := verifyChecksum([]byte("data"), "myfile.tar.gz", []byte(""))
262+
if err == nil {
263+
t.Fatal("expected error for empty checksums file")
264+
}
265+
}
266+
165267
func TestPerformUpdate_StagedFileCleanedUp(t *testing.T) {
166268
_, binaryPath := setupBinary(t)
167269
deps := fakeDeps(t, binaryPath)

0 commit comments

Comments
 (0)