@@ -1075,6 +1075,96 @@ func TestWaitAsync(t *testing.T) {
1075
1075
assert .Equal (t , replicate .Succeeded , lastStatus )
1076
1076
}
1077
1077
1078
+ func TestRun (t * testing.T ) {
1079
+ mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1080
+ switch r .URL .Path {
1081
+ case "/predictions" :
1082
+ assert .Equal (t , http .MethodPost , r .Method )
1083
+ prediction := replicate.Prediction {
1084
+ ID : "gtsllfynndufawqhdngldkdrkq" ,
1085
+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1086
+ Status : replicate .Starting ,
1087
+ }
1088
+ json .NewEncoder (w ).Encode (prediction )
1089
+ case "/predictions/gtsllfynndufawqhdngldkdrkq" :
1090
+ assert .Equal (t , http .MethodGet , r .Method )
1091
+ prediction := replicate.Prediction {
1092
+ ID : "gtsllfynndufawqhdngldkdrkq" ,
1093
+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1094
+ Status : replicate .Succeeded ,
1095
+ Output : "Hello, world!" ,
1096
+ }
1097
+ json .NewEncoder (w ).Encode (prediction )
1098
+ default :
1099
+ t .Fatalf ("Unexpected request to %s" , r .URL .Path )
1100
+ }
1101
+ }))
1102
+ defer mockServer .Close ()
1103
+
1104
+ client , err := replicate .NewClient (
1105
+ replicate .WithToken ("test-token" ),
1106
+ replicate .WithBaseURL (mockServer .URL ),
1107
+ )
1108
+ require .NoError (t , err )
1109
+
1110
+ ctx := context .Background ()
1111
+ input := replicate.PredictionInput {"prompt" : "Hello" }
1112
+ output , err := client .Run (ctx , "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" , input , nil )
1113
+
1114
+ require .NoError (t , err )
1115
+ assert .NotNil (t , output )
1116
+ assert .Equal (t , "Hello, world!" , output )
1117
+ }
1118
+
1119
+ func TestRunReturningModelError (t * testing.T ) {
1120
+ mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1121
+ switch r .URL .Path {
1122
+ case "/predictions" :
1123
+ assert .Equal (t , http .MethodPost , r .Method )
1124
+ prediction := replicate.Prediction {
1125
+ ID : "fynndufawqhdngldkgtslldrkq" ,
1126
+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1127
+ Status : replicate .Starting ,
1128
+ }
1129
+ json .NewEncoder (w ).Encode (prediction )
1130
+ case "/predictions/fynndufawqhdngldkgtslldrkq" :
1131
+ assert .Equal (t , http .MethodGet , r .Method )
1132
+
1133
+ logs := "Could not say hello"
1134
+ prediction := replicate.Prediction {
1135
+ ID : "fynndufawqhdngldkgtslldrkq" ,
1136
+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1137
+ Status : replicate .Failed ,
1138
+ Logs : & logs ,
1139
+ Error : "Model execution failed" ,
1140
+ }
1141
+ json .NewEncoder (w ).Encode (prediction )
1142
+ default :
1143
+ t .Fatalf ("Unexpected request to %s" , r .URL .Path )
1144
+ }
1145
+ }))
1146
+ defer mockServer .Close ()
1147
+
1148
+ client , err := replicate .NewClient (
1149
+ replicate .WithToken ("test-token" ),
1150
+ replicate .WithBaseURL (mockServer .URL ),
1151
+ )
1152
+ require .NoError (t , err )
1153
+
1154
+ ctx := context .Background ()
1155
+ input := replicate.PredictionInput {"prompt" : "Hello" }
1156
+ _ , err = client .Run (ctx , "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" , input , nil )
1157
+
1158
+ require .Error (t , err )
1159
+ modelErr , ok := err .(* replicate.ModelError )
1160
+ require .True (t , ok , "Expected error to be of type *replicate.ModelError" )
1161
+ assert .Equal (t , "model error: Model execution failed" , modelErr .Error ())
1162
+ assert .Equal (t , "fynndufawqhdngldkgtslldrkq" , modelErr .Prediction .ID )
1163
+ assert .Equal (t , replicate .Failed , modelErr .Prediction .Status )
1164
+ assert .Equal (t , "Model execution failed" , modelErr .Prediction .Error )
1165
+ assert .Equal (t , "Could not say hello" , * modelErr .Prediction .Logs )
1166
+ }
1167
+
1078
1168
func TestCreateTraining (t * testing.T ) {
1079
1169
mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1080
1170
assert .Equal (t , http .MethodPost , r .Method )
0 commit comments