Skip to content

Commit 8aadeca

Browse files
authored
MCP wrap/proxy refactor and improvements (#4399)
Proxy: - support --instance to specify target machine - if url resolves, don't create a wireguard proxy - move passthru support to a separate file (prep for supporting replay) Wrap: - produce an error when there is an attempt to connect a second get - ensure last byte is a newline on POST
1 parent 80530e1 commit 8aadeca

File tree

4 files changed

+314
-246
lines changed

4 files changed

+314
-246
lines changed

internal/command/mcp/proxy.go

Lines changed: 40 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,19 @@
11
package mcp
22

33
import (
4-
"bufio"
5-
"bytes"
64
"context"
75
"fmt"
8-
"io"
96
"log"
107
"net"
11-
"net/http"
128
"net/url"
139
"os"
1410
"os/exec"
15-
"os/signal"
1611
"strings"
17-
"sync"
18-
"syscall"
19-
"time"
2012

2113
"github.com/spf13/cobra"
2214
"github.com/superfly/flyctl/internal/appconfig"
2315
"github.com/superfly/flyctl/internal/command"
16+
mcpProxy "github.com/superfly/flyctl/internal/command/mcp/proxy"
2417
"github.com/superfly/flyctl/internal/flag"
2518
"github.com/superfly/flyctl/internal/flag/flagnames"
2619
)
@@ -52,6 +45,10 @@ var sharedProxyFlags = flag.Set{
5245
Default: "127.0.0.1",
5346
Description: "Local address to bind to",
5447
},
48+
flag.String{
49+
Name: "instance",
50+
Description: "Use fly-force-instance-id to connect to a specific instance",
51+
},
5552
}
5653

5754
func NewProxy() *cobra.Command {
@@ -111,30 +108,25 @@ func NewInspect() *cobra.Command {
111108
return cmd
112109
}
113110

114-
type ProxyInfo struct {
115-
url string
116-
bearerToken string
117-
user string
118-
password string
119-
}
120-
121111
func runProxy(ctx context.Context) error {
122-
proxyInfo := ProxyInfo{
123-
url: flag.GetString(ctx, "url"),
124-
bearerToken: flag.GetString(ctx, "bearer-token"),
125-
user: flag.GetString(ctx, "user"),
126-
password: flag.GetString(ctx, "password"),
112+
proxyInfo := mcpProxy.ProxyInfo{
113+
Url: flag.GetString(ctx, "url"),
114+
BearerToken: flag.GetString(ctx, "bearer-token"),
115+
User: flag.GetString(ctx, "user"),
116+
Password: flag.GetString(ctx, "password"),
117+
Instance: flag.GetString(ctx, "instance"),
127118
}
128119

129120
return runProxyOrInspect(ctx, proxyInfo, flag.GetBool(ctx, "inspector"))
130121
}
131122

132123
func runInspect(ctx context.Context) error {
133-
proxyInfo := ProxyInfo{
134-
url: flag.GetString(ctx, "url"),
135-
bearerToken: flag.GetString(ctx, "bearer-token"),
136-
user: flag.GetString(ctx, "user"),
137-
password: flag.GetString(ctx, "password"),
124+
proxyInfo := mcpProxy.ProxyInfo{
125+
Url: flag.GetString(ctx, "url"),
126+
BearerToken: flag.GetString(ctx, "bearer-token"),
127+
User: flag.GetString(ctx, "user"),
128+
Password: flag.GetString(ctx, "password"),
129+
Instance: flag.GetString(ctx, "instance"),
138130
}
139131

140132
server := flag.GetString(ctx, "server")
@@ -150,17 +142,17 @@ func runInspect(ctx context.Context) error {
150142
return err
151143
}
152144

153-
if proxyInfo.url == "" {
154-
proxyInfo.url, _ = mcpConfig["url"].(string)
145+
if proxyInfo.Url == "" {
146+
proxyInfo.Url, _ = mcpConfig["url"].(string)
155147
}
156-
if proxyInfo.bearerToken == "" {
157-
proxyInfo.bearerToken, _ = mcpConfig["bearer-token"].(string)
148+
if proxyInfo.BearerToken == "" {
149+
proxyInfo.BearerToken, _ = mcpConfig["bearer-token"].(string)
158150
}
159-
if proxyInfo.user == "" {
160-
proxyInfo.user, _ = mcpConfig["user"].(string)
151+
if proxyInfo.User == "" {
152+
proxyInfo.User, _ = mcpConfig["user"].(string)
161153
}
162-
if proxyInfo.password == "" {
163-
proxyInfo.password, _ = mcpConfig["password"].(string)
154+
if proxyInfo.Password == "" {
155+
proxyInfo.Password, _ = mcpConfig["password"].(string)
164156
}
165157
} else if len(configPaths) > 1 {
166158
return fmt.Errorf("multiple MCP client configuration files specifed. Please specify at most one")
@@ -169,21 +161,21 @@ func runInspect(ctx context.Context) error {
169161
return runProxyOrInspect(ctx, proxyInfo, true)
170162
}
171163

172-
func runProxyOrInspect(ctx context.Context, proxyInfo ProxyInfo, inspect bool) error {
164+
func runProxyOrInspect(ctx context.Context, proxyInfo mcpProxy.ProxyInfo, inspect bool) error {
173165

174166
// If no URL is provided, try to get it from the app config
175167
// If that fails, return an error
176-
if proxyInfo.url == "" {
168+
if proxyInfo.Url == "" {
177169
appConfig := appconfig.ConfigFromContext(ctx)
178170

179171
if appConfig != nil {
180172
appUrl := appConfig.URL()
181173
if appUrl != nil {
182-
proxyInfo.url = appUrl.String()
174+
proxyInfo.Url = appUrl.String()
183175
}
184176
}
185177

186-
if proxyInfo.url == "" {
178+
if proxyInfo.Url == "" {
187179
log.Fatal("The app config could not be found and no URL was provided")
188180
}
189181
}
@@ -194,16 +186,16 @@ func runProxyOrInspect(ctx context.Context, proxyInfo ProxyInfo, inspect bool) e
194186
return fmt.Errorf("failed to find executable: %w", err)
195187
}
196188

197-
args := []string{"@modelcontextprotocol/inspector@latest", flyctl, "mcp", "proxy", "--url", proxyInfo.url}
189+
args := []string{"@modelcontextprotocol/inspector@latest", flyctl, "mcp", "proxy", "--url", proxyInfo.Url}
198190

199-
if proxyInfo.bearerToken != "" {
200-
args = append(args, "--bearer-token", proxyInfo.bearerToken)
191+
if proxyInfo.BearerToken != "" {
192+
args = append(args, "--bearer-token", proxyInfo.BearerToken)
201193
}
202-
if proxyInfo.user != "" {
203-
args = append(args, "--user", proxyInfo.user)
194+
if proxyInfo.User != "" {
195+
args = append(args, "--user", proxyInfo.User)
204196
}
205-
if proxyInfo.password != "" {
206-
args = append(args, "--password", proxyInfo.password)
197+
if proxyInfo.Password != "" {
198+
args = append(args, "--password", proxyInfo.Password)
207199
}
208200

209201
// Launch MCP inspector
@@ -217,51 +209,19 @@ func runProxyOrInspect(ctx context.Context, proxyInfo ProxyInfo, inspect bool) e
217209
return nil
218210
}
219211

220-
url, proxyCmd, err := resolveProxy(ctx, proxyInfo.url)
212+
url, proxyCmd, err := resolveProxy(ctx, proxyInfo.Url)
221213
if err != nil {
222214
log.Fatalf("Error resolving proxy URL: %v", err)
223215
}
224216

225-
proxyInfo.url = url
217+
proxyInfo.Url = url
226218

227219
// Configure logging to go to stderr only
228220
log.SetOutput(os.Stderr)
229221

230-
err = waitForServer(ctx, proxyInfo)
222+
err = mcpProxy.Passthru(ctx, proxyInfo)
231223
if err != nil {
232-
log.Fatalf("Error waiting for server: %v", err)
233-
}
234-
235-
// Store whether the SSE connection is ready
236-
// This may become unready if the connection is closed
237-
ready := false
238-
readyMutex := sync.Mutex{}
239-
readyCond := sync.NewCond(&readyMutex)
240-
241-
// Start the HTTP client
242-
go func() {
243-
start := time.Now()
244-
for {
245-
getFromServer(ctx, proxyInfo, &ready, readyCond)
246-
247-
// Ready should be set to false when the connection is closed
248-
readyCond.L.Lock()
249-
ready = false
250-
readyCond.Broadcast()
251-
readyCond.L.Unlock()
252-
253-
// Wait a minimum of 10 seconds before the next request
254-
elapsed := time.Since(start)
255-
if elapsed < 10*time.Second {
256-
time.Sleep(10*time.Second - elapsed)
257-
}
258-
start = time.Now()
259-
}
260-
}()
261-
262-
// Start processing stdin
263-
if err := processStdin(ctx, proxyInfo, &ready, readyCond); err != nil {
264-
log.Fatalf("Error processing stdin: %v", err)
224+
log.Fatal(err)
265225
}
266226

267227
// Kill the proxy process if it was started
@@ -275,73 +235,6 @@ func runProxyOrInspect(ctx context.Context, proxyInfo ProxyInfo, inspect bool) e
275235
return nil
276236
}
277237

278-
// waitForServer waits for the server to be up and running
279-
func waitForServer(ctx context.Context, proxyInfo ProxyInfo) error {
280-
// Continue to post nothing until the server is up
281-
delay := 100 * time.Millisecond
282-
var err error
283-
for delay < 60*time.Second {
284-
err = sendToServer(ctx, "", proxyInfo)
285-
286-
if err == nil {
287-
break
288-
} else if !strings.Contains(err.Error(), "connection refused") {
289-
log.Printf("Error sending message to server: %v", err)
290-
break
291-
}
292-
293-
time.Sleep(delay)
294-
delay *= 2
295-
}
296-
297-
return err
298-
}
299-
300-
// ProcessStdin reads messages from stdin and forwards them to the server
301-
func processStdin(ctx context.Context, proxyInfo ProxyInfo, ready *bool, readyCond *sync.Cond) error {
302-
stp := make(chan os.Signal, 1)
303-
signal.Notify(stp, syscall.SIGINT, syscall.SIGTERM)
304-
go func() {
305-
<-stp
306-
os.Exit(0)
307-
}()
308-
309-
scanner := bufio.NewScanner(os.Stdin)
310-
for scanner.Scan() {
311-
line := scanner.Text() + "\n"
312-
313-
// Skip empty lines
314-
if strings.TrimSpace(line) == "" {
315-
continue
316-
}
317-
318-
// Wait for the server to be ready
319-
readyCond.L.Lock()
320-
for !*ready {
321-
readyCond.Wait()
322-
}
323-
readyCond.L.Unlock()
324-
325-
// Forward raw message to server
326-
err := sendToServer(ctx, line, proxyInfo)
327-
if err != nil {
328-
// Log error but continue processing
329-
log.Printf("Error sending message to server: %v", err)
330-
// We could format an error message here, but since we're operating at the raw string level,
331-
// we'll return a generic error JSON
332-
errMsg := fmt.Sprintf(`{"type":"error","content":"Failed to send to server: %v"}`, err)
333-
fmt.Fprintln(os.Stdout, errMsg)
334-
continue
335-
}
336-
}
337-
338-
if err := scanner.Err(); err != nil {
339-
return fmt.Errorf("error reading from stdin: %w", err)
340-
}
341-
342-
return nil
343-
}
344-
345238
// resolveProxy starts the proxy process and returns the new URL
346239
func resolveProxy(ctx context.Context, originalUrl string) (string, *exec.Cmd, error) {
347240
appName := flag.GetString(ctx, "app")
@@ -431,88 +324,3 @@ func getAvailablePort() (int, error) {
431324

432325
return listener.Addr().(*net.TCPAddr).Port, nil
433326
}
434-
435-
// getFromServer sends a GET request to the server and streams the response to stdout
436-
func getFromServer(ctx context.Context, proxyInfo ProxyInfo, ready *bool, readyCond *sync.Cond) error {
437-
// Create HTTP request
438-
req, err := http.NewRequest("GET", proxyInfo.url, nil)
439-
if err != nil {
440-
return fmt.Errorf("error creating request: %w", err)
441-
}
442-
req.Header.Set("User-Agent", "mcp-bridge-client")
443-
req.Header.Set("Accept", "application/json")
444-
445-
// Set basic authentication if bearer token or user is provided
446-
if proxyInfo.bearerToken != "" {
447-
req.Header.Set("Authorization", "Bearer "+proxyInfo.bearerToken)
448-
} else if proxyInfo.user != "" {
449-
req.SetBasicAuth(proxyInfo.user, proxyInfo.password)
450-
}
451-
452-
// Send request
453-
client := &http.Client{}
454-
resp, err := client.Do(req)
455-
if err != nil {
456-
return fmt.Errorf("error sending request: %w", err)
457-
}
458-
defer resp.Body.Close()
459-
460-
// Check response status
461-
if resp.StatusCode != http.StatusOK {
462-
return fmt.Errorf("server returned error: %s (status %d)", resp.Status, resp.StatusCode)
463-
}
464-
465-
// We're now ready to receive messages
466-
readyCond.L.Lock()
467-
*ready = true
468-
readyCond.Broadcast()
469-
readyCond.L.Unlock()
470-
471-
// Stream response body to stdout
472-
if _, err := io.Copy(os.Stdout, resp.Body); err != nil {
473-
return fmt.Errorf("error streaming response to stdout: %w", err)
474-
}
475-
476-
return nil
477-
}
478-
479-
// SendToServer sends a raw message to the server and returns the raw response
480-
func sendToServer(ctx context.Context, message string, proxyInfo ProxyInfo) error {
481-
// Create HTTP request with raw message
482-
req, err := http.NewRequest("POST", proxyInfo.url, bytes.NewBufferString(message))
483-
if err != nil {
484-
return fmt.Errorf("error creating request: %w", err)
485-
}
486-
req.Header.Set("Content-Type", "application/json")
487-
req.Header.Set("User-Agent", "mcp-bridge-client")
488-
req.Header.Set("Accept", "application/json, text/event-stream")
489-
490-
// Set basic authentication if bearer token or user is provided
491-
if proxyInfo.bearerToken != "" {
492-
req.Header.Set("Authorization", "Bearer "+proxyInfo.bearerToken)
493-
} else if proxyInfo.user != "" {
494-
req.SetBasicAuth(proxyInfo.user, proxyInfo.password)
495-
}
496-
497-
// Send request
498-
client := &http.Client{}
499-
resp, err := client.Do(req)
500-
if err != nil {
501-
return fmt.Errorf("error sending request: %w", err)
502-
}
503-
defer resp.Body.Close()
504-
505-
// Check response status
506-
if resp.StatusCode != http.StatusAccepted {
507-
// Read response body
508-
body, err := io.ReadAll(resp.Body)
509-
if err != nil {
510-
return fmt.Errorf("error reading response: %w", err)
511-
}
512-
513-
return fmt.Errorf("server returned error: %s (status %d)", body, resp.StatusCode)
514-
}
515-
516-
// Request was accepted
517-
return nil
518-
}

0 commit comments

Comments
 (0)