Skip to content

Commit 258bdb4

Browse files
committed
fix: enhance error handling for unsupported media types in pull operations
1 parent 9f6b0d2 commit 258bdb4

File tree

4 files changed

+119
-55
lines changed

4 files changed

+119
-55
lines changed

cmd/cli/commands/utils.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/docker/model-runner/cmd/cli/desktop"
1212
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
13+
"github.com/docker/model-runner/pkg/distribution/distribution"
1314
"github.com/docker/model-runner/pkg/distribution/oci/reference"
1415
"github.com/docker/model-runner/pkg/inference/backends/vllm"
1516
"github.com/moby/term"
@@ -53,7 +54,7 @@ func handleClientError(err error, message string) error {
5354
var buf bytes.Buffer
5455
printNextSteps(&buf, []string{enableVLLM})
5556
return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n"))
56-
} else if strings.Contains(err.Error(), "try upgrading") {
57+
} else if errors.Is(err, distribution.ErrUnsupportedMediaType) {
5758
// The model uses a newer config format than this client supports.
5859
var buf bytes.Buffer
5960
printNextSteps(&buf, []string{

cmd/cli/desktop/desktop.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,16 @@ func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, b
151151
} else {
152152
bodyStr = strings.TrimSpace(string(body))
153153
}
154-
err := fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, bodyStr)
154+
var err error
155+
if resp.StatusCode == http.StatusUnprocessableEntity {
156+
// 422 means the model uses a config type this client does not
157+
// support. Reattach the sentinel so callers can use errors.Is.
158+
err = fmt.Errorf("pulling %s failed with status %s: %w: %s",
159+
model, resp.Status, distribution.ErrUnsupportedMediaType, bodyStr)
160+
} else {
161+
err = fmt.Errorf("pulling %s failed with status %s: %s",
162+
model, resp.Status, bodyStr)
163+
}
155164
// Only retry on gateway/proxy errors (502, 503, 504).
156165
// Do not retry 500 (usually deterministic server errors) or
157166
// 4xx (client errors including 422 for unsupported media type).
@@ -245,7 +254,7 @@ func (c *Client) withRetries(
245254
}
246255
}
247256

248-
return "", progressShown, fmt.Errorf("%w (failed after %d retries)", lastErr, maxRetries)
257+
return "", progressShown, fmt.Errorf("%s failed after %d retries: %w", operationName, maxRetries, lastErr)
249258
}
250259

251260
func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {

cmd/cli/desktop/desktop_test.go

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@ import (
1010
"testing"
1111

1212
mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
13+
"github.com/docker/model-runner/pkg/distribution/distribution"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
1516
"go.uber.org/mock/gomock"
1617
)
1718

19+
// errorReadCloser is an io.ReadCloser whose Read always returns an error.
20+
type errorReadCloser struct{ err error }
21+
22+
func (e *errorReadCloser) Read(_ []byte) (int, error) { return 0, e.err }
23+
func (e *errorReadCloser) Close() error { return nil }
24+
1825
func TestPullRetryOnNetworkError(t *testing.T) {
1926
ctrl := gomock.NewController(t)
2027
defer ctrl.Finish()
@@ -100,34 +107,52 @@ func TestPullNoRetryOn422Error(t *testing.T) {
100107

101108
printer := NewSimplePrinter(func(s string) {})
102109
_, _, err := client.Pull(modelName, printer)
103-
assert.Error(t, err)
104-
assert.Contains(t, err.Error(), "try upgrading")
110+
require.Error(t, err)
111+
// The sentinel must be preserved so callers can use errors.Is.
112+
assert.True(t, errors.Is(err, distribution.ErrUnsupportedMediaType))
105113
}
106114

107-
func TestPullRetryOn502Error(t *testing.T) {
108-
ctrl := gomock.NewController(t)
109-
defer ctrl.Finish()
110-
111-
modelName := "test-model"
112-
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
113-
mockContext := NewContextForMock(mockClient)
114-
client := New(mockContext)
115-
116-
// 502 Bad Gateway is a transient proxy error and should be retried.
117-
gomock.InOrder(
118-
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
119-
StatusCode: http.StatusBadGateway,
120-
Body: io.NopCloser(bytes.NewBufferString("Bad Gateway")),
121-
}, nil),
122-
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
123-
StatusCode: http.StatusOK,
124-
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
125-
}, nil),
126-
)
115+
func TestPullRetriesOnTransientGatewayErrors(t *testing.T) {
116+
// 502 and 504 are transient gateway/proxy errors and should be retried.
117+
// Note: 503 is intercepted by doRequestWithAuthContext as ErrServiceUnavailable
118+
// and is covered separately by TestPullRetryOnServiceUnavailable.
119+
transientCodes := []struct {
120+
code int
121+
name string
122+
body string
123+
}{
124+
{http.StatusBadGateway, "502 Bad Gateway", "Bad Gateway"},
125+
{http.StatusGatewayTimeout, "504 Gateway Timeout", "Gateway Timeout"},
126+
}
127127

128-
printer := NewSimplePrinter(func(s string) {})
129-
_, _, err := client.Pull(modelName, printer)
130-
assert.NoError(t, err)
128+
for _, tc := range transientCodes {
129+
t.Run(tc.name, func(t *testing.T) {
130+
ctrl := gomock.NewController(t)
131+
defer ctrl.Finish()
132+
133+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
134+
mockContext := NewContextForMock(mockClient)
135+
client := New(mockContext)
136+
137+
// First attempt fails with the transient error, second succeeds.
138+
gomock.InOrder(
139+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
140+
StatusCode: tc.code,
141+
Body: io.NopCloser(bytes.NewBufferString(tc.body)),
142+
}, nil),
143+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
144+
StatusCode: http.StatusOK,
145+
Body: io.NopCloser(bytes.NewBufferString(
146+
`{"type":"success","message":"Model pulled successfully"}`,
147+
)),
148+
}, nil),
149+
)
150+
151+
printer := NewSimplePrinter(func(s string) {})
152+
_, _, err := client.Pull("test-model", printer)
153+
assert.NoError(t, err)
154+
})
155+
}
131156
}
132157

