Skip to content

Commit bf16f10

Browse files
authored
fix: check request origin (#838)
1 parent ab0a10a commit bf16f10

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

internal/mcp/llm_binding.go

+36-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,13 @@ func (m *McpLLMBinding) HandleSseServer() error {
141141
m.mutex.Unlock()
142142
}()
143143

144-
err := m.sseServer.Start(m.baseURL.Host)
144+
srv := &http.Server{
145+
Addr: m.baseURL.Host,
146+
Handler: middleware(m.sseServer),
147+
}
148+
149+
err := srv.ListenAndServe()
150+
145151
if err != nil {
146152
// expect http.ErrServerClosed when shutting down
147153
if !errors.Is(err, http.ErrServerClosed) {
@@ -152,6 +158,35 @@ func (m *McpLLMBinding) HandleSseServer() error {
152158
return nil
153159
}
154160

161+
var allowedHostnames = map[string]bool{
162+
"localhost": true,
163+
"127.0.0.1": true,
164+
"::1": true,
165+
}
166+
167+
func middleware(sseServer *server.SSEServer) http.Handler {
168+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169+
originHeader := r.Header.Get("Origin")
170+
isValidOrigin := originHeader == ""
171+
172+
if originHeader != "" {
173+
parsedOrigin, err := url.Parse(originHeader)
174+
if err == nil {
175+
requestHost := parsedOrigin.Hostname()
176+
if _, allowed := allowedHostnames[requestHost]; allowed {
177+
isValidOrigin = true
178+
}
179+
}
180+
}
181+
182+
if isValidOrigin {
183+
sseServer.ServeHTTP(w, r)
184+
} else {
185+
http.Error(w, "Forbidden: Access restricted to localhost origins", http.StatusForbidden)
186+
}
187+
})
188+
}
189+
155190
func (m *McpLLMBinding) Shutdown(ctx context.Context) {
156191
m.mutex.Lock()
157192
defer m.mutex.Unlock()

0 commit comments

Comments
 (0)