Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmd/cli/commands/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
"github.com/moby/term"
Expand Down Expand Up @@ -53,6 +54,13 @@ func handleClientError(err error, message string) error {
var buf bytes.Buffer
printNextSteps(&buf, []string{enableVLLM})
return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n"))
} else if errors.Is(err, distribution.ErrUnsupportedMediaType) {
// The model uses a newer config format than this client supports.
var buf bytes.Buffer
printNextSteps(&buf, []string{
"Upgrade Docker Model Runner to the latest version to support this model",
})
return fmt.Errorf("%s: %w\n%s", message, err, strings.TrimRight(buf.String(), "\n"))
}
return fmt.Errorf("%s: %w", message, err)
}
Expand Down
46 changes: 37 additions & 9 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,29 @@ func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, b
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err := fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
// Only retry on server errors (5xx), not client errors (4xx)
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
body, readErr := io.ReadAll(resp.Body)
var bodyStr string
if readErr != nil {
bodyStr = fmt.Sprintf("failed to read response body: %v", readErr)
} else {
bodyStr = strings.TrimSpace(string(body))
}
var err error
if resp.StatusCode == http.StatusUnprocessableEntity {
// 422 means the model uses a config type this client does not
// support. Reattach the sentinel so callers can use errors.Is.
err = fmt.Errorf("pulling %s failed with status %s: %w: %s",
model, resp.Status, distribution.ErrUnsupportedMediaType, bodyStr)
} else {
err = fmt.Errorf("pulling %s failed with status %s: %s",
model, resp.Status, bodyStr)
}
// Only retry on gateway/proxy errors (502, 503, 504).
// Do not retry 500 (usually deterministic server errors) or
// 4xx (client errors including 422 for unsupported media type).
shouldRetry := resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable ||
resp.StatusCode == http.StatusGatewayTimeout
return "", false, err, shouldRetry
}

Expand Down Expand Up @@ -235,7 +254,7 @@ func (c *Client) withRetries(
}
}

return "", progressShown, fmt.Errorf("failed to %s after %d retries: %w", operationName, maxRetries, lastErr)
return "", progressShown, fmt.Errorf("%s failed after %d retries: %w", operationName, maxRetries, lastErr)
}

func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
Expand Down Expand Up @@ -272,10 +291,19 @@ func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, b
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err := fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
// Only retry on server errors (5xx), not client errors (4xx)
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
body, readErr := io.ReadAll(resp.Body)
var bodyStr string
if readErr != nil {
bodyStr = fmt.Sprintf("(failed to read response body: %v)", readErr)
} else {
bodyStr = strings.TrimSpace(string(body))
}
err := fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, bodyStr)
// Only retry on gateway/proxy errors. Do not retry plain 500
// (usually deterministic server errors) or 4xx (client errors).
shouldRetry := resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable ||
resp.StatusCode == http.StatusGatewayTimeout
return "", false, err, shouldRetry
}

Expand Down
140 changes: 126 additions & 14 deletions cmd/cli/desktop/desktop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@ import (
"errors"
"io"
"net/http"
"strings"
"testing"

mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

// errorReadCloser is an io.ReadCloser whose Read always returns an error.
type errorReadCloser struct{ err error }

func (e *errorReadCloser) Read(_ []byte) (int, error) { return 0, e.err }
func (e *errorReadCloser) Close() error { return nil }

func TestPullRetryOnNetworkError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down Expand Up @@ -59,7 +67,7 @@ func TestPullNoRetryOn4xxError(t *testing.T) {
assert.Contains(t, err.Error(), "Model not found")
}

func TestPullRetryOn5xxError(t *testing.T) {
func TestPullNoRetryOn500Error(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -68,21 +76,83 @@ func TestPullRetryOn5xxError(t *testing.T) {
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// First attempt fails with 500, second succeeds
gomock.InOrder(
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
}, nil),
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil),
)
// 500 is not retried (deterministic server error), so only 1 call.
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
}, nil).Times(1)

printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, printer)
assert.NoError(t, err)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Internal server error")
}

func TestPullNoRetryOn422Error(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

modelName := "test-model"
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// 422 (unsupported media type) must not be retried.
unsupportedMsg := `error while pulling model: config type "v0.3" is not supported` +
` - try upgrading`
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewBufferString(unsupportedMsg)),
}, nil).Times(1)

printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, printer)
require.Error(t, err)
// The sentinel must be preserved so callers can use errors.Is.
assert.True(t, errors.Is(err, distribution.ErrUnsupportedMediaType))
}

func TestPullRetriesOnTransientGatewayErrors(t *testing.T) {
// 502 and 504 are transient gateway/proxy errors and should be retried.
// Note: 503 is intercepted by doRequestWithAuthContext as ErrServiceUnavailable
// and is covered separately by TestPullRetryOnServiceUnavailable.
transientCodes := []struct {
code int
name string
body string
}{
{http.StatusBadGateway, "502 Bad Gateway", "Bad Gateway"},
{http.StatusGatewayTimeout, "504 Gateway Timeout", "Gateway Timeout"},
}

for _, tc := range transientCodes {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// First attempt fails with the transient error, second succeeds.
gomock.InOrder(
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: tc.code,
Body: io.NopCloser(bytes.NewBufferString(tc.body)),
}, nil),
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(
`{"type":"success","message":"Model pulled successfully"}`,
)),
}, nil),
)

printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull("test-model", printer)
assert.NoError(t, err)
})
}
}

func TestPullRetryOnServiceUnavailable(t *testing.T) {
Expand Down Expand Up @@ -127,7 +197,7 @@ func TestPullMaxRetriesExhausted(t *testing.T) {
printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, printer)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to download after 3 retries")
assert.Contains(t, err.Error(), "download failed after 3 retries")
}

func TestPushRetryOnNetworkError(t *testing.T) {
Expand Down Expand Up @@ -341,3 +411,45 @@ func TestIsTemplateIncompatibleError(t *testing.T) {
})
}
}

func TestPullBodyReadFailure(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// Response body read fails. Use a non-retryable 500 status so the test
// completes in a single attempt.
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusInternalServerError,
Body: &errorReadCloser{err: errors.New("connection reset")},
}, nil).Times(1)

printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull("test-model", printer)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to read response body")
}

func TestDisplayProgressNonJSONLines(t *testing.T) {
// Simulate a proxy returning an HTML error page instead of a progress stream.
htmlBody := "<html><body><h1>502 Bad Gateway</h1></body></html>\n"
printer := NewSimplePrinter(func(string) {})
_, _, err := DisplayProgress(strings.NewReader(htmlBody), printer)
require.Error(t, err)
assert.Contains(t, err.Error(), "unexpected response from server")
assert.Contains(t, err.Error(), "502 Bad Gateway")
}

func TestDisplayProgressMixedContent(t *testing.T) {
// Valid progress followed by some unparseable lines: the valid progress
// should be honoured and no error returned for the stray lines.
body := `{"type":"success","message":"Model pulled successfully"}` + "\n" +
"<html>some extra garbage</html>\n"
printer := NewSimplePrinter(func(string) {})
msg, _, err := DisplayProgress(strings.NewReader(body), printer)
require.NoError(t, err)
assert.Equal(t, "Model pulled successfully", msg)
}
Loading
Loading