133158
func TestPullRetryOnServiceUnavailable(t *testing.T) {
@@ -172,7 +197,7 @@ func TestPullMaxRetriesExhausted(t *testing.T) {
172197
printer := NewSimplePrinter(func(s string) {})
173198
_, _, err := client.Pull(modelName, printer)
174199
assert.Error(t, err)
175-
assert.Contains(t, err.Error(), "(failed after 3 retries)")
200+
assert.Contains(t, err.Error(), "download failed after 3 retries")
176201
}
177202

178203
func TestPushRetryOnNetworkError(t *testing.T) {
@@ -387,6 +412,27 @@ func TestIsTemplateIncompatibleError(t *testing.T) {
387412
}
388413
}
389414

415+
func TestPullBodyReadFailure(t *testing.T) {
416+
ctrl := gomock.NewController(t)
417+
defer ctrl.Finish()
418+
419+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
420+
mockContext := NewContextForMock(mockClient)
421+
client := New(mockContext)
422+
423+
// Response body read fails. Use a non-retryable 500 status so the test
424+
// completes in a single attempt.
425+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
426+
StatusCode: http.StatusInternalServerError,
427+
Body: &errorReadCloser{err: errors.New("connection reset")},
428+
}, nil).Times(1)
429+
430+
printer := NewSimplePrinter(func(s string) {})
431+
_, _, err := client.Pull("test-model", printer)
432+
require.Error(t, err)
433+
assert.Contains(t, err.Error(), "failed to read response body")
434+
}
435+
390436
func TestDisplayProgressNonJSONLines(t *testing.T) {
391437
// Simulate a proxy returning an HTML error page instead of a progress stream.
392438
htmlBody := "<html><body><h1>502 Bad Gateway</h1></body></html>\n"

cmd/cli/desktop/progress.go

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ import (
1919
// DisplayProgress displays progress messages from a model pull/push operation
2020
// using Docker-style multi-line progress bars.
2121
// Returns the final message, whether progress was actually shown, and any error.
22-
func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string, bool, error) {
22+
func DisplayProgress(
23+
body io.Reader, printer standalone.StatusPrinter,
24+
) (finalMessage string, progressShown bool, retErr error) {
2325
fd, isTerminal := printer.GetFdInfo()
2426

2527
// If not a terminal, fall back to simple line-by-line output
@@ -40,13 +42,22 @@ func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string,
4042
close(errCh)
4143
}()
4244

45+
// Ensure the pipe is always closed and the display goroutine is always
46+
// drained, even on early returns, to prevent goroutine leaks.
47+
defer func() {
48+
pw.Close()
49+
if displayErr := <-errCh; retErr == nil &&
50+
displayErr != nil && !errors.Is(displayErr, io.EOF) {
51+
retErr = displayErr
52+
}
53+
}()
54+
4355
// Convert progress messages to JSONMessage format
4456
scanner := bufio.NewScanner(body)
45-
var finalMessage string
46-
progressShown := false // Track if we actually showed any progress bars
4757
// nonJSONBytes collects raw unparseable lines for error reporting,
4858
// capped at maxNonJSONBytes to avoid large allocations.
4959
var nonJSONBytes []byte
60+
var nonJSONTruncated bool
5061

5162
for scanner.Scan() {
5263
progressLine := scanner.Text()
@@ -58,15 +69,14 @@ func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string,
5869
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
5970
// Collect unparseable lines (e.g. HTML error pages from proxies)
6071
// so we can surface them if no valid progress arrives.
61-
nonJSONBytes = appendNonJSONLine(nonJSONBytes, progressLine)
72+
nonJSONBytes, nonJSONTruncated = appendNonJSONLine(nonJSONBytes, progressLine)
6273
continue
6374
}
6475

6576
switch progressMsg.Type {
6677
case oci.TypeProgress:
6778
progressShown = true // We're showing actual progress
6879
if err := writeDockerProgress(pw, &progressMsg); err != nil {
69-
pw.Close()
7080
return "", false, err
7181
}
7282

@@ -80,33 +90,23 @@ func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string,
8090
printer.PrintErrf("Warning: %s\n", progressMsg.Message)
8191

8292
case oci.TypeError:
83-
pw.Close()
8493
return "", false, fmt.Errorf("%s", progressMsg.Message)
8594
}
8695
}
8796

8897
if err := scanner.Err(); err != nil {
89-
pw.Close()
9098
return "", false, err
9199
}
92100

93101
// If we received only unparseable lines and no valid progress or success,
94102
// surface the raw content as an error. This catches HTML error pages
95103
// returned by proxies or CDNs in place of a proper progress stream.
96104
if finalMessage == "" && !progressShown {
97-
if err := unexpectedProgressDataError(nonJSONBytes); err != nil {
98-
pw.Close()
105+
if err := unexpectedProgressDataError(nonJSONBytes, nonJSONTruncated); err != nil {
99106
return "", false, err
100107
}
101108
}
102109

103-
pw.Close()
104-
105-
// Wait for display to finish
106-
if err := <-errCh; err != nil && !errors.Is(err, io.EOF) {
107-
return finalMessage, progressShown, err
108-
}
109-
110110
return finalMessage, progressShown, nil
111111
}
112112

@@ -119,6 +119,7 @@ func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (st
119119
progressShown := false // Track if we actually showed any progress
120120
// nonJSONBytes collects raw unparseable lines for error reporting.
121121
var nonJSONBytes []byte
122+
var nonJSONTruncated bool
122123

123124
for scanner.Scan() {
124125
progressLine := scanner.Text()
@@ -129,7 +130,7 @@ func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (st
129130
var progressMsg oci.ProgressMessage
130131
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
131132
// Collect unparseable lines for error reporting.
132-
nonJSONBytes = appendNonJSONLine(nonJSONBytes, progressLine)
133+
nonJSONBytes, nonJSONTruncated = appendNonJSONLine(nonJSONBytes, progressLine)
133134
continue
134135
}
135136

@@ -167,7 +168,7 @@ func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (st
167168

168169
// Surface unparseable content if no valid progress was received.
169170
if finalMessage == "" && !progressShown {
170-
if err := unexpectedProgressDataError(nonJSONBytes); err != nil {
171+
if err := unexpectedProgressDataError(nonJSONBytes, nonJSONTruncated); err != nil {
171172
return "", false, err
172173
}
173174
}
@@ -289,29 +290,36 @@ func NewSimplePrinter(printFunc func(string)) standalone.StatusPrinter {
289290
const maxNonJSONBytes = 4096
290291

291292
// appendNonJSONLine appends line (with a newline separator) to dst, enforcing
292-
// a hard cap of maxNonJSONBytes total to avoid large allocations.
293-
func appendNonJSONLine(dst []byte, line string) []byte {
293+
// a hard cap of maxNonJSONBytes total. Returns the updated slice and a boolean
294+
// indicating whether the line was truncated to fit within the cap.
295+
func appendNonJSONLine(dst []byte, line string) ([]byte, bool) {
294296
if len(dst) >= maxNonJSONBytes {
295-
return dst
297+
return dst, true
296298
}
297299
if len(dst) > 0 {
298300
dst = append(dst, '\n')
299301
}
300302
remaining := maxNonJSONBytes - len(dst)
301-
if len(line) > remaining {
303+
truncated := len(line) > remaining
304+
if truncated {
302305
line = line[:remaining]
303306
}
304-
return append(dst, line...)
307+
return append(dst, line...), truncated
305308
}
306309

307310
// unexpectedProgressDataError returns an error describing unexpected non-JSON
308-
// response data, or nil if nonJSONBytes is empty.
309-
func unexpectedProgressDataError(nonJSONBytes []byte) error {
311+
// response data, or nil if nonJSONBytes is empty. If truncated is true, a
312+
// marker is appended to indicate the response was cut off.
313+
func unexpectedProgressDataError(nonJSONBytes []byte, truncated bool) error {
310314
if len(nonJSONBytes) == 0 {
311315
return nil
312316
}
317+
msg := string(nonJSONBytes)
318+
if truncated {
319+
msg += "\n...[truncated]"
320+
}
313321
return fmt.Errorf(
314322
"unexpected response from server (not valid progress data): %s",
315-
string(nonJSONBytes),
323+
msg,
316324
)
317325
}

0 commit comments

Comments
 (0)