Skip to content

Commit 18c1346

Browse files
committed
Add Access-Control-Allow-Origin CORS header to /v1/models endpoint
- match behavior of llama.cpp where the Origin in request is used - add test for listModelsHandler
1 parent da2326b commit 18c1346

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

proxy/proxymanager.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
9898
// Set the Content-Type header to application/json
9999
c.Header("Content-Type", "application/json")
100100

101+
if origin := c.Request.Header.Get("Origin"); origin != "" {
102+
c.Header("Access-Control-Allow-Origin", origin)
103+
}
104+
101105
// Encode the data as JSON and write it to the response writer
102106
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
103107
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))

proxy/proxymanager_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxy
22

33
import (
44
"bytes"
5+
"encoding/json"
56
"fmt"
67
"net/http"
78
"net/http/httptest"
@@ -141,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
141142
assert.Equal(t, key, result)
142143
}
143144
}
145+
146+
func TestProxyManager_ListModelsHandler(t *testing.T) {
147+
config := &Config{
148+
HealthCheckTimeout: 15,
149+
Models: map[string]ModelConfig{
150+
"model1": getTestSimpleResponderConfig("model1"),
151+
"model2": getTestSimpleResponderConfig("model2"),
152+
"model3": getTestSimpleResponderConfig("model3"),
153+
},
154+
}
155+
156+
proxy := New(config)
157+
158+
// Create a test request
159+
req := httptest.NewRequest("GET", "/v1/models", nil)
160+
req.Header.Add("Origin", "i-am-the-origin")
161+
w := httptest.NewRecorder()
162+
163+
// Call the listModelsHandler
164+
proxy.HandlerFunc(w, req)
165+
166+
// Check the response status code
167+
assert.Equal(t, http.StatusOK, w.Code)
168+
169+
// Check for Access-Control-Allow-Origin
170+
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
171+
172+
// Parse the JSON response
173+
var response struct {
174+
Data []map[string]interface{} `json:"data"`
175+
}
176+
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
177+
t.Fatalf("Failed to parse JSON response: %v", err)
178+
}
179+
180+
// Check the number of models returned
181+
assert.Len(t, response.Data, 3)
182+
183+
// Check the details of each model
184+
expectedModels := map[string]struct{}{
185+
"model1": {},
186+
"model2": {},
187+
"model3": {},
188+
}
189+
190+
for _, model := range response.Data {
191+
modelID, ok := model["id"].(string)
192+
assert.True(t, ok, "model ID should be a string")
193+
_, exists := expectedModels[modelID]
194+
assert.True(t, exists, "unexpected model ID: %s", modelID)
195+
delete(expectedModels, modelID)
196+
197+
object, ok := model["object"].(string)
198+
assert.True(t, ok, "object should be a string")
199+
assert.Equal(t, "model", object)
200+
201+
created, ok := model["created"].(float64)
202+
assert.True(t, ok, "created should be a number")
203+
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
204+
205+
ownedBy, ok := model["owned_by"].(string)
206+
assert.True(t, ok, "owned_by should be a string")
207+
assert.Equal(t, "llama-swap", ownedBy)
208+
}
209+
210+
// Ensure all expected models were returned
211+
assert.Empty(t, expectedModels, "not all expected models were returned")
212+
}

0 commit comments

Comments
 (0)