Skip to content

Commit 3c5fd6b

Browse files
authored
Return ModelError when Run fails (#76)
* Rename apierror.go to error.go * Fix capitalization of error string * Define ModelError type * Return ModelError when Run fails to produce output * Add test coverage for Run method
1 parent 972c92e commit 3c5fd6b

File tree

3 files changed

+112
-2
lines changed

3 files changed

+112
-2
lines changed

client_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,96 @@ func TestWaitAsync(t *testing.T) {
10751075
assert.Equal(t, replicate.Succeeded, lastStatus)
10761076
}
10771077

1078+
func TestRun(t *testing.T) {
1079+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1080+
switch r.URL.Path {
1081+
case "/predictions":
1082+
assert.Equal(t, http.MethodPost, r.Method)
1083+
prediction := replicate.Prediction{
1084+
ID: "gtsllfynndufawqhdngldkdrkq",
1085+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1086+
Status: replicate.Starting,
1087+
}
1088+
json.NewEncoder(w).Encode(prediction)
1089+
case "/predictions/gtsllfynndufawqhdngldkdrkq":
1090+
assert.Equal(t, http.MethodGet, r.Method)
1091+
prediction := replicate.Prediction{
1092+
ID: "gtsllfynndufawqhdngldkdrkq",
1093+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1094+
Status: replicate.Succeeded,
1095+
Output: "Hello, world!",
1096+
}
1097+
json.NewEncoder(w).Encode(prediction)
1098+
default:
1099+
t.Fatalf("Unexpected request to %s", r.URL.Path)
1100+
}
1101+
}))
1102+
defer mockServer.Close()
1103+
1104+
client, err := replicate.NewClient(
1105+
replicate.WithToken("test-token"),
1106+
replicate.WithBaseURL(mockServer.URL),
1107+
)
1108+
require.NoError(t, err)
1109+
1110+
ctx := context.Background()
1111+
input := replicate.PredictionInput{"prompt": "Hello"}
1112+
output, err := client.Run(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil)
1113+
1114+
require.NoError(t, err)
1115+
assert.NotNil(t, output)
1116+
assert.Equal(t, "Hello, world!", output)
1117+
}
1118+
1119+
func TestRunReturningModelError(t *testing.T) {
1120+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1121+
switch r.URL.Path {
1122+
case "/predictions":
1123+
assert.Equal(t, http.MethodPost, r.Method)
1124+
prediction := replicate.Prediction{
1125+
ID: "fynndufawqhdngldkgtslldrkq",
1126+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1127+
Status: replicate.Starting,
1128+
}
1129+
json.NewEncoder(w).Encode(prediction)
1130+
case "/predictions/fynndufawqhdngldkgtslldrkq":
1131+
assert.Equal(t, http.MethodGet, r.Method)
1132+
1133+
logs := "Could not say hello"
1134+
prediction := replicate.Prediction{
1135+
ID: "fynndufawqhdngldkgtslldrkq",
1136+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1137+
Status: replicate.Failed,
1138+
Logs: &logs,
1139+
Error: "Model execution failed",
1140+
}
1141+
json.NewEncoder(w).Encode(prediction)
1142+
default:
1143+
t.Fatalf("Unexpected request to %s", r.URL.Path)
1144+
}
1145+
}))
1146+
defer mockServer.Close()
1147+
1148+
client, err := replicate.NewClient(
1149+
replicate.WithToken("test-token"),
1150+
replicate.WithBaseURL(mockServer.URL),
1151+
)
1152+
require.NoError(t, err)
1153+
1154+
ctx := context.Background()
1155+
input := replicate.PredictionInput{"prompt": "Hello"}
1156+
_, err = client.Run(ctx, "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil)
1157+
1158+
require.Error(t, err)
1159+
modelErr, ok := err.(*replicate.ModelError)
1160+
require.True(t, ok, "Expected error to be of type *replicate.ModelError")
1161+
assert.Equal(t, "model error: Model execution failed", modelErr.Error())
1162+
assert.Equal(t, "fynndufawqhdngldkgtslldrkq", modelErr.Prediction.ID)
1163+
assert.Equal(t, replicate.Failed, modelErr.Prediction.Status)
1164+
assert.Equal(t, "Model execution failed", modelErr.Prediction.Error)
1165+
assert.Equal(t, "Could not say hello", *modelErr.Prediction.Logs)
1166+
}
1167+
10781168
func TestCreateTraining(t *testing.T) {
10791169
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
10801170
assert.Equal(t, http.MethodPost, r.Method)

apierror.go renamed to error.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (e APIError) Error() string {
5555

5656
output := strings.Join(components, ": ")
5757
if output == "" {
58-
output = "Unknown error"
58+
output = "unknown error"
5959
}
6060

6161
if e.Instance != "" {
@@ -78,3 +78,16 @@ func (e *APIError) WriteHTTPResponse(w http.ResponseWriter) {
7878
http.Error(w, err.Error(), http.StatusInternalServerError)
7979
}
8080
}
81+
82+
// ModelError represents an error returned by a model for a failed prediction.
83+
type ModelError struct {
84+
Prediction *Prediction `json:"prediction"`
85+
}
86+
87+
func (e *ModelError) Error() string {
88+
if e.Prediction == nil {
89+
return "unknown model error"
90+
}
91+
92+
return fmt.Sprintf("model error: %s", e.Prediction.Error)
93+
}

run.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp
2121
}
2222

2323
err = r.Wait(ctx, prediction)
24+
if err != nil {
25+
return nil, err
26+
}
27+
28+
if prediction.Error != nil {
29+
return nil, &ModelError{Prediction: prediction}
30+
}
2431

25-
return prediction.Output, err
32+
return prediction.Output, nil
2633
}

0 commit comments

Comments
 (0)