Skip to content

Commit 8ca633b

Browse files
committed
feat: add some handy middleware for visibility
1 parent 84a62ce commit 8ca633b

File tree

9 files changed

+1220
-411
lines changed

9 files changed

+1220
-411
lines changed

internal/commands/http.go

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,19 @@ func (c *HTTPCmd) Run(ctx context.Context, globals *Globals) error {
5353
mux := http.NewServeMux()
5454
srv := newServerWithTimeouts(mux)
5555

56+
// Build middleware chain
57+
chain := middleware.NewChain().
58+
Use(middleware.ClientIP(c.TrustProxy)).
59+
Use(middleware.RequestLog()).
60+
UseIf(c.AuthToken != "", middleware.Auth(c.AuthToken))
61+
5662
var handler http.Handler
5763
if c.UseSSE {
58-
handler = mcpserver.NewSSEServer(mcpServer)
59-
60-
// Apply middleware in order: ClientIP first (extracts IP into context), then Auth
61-
handler = middleware.ClientIP(c.TrustProxy)(handler)
62-
if c.AuthToken != "" {
63-
handler = middleware.Auth(c.AuthToken)(handler)
64-
}
65-
64+
handler = chain.Then(mcpserver.NewSSEServer(mcpServer))
6665
mux.Handle("/sse", handler)
6766
logEvent.Str("transport", "sse").Str("endpoint", fmt.Sprintf("http://%s/sse", listener.Addr())).Msg("Starting SSE HTTP server")
6867
} else {
69-
handler = mcpserver.NewStreamableHTTPServer(mcpServer)
70-
71-
// Apply middleware in order: ClientIP first (extracts IP into context), then Auth
72-
handler = middleware.ClientIP(c.TrustProxy)(handler)
73-
if c.AuthToken != "" {
74-
handler = middleware.Auth(c.AuthToken)(handler)
75-
}
76-
68+
handler = chain.Then(mcpserver.NewStreamableHTTPServer(mcpServer))
7769
mux.Handle("/mcp", handler)
7870
logEvent.Str("transport", "streamable-http").Str("endpoint", fmt.Sprintf("http://%s/mcp", listener.Addr())).Msg("Starting Streamable HTTP server")
7971
}

pkg/middleware/auth.go

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,13 @@
11
package middleware
22

33
import (
4-
"context"
54
"crypto/hmac"
65
"net/http"
76
"strings"
87

98
"github.com/rs/zerolog/log"
109
)
1110

12-
// contextKey is a private type for context keys to avoid collisions
13-
type contextKey string
14-
15-
const (
16-
// clientIPKey is the context key for storing the client IP address
17-
clientIPKey contextKey = "client_ip"
18-
)
19-
20-
// ClientIP creates an HTTP middleware that extracts the real client IP address
21-
// and injects it into the request context. This should be the first middleware
22-
// in the chain to ensure all subsequent middlewares and handlers can access it.
23-
//
24-
// When trustProxy is false, it uses r.RemoteAddr directly.
25-
// When trustProxy is true, it checks proxy headers in priority order.
26-
func ClientIP(trustProxy bool) func(http.Handler) http.Handler {
27-
return func(next http.Handler) http.Handler {
28-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29-
clientIP := getClientIP(r, trustProxy)
30-
ctx := context.WithValue(r.Context(), clientIPKey, clientIP)
31-
next.ServeHTTP(w, r.WithContext(ctx))
32-
})
33-
}
34-
}
35-
36-
// GetClientIPFromContext extracts the client IP from the request context.
37-
// Returns an empty string if the IP is not found in the context.
38-
// This should be used after the ClientIP middleware has run.
39-
func GetClientIPFromContext(ctx context.Context) string {
40-
if ip, ok := ctx.Value(clientIPKey).(string); ok {
41-
return ip
42-
}
43-
return ""
44-
}
45-
4611
// Auth creates an HTTP middleware that validates Bearer token authentication.
4712
// It uses constant-time comparison to prevent timing attacks.
4813
// The client IP for logging is read from the request context (set by ClientIP middleware).
@@ -70,49 +35,3 @@ func Auth(token string) func(http.Handler) http.Handler {
7035
})
7136
}
7237
}
73-
74-
// getClientIP extracts the real client IP from the request, checking multiple proxy headers.
75-
// This is an internal helper function. Use ClientIP middleware and GetClientIPFromContext instead.
76-
//
77-
// When trustProxy is false, it returns r.RemoteAddr directly without checking proxy headers.
78-
// When trustProxy is true, it checks headers in priority order:
79-
// - CF-Connecting-IP: Cloudflare
80-
// - True-Client-IP: Akamai and Cloudflare Enterprise
81-
// - X-Real-IP: Nginx proxy/FastCGI
82-
// - X-Forwarded-For: Standard proxy header (takes first IP from comma-separated list)
83-
// - X-Client-IP: Apache and others
84-
//
85-
// Security Warning: Only enable trustProxy when behind a trusted reverse proxy that
86-
// properly sets these headers. Proxy headers can be spoofed if the application is
87-
// directly exposed to the internet.
88-
func getClientIP(r *http.Request, trustProxy bool) string {
89-
if !trustProxy {
90-
return r.RemoteAddr
91-
}
92-
93-
// Priority order of headers to check
94-
headers := []string{
95-
"CF-Connecting-IP", // Cloudflare
96-
"True-Client-IP", // Akamai and Cloudflare Enterprise
97-
"X-Real-IP", // Nginx proxy/FastCGI
98-
"X-Forwarded-For", // Standard proxy header
99-
"X-Client-IP", // Apache, others
100-
}
101-
102-
for _, header := range headers {
103-
if ip := r.Header.Get(header); ip != "" {
104-
// For X-Forwarded-For, take the first IP (original client)
105-
// Format: X-Forwarded-For: client, proxy1, proxy2
106-
if header == "X-Forwarded-For" {
107-
ips := strings.Split(ip, ",")
108-
if len(ips) > 0 {
109-
return strings.TrimSpace(ips[0])
110-
}
111-
}
112-
return ip
113-
}
114-
}
115-
116-
// Fall back to RemoteAddr
117-
return r.RemoteAddr
118-
}

0 commit comments

Comments
 (0)