Skip to content

Commit 15b43e1

Browse files
authored
Merge pull request #18278 from hakman/stream-container-images
nodeup: stream verified image bytes into ctr import
2 parents 2ae8ea1 + 30d83e0 commit 15b43e1

6 files changed

Lines changed: 315 additions & 120 deletions

File tree

upup/pkg/fi/http.go

Lines changed: 93 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ import (
2323
"net"
2424
"net/http"
2525
"os"
26-
"path"
26+
"path/filepath"
2727
"time"
2828

2929
"k8s.io/klog/v2"
3030
"k8s.io/kops/util/pkg/hashing"
3131
)
3232

33+
const downloadTimeout = 3 * time.Minute
34+
3335
// DownloadURL will download the file at the given url and store it as dest.
3436
// If hash is non-nil, it will also verify that it matches the hash of the downloaded file.
3537
func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, error) {
@@ -43,47 +45,103 @@ func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, er
4345
}
4446
}
4547

46-
dirMode := os.FileMode(0o755)
47-
err := downloadURLAlways(url, dest, dirMode)
48+
return downloadURLToFile(url, dest, hash)
49+
}
50+
51+
func downloadURLToFile(url string, destPath string, hash *hashing.Hash) (*hashing.Hash, error) {
52+
dir := filepath.Dir(destPath)
53+
if err := os.MkdirAll(dir, 0o755); err != nil {
54+
return nil, fmt.Errorf("error creating directories for destination file %q: %v", destPath, err)
55+
}
56+
57+
output, err := os.CreateTemp(dir, "."+filepath.Base(destPath)+".tmp")
58+
if err != nil {
59+
return nil, fmt.Errorf("error creating temporary file for download %q: %v", destPath, err)
60+
}
61+
tempPath := output.Name()
62+
defer os.Remove(tempPath)
63+
64+
actual, err := downloadURLToWriter(url, output, hash)
65+
if closeErr := output.Close(); closeErr != nil && err == nil {
66+
err = closeErr
67+
}
4868
if err != nil {
4969
return nil, err
5070
}
71+
if err := os.Chmod(tempPath, 0o644); err != nil {
72+
return nil, fmt.Errorf("error setting mode on downloaded file %q: %v", tempPath, err)
73+
}
74+
if err := os.Rename(tempPath, destPath); err != nil {
75+
return nil, fmt.Errorf("error moving downloaded file %q to %q: %v", tempPath, destPath, err)
76+
}
77+
return actual, nil
78+
}
5179

80+
// downloadURLToWriter streams the file at the given url to dest.
81+
// If hash is non-nil, it will also verify that it matches the downloaded bytes.
82+
func downloadURLToWriter(url string, dest io.Writer, hash *hashing.Hash) (*hashing.Hash, error) {
83+
responseBody, err := OpenURL(url)
84+
if err != nil {
85+
return nil, err
86+
}
87+
defer responseBody.Close()
88+
89+
start := time.Now()
90+
defer func() {
91+
klog.V(2).Infof("Downloading %q took %q", url, time.Since(start))
92+
}()
93+
klog.V(2).Infof("Downloading %q", url)
94+
95+
algorithm := hashing.HashAlgorithmSHA256
5296
if hash != nil {
53-
match, err := fileHasHash(dest, hash)
54-
if err != nil {
55-
return nil, err
56-
}
57-
if !match {
58-
return nil, fmt.Errorf("downloaded from %q but hash did not match expected %q", url, hash)
59-
}
60-
} else {
61-
hash, err = hashing.HashAlgorithmSHA256.HashFile(dest)
62-
if err != nil {
63-
return nil, err
64-
}
97+
algorithm = hash.Algorithm
6598
}
99+
hasher := algorithm.NewHasher()
100+
writer := io.MultiWriter(dest, hasher)
66101

67-
return hash, nil
102+
if _, err := io.Copy(writer, responseBody); err != nil {
103+
return nil, fmt.Errorf("error downloading HTTP content from %q: %v", url, err)
104+
}
105+
106+
actual := &hashing.Hash{
107+
Algorithm: algorithm,
108+
HashValue: hasher.Sum(nil),
109+
}
110+
if hash != nil && !actual.Equal(hash) {
111+
return nil, fmt.Errorf("downloaded from %q but hash did not match expected %q", url, hash)
112+
}
113+
return actual, nil
68114
}
69115

70-
func downloadURLAlways(url string, destPath string, dirMode os.FileMode) error {
71-
err := os.MkdirAll(path.Dir(destPath), dirMode)
116+
// OpenURL opens a hardened HTTP GET stream for url.
117+
func OpenURL(url string) (io.ReadCloser, error) {
118+
httpClient := newDownloadHTTPClient()
119+
120+
ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout)
121+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
72122
if err != nil {
73-
return fmt.Errorf("error creating directories for destination file %q: %v", destPath, err)
123+
cancel()
124+
return nil, fmt.Errorf("cannot create request: %v", err)
74125
}
75126

76-
output, err := os.Create(destPath)
127+
response, err := httpClient.Do(req)
77128
if err != nil {
78-
return fmt.Errorf("error creating file for download %q: %v", destPath, err)
129+
cancel()
130+
return nil, fmt.Errorf("error doing HTTP fetch of %q: %v", url, err)
79131
}
80-
defer output.Close()
81132

82-
klog.V(2).Infof("Downloading %q", url)
133+
// http.Client follows 3xx automatically, so anything outside 2xx that reaches us is a bug or a missing Location.
134+
if response.StatusCode < 200 || response.StatusCode > 299 {
135+
response.Body.Close()
136+
cancel()
137+
return nil, fmt.Errorf("unexpected response from %q: HTTP %s", url, response.Status)
138+
}
139+
140+
return &cancelOnCloseReadCloser{ReadCloser: response.Body, cancel: cancel}, nil
141+
}
83142

84-
// Create a client with custom timeouts
85-
// to avoid idle downloads to hang the program
86-
httpClient := &http.Client{
143+
func newDownloadHTTPClient() *http.Client {
144+
return &http.Client{
87145
Transport: &http.Transport{
88146
Proxy: http.ProxyFromEnvironment,
89147
DialContext: (&net.Dialer{
@@ -95,35 +153,15 @@ func downloadURLAlways(url string, destPath string, dirMode os.FileMode) error {
95153
IdleConnTimeout: 30 * time.Second,
96154
},
97155
}
156+
}
98157

99-
// this will stop slow downloads after 3 minutes
100-
// and interrupt reading of the Response.Body
101-
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
102-
defer cancel()
103-
104-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
105-
if err != nil {
106-
return fmt.Errorf("Cannot create request: %v", err)
107-
}
108-
109-
response, err := httpClient.Do(req)
110-
if err != nil {
111-
return fmt.Errorf("error doing HTTP fetch of %q: %v", url, err)
112-
}
113-
defer response.Body.Close()
114-
115-
if response.StatusCode >= 400 {
116-
return fmt.Errorf("error response from %q: HTTP %v", url, response.StatusCode)
117-
}
118-
119-
start := time.Now()
120-
defer func() {
121-
klog.V(2).Infof("Copying %q to %q took %q", url, destPath, time.Since(start))
122-
}()
158+
type cancelOnCloseReadCloser struct {
159+
io.ReadCloser
160+
cancel context.CancelFunc
161+
}
123162

124-
_, err = io.Copy(output, response.Body)
125-
if err != nil {
126-
return fmt.Errorf("error downloading HTTP content from %q: %v", url, err)
127-
}
128-
return nil
163+
func (r *cancelOnCloseReadCloser) Close() error {
164+
err := r.ReadCloser.Close()
165+
r.cancel()
166+
return err
129167
}

upup/pkg/fi/http_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
Copyright 2026 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package fi
18+
19+
import (
20+
"bytes"
21+
"net/http"
22+
"net/http/httptest"
23+
"os"
24+
"path/filepath"
25+
"testing"
26+
27+
"k8s.io/kops/util/pkg/hashing"
28+
)
29+
30+
func TestDownloadURLRejectsNon2xxAndPreservesDestination(t *testing.T) {
31+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32+
w.WriteHeader(http.StatusFound)
33+
_, _ = w.Write([]byte("redirect body"))
34+
}))
35+
defer server.Close()
36+
37+
dest := filepath.Join(t.TempDir(), "download")
38+
if err := os.WriteFile(dest, []byte("original"), 0o644); err != nil {
39+
t.Fatalf("WriteFile() error = %v", err)
40+
}
41+
42+
if _, err := DownloadURL(server.URL, dest, nil); err == nil {
43+
t.Fatalf("DownloadURL() expected error")
44+
}
45+
46+
actual, err := os.ReadFile(dest)
47+
if err != nil {
48+
t.Fatalf("ReadFile() error = %v", err)
49+
}
50+
if string(actual) != "original" {
51+
t.Fatalf("download destination = %q, expected original contents", actual)
52+
}
53+
}
54+
55+
func TestDownloadURLToWriterVerifiesHash(t *testing.T) {
56+
body := []byte("payload")
57+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
_, _ = w.Write(body)
59+
}))
60+
defer server.Close()
61+
62+
expectedHash, err := hashing.HashAlgorithmSHA256.Hash(bytes.NewReader(body))
63+
if err != nil {
64+
t.Fatalf("Hash() error = %v", err)
65+
}
66+
67+
var output bytes.Buffer
68+
actualHash, err := downloadURLToWriter(server.URL, &output, expectedHash)
69+
if err != nil {
70+
t.Fatalf("downloadURLToWriter() error = %v", err)
71+
}
72+
if !actualHash.Equal(expectedHash) {
73+
t.Fatalf("downloadURLToWriter() hash = %v, expected %v", actualHash, expectedHash)
74+
}
75+
if !bytes.Equal(output.Bytes(), body) {
76+
t.Fatalf("downloadURLToWriter() body = %q, expected %q", output.Bytes(), body)
77+
}
78+
}
79+
80+
func TestDownloadURLToWriterRejectsHashMismatch(t *testing.T) {
81+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82+
_, _ = w.Write([]byte("payload"))
83+
}))
84+
defer server.Close()
85+
86+
wrongHash, err := hashing.HashAlgorithmSHA256.Hash(bytes.NewReader([]byte("different")))
87+
if err != nil {
88+
t.Fatalf("Hash() error = %v", err)
89+
}
90+
91+
var output bytes.Buffer
92+
if _, err := downloadURLToWriter(server.URL, &output, wrongHash); err == nil {
93+
t.Fatalf("downloadURLToWriter() expected hash mismatch error")
94+
}
95+
}

upup/pkg/fi/nodeup/command.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"net/url"
2929
"os"
3030
"os/exec"
31+
"path"
3132
"strconv"
3233
"strings"
3334
"time"
@@ -343,13 +344,23 @@ func (c *NodeUpCommand) Run(out io.Writer) error {
343344
return fmt.Errorf("error building loader: %v", err)
344345
}
345346

346-
for i, image := range nodeupConfig.Images[architecture] {
347-
taskMap["LoadImage."+strconv.Itoa(i)] = &nodetasks.LoadImageTask{
347+
for _, image := range nodeupConfig.Images[architecture] {
348+
if len(image.Sources) == 0 {
349+
return fmt.Errorf("image has no sources: %v", image)
350+
}
351+
u, err := url.Parse(image.Sources[0])
352+
if err != nil {
353+
return fmt.Errorf("invalid image source URL %q: %w", image.Sources[0], err)
354+
}
355+
key := "SideloadImage/" + path.Base(u.Path)
356+
if _, ok := taskMap[key]; ok {
357+
return fmt.Errorf("duplicate image task %q", key)
358+
}
359+
taskMap[key] = &nodetasks.LoadImageTask{
348360
Sources: image.Sources,
349361
Hash: image.Hash,
350362
}
351363
}
352-
// Protokube load image task is in ProtokubeBuilder
353364

354365
var target fi.NodeupTarget
355366

upup/pkg/fi/utils/gzip.go renamed to upup/pkg/fi/nodeup/nodetasks/compression.go

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,25 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package utils
17+
package nodetasks
1818

1919
import (
20+
"bufio"
21+
"bytes"
2022
"compress/gzip"
2123
"io"
22-
"os"
2324
)
2425

25-
// UngzipFile extracts a .gzip file
26-
func UngzipFile(gzipPath string, destPath string) error {
27-
reader, err := os.Open(gzipPath)
28-
if err != nil {
29-
return err
30-
}
31-
defer reader.Close()
26+
var gzipMagic = []byte{0x1f, 0x8b}
3227

33-
writer, err := os.Create(destPath)
34-
if err != nil {
35-
return err
28+
func maybeGzipReader(r io.Reader) (io.ReadCloser, error) {
29+
buffered := bufio.NewReader(r)
30+
header, err := buffered.Peek(len(gzipMagic))
31+
if err != nil && err != io.EOF {
32+
return nil, err
3633
}
37-
defer writer.Close()
38-
39-
archive, err := gzip.NewReader(reader)
40-
if err != nil {
41-
return err
34+
if len(header) == len(gzipMagic) && bytes.Equal(header, gzipMagic) {
35+
return gzip.NewReader(buffered)
4236
}
43-
defer archive.Close()
44-
45-
_, err = io.Copy(writer, archive)
46-
return err
37+
return io.NopCloser(buffered), nil
4738
}

0 commit comments

Comments
 (0)