66 "bytes"
77 "compress/gzip"
88 "context"
9+ "crypto/sha256"
10+ "encoding/hex"
11+ "errors"
912 "fmt"
1013 "io"
1114 "net/http"
@@ -17,6 +20,11 @@ import (
1720 "time"
1821)
1922
23+ const (
24+ maxBinarySize = 100 * 1024 * 1024 // 100 MB guard against runaway downloads
25+ maxChecksumSize = 1 << 20 // 1 MB, well above any real checksums.txt
26+ )
27+
2028// UpdateDeps holds the external dependencies for PerformUpdate.
2129// Tests inject fakes; production uses defaultUpdateDeps().
2230type UpdateDeps struct {
@@ -42,7 +50,9 @@ func defaultUpdateDeps(version string) UpdateDeps {
4250 CheckLatest : func (ctx context.Context ) (* ReleaseInfo , error ) {
4351 return Check (ctx , version )
4452 },
45- Download : defaultDownload ,
53+ Download : func (ctx context.Context , v string ) ([]byte , error ) {
54+ return defaultDownload (ctx , http .DefaultClient , v )
55+ },
4656 RunVersion : defaultRunVersion ,
4757 CurrentVersion : version ,
4858 }
@@ -74,30 +84,64 @@ func assetDownloadURL(version string) string {
7484 return fmt .Sprintf ("https://github.com/%s/releases/download/%s/%s" , repo , version , assetName (version ))
7585}
7686
77- func defaultDownload (ctx context.Context , version string ) ([]byte , error ) {
78- downloadURL := assetDownloadURL (version )
79- req , err := http .NewRequestWithContext (ctx , "GET" , downloadURL , nil )
87+ func defaultDownload (ctx context.Context , client HTTPDoer , version string ) ([]byte , error ) {
88+ // Download the release archive.
89+ archiveData , err := fetchURL (ctx , client , assetDownloadURL (version ), maxBinarySize )
90+ if err != nil {
91+ return nil , fmt .Errorf ("downloading asset: %w" , err )
92+ }
93+
94+ // Download checksums.txt and verify the archive before extracting.
95+ checksumURL := fmt .Sprintf ("https://github.com/%s/releases/download/%s/checksums.txt" , repo , version )
96+ checksumData , err := fetchURL (ctx , client , checksumURL , maxChecksumSize )
8097 if err != nil {
98+ return nil , fmt .Errorf ("downloading checksums: %w" , err )
99+ }
100+ if err := verifyChecksum (archiveData , assetName (version ), checksumData ); err != nil {
81101 return nil , err
82102 }
83103
84- resp , err := http .DefaultClient .Do (req )
104+ // Extract binary from archive.
105+ return extractBinary (archiveData , version )
106+ }
107+
108+ // fetchURL performs a GET request and returns the response body, bounded by limit bytes.
109+ func fetchURL (ctx context.Context , client HTTPDoer , url string , limit int64 ) ([]byte , error ) {
110+ req , err := http .NewRequestWithContext (ctx , "GET" , url , nil )
85111 if err != nil {
86- return nil , fmt .Errorf ("downloading asset: %w" , err )
112+ return nil , err
113+ }
114+ resp , err := client .Do (req )
115+ if err != nil {
116+ return nil , err
87117 }
88118 defer func () { _ = resp .Body .Close () }()
89-
90119 if resp .StatusCode != http .StatusOK {
91- return nil , fmt .Errorf ("asset download returned %d " , resp .StatusCode )
120+ return nil , fmt .Errorf ("HTTP %d for %s " , resp .StatusCode , url )
92121 }
122+ return io .ReadAll (io .LimitReader (resp .Body , limit ))
123+ }
93124
94- archiveData , err := io .ReadAll (resp .Body )
95- if err != nil {
96- return nil , fmt .Errorf ("reading asset: %w" , err )
125+ // verifyChecksum checks the SHA256 of data against the entry for filename in
126+ // a checksums.txt file (sha256sum format: "<hash> <filename>").
127+ func verifyChecksum (data []byte , filename string , checksumFile []byte ) error {
128+ for _ , line := range strings .Split (string (checksumFile ), "\n " ) {
129+ line = strings .TrimSpace (line )
130+ if line == "" {
131+ continue
132+ }
133+ parts := strings .Fields (line )
134+ if len (parts ) != 2 || parts [1 ] != filename {
135+ continue
136+ }
137+ sum := sha256 .Sum256 (data )
138+ actual := hex .EncodeToString (sum [:])
139+ if actual != parts [0 ] {
140+ return fmt .Errorf ("checksum mismatch for %s: got %s, want %s" , filename , actual , parts [0 ])
141+ }
142+ return nil
97143 }
98-
99- // Extract binary from archive.
100- return extractBinary (archiveData , version )
144+ return fmt .Errorf ("no checksum entry found for %s in checksums.txt" , filename )
101145}
102146
103147// extractBinary extracts the workload-exporter binary from a .tar.gz or .zip archive.
@@ -120,14 +164,14 @@ func extractFromTarGz(data []byte, wantName string) ([]byte, error) {
120164 tr := tar .NewReader (gz )
121165 for {
122166 hdr , err := tr .Next ()
123- if err == io .EOF {
167+ if errors . Is ( err , io .EOF ) {
124168 break
125169 }
126170 if err != nil {
127171 return nil , fmt .Errorf ("reading tar: %w" , err )
128172 }
129173 if hdr .Name == wantName || filepath .Base (hdr .Name ) == wantName {
130- return io .ReadAll (tr )
174+ return io .ReadAll (io . LimitReader ( tr , maxBinarySize ) )
131175 }
132176 }
133177 return nil , fmt .Errorf ("binary %s not found in archive" , wantName )
@@ -145,7 +189,7 @@ func extractFromZip(data []byte, wantName string) ([]byte, error) {
145189 return nil , err
146190 }
147191 defer func () { _ = rc .Close () }()
148- return io .ReadAll (rc )
192+ return io .ReadAll (io . LimitReader ( rc , maxBinarySize ) )
149193 }
150194 }
151195 return nil , fmt .Errorf ("binary %s not found in archive" , wantName )
@@ -180,7 +224,7 @@ func performUpdate(ctx context.Context, w io.Writer, deps UpdateDeps) error {
180224 return nil
181225 }
182226
183- if release .TagName == deps .CurrentVersion {
227+ if ! semverGreater ( release .TagName , deps .CurrentVersion ) {
184228 _ , _ = fmt .Fprintf (w , "already up to date (%s)\n " , deps .CurrentVersion )
185229 return nil
186230 }
@@ -232,6 +276,9 @@ func performUpdate(ctx context.Context, w io.Writer, deps UpdateDeps) error {
232276 if ! strings .Contains (newVersion , "workload-exporter" ) {
233277 return fmt .Errorf ("sanity check failed: unexpected version output: %s" , newVersion )
234278 }
279+ if ! strings .Contains (newVersion , release .TagName ) {
280+ return fmt .Errorf ("sanity check failed: version output %q does not contain expected %s" , newVersion , release .TagName )
281+ }
235282
236283 _ , _ = fmt .Fprintf (w , "verified: %s\n " , newVersion )
237284
@@ -265,9 +312,9 @@ func copyFile(src, dst string) error {
265312 if err != nil {
266313 return err
267314 }
268- defer func () { _ = out .Close () }()
269315
270316 if _ , err := io .Copy (out , in ); err != nil {
317+ _ = out .Close ()
271318 return err
272319 }
273320 return out .Close ()
0 commit comments