diff --git a/cmd/thv/app/config.go b/cmd/thv/app/config.go index 9448514..c5f7790 100644 --- a/cmd/thv/app/config.go +++ b/cmd/thv/app/config.go @@ -13,6 +13,7 @@ import ( "github.com/stacklok/toolhive/pkg/labels" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" + "github.com/stacklok/toolhive/pkg/transport" ) var configCmd = &cobra.Command{ @@ -262,7 +263,7 @@ func addRunningMCPsToClient(ctx context.Context, clientName string) error { } // Generate URL for the MCP server - url := client.GenerateMCPServerURL("localhost", port, name) + url := client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name) // Update each configuration file for _, clientConfig := range clientConfigs { diff --git a/cmd/thv/app/list.go b/cmd/thv/app/list.go index 4662ec4..64d39a0 100644 --- a/cmd/thv/app/list.go +++ b/cmd/thv/app/list.go @@ -13,6 +13,7 @@ import ( "github.com/stacklok/toolhive/pkg/labels" "github.com/stacklok/toolhive/pkg/lifecycle" "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport" ) var listCmd = &cobra.Command{ @@ -28,10 +29,7 @@ var ( ) // Constants for list command -const ( - defaultHost = "localhost" - unknownTransport = "unknown" -) +const unknownTransport = "unknown" // ContainerOutput represents container information for JSON output type ContainerOutput struct { @@ -73,7 +71,7 @@ func listCmdFunc(cmd *cobra.Command, _ []string) error { // Output based on format switch listFormat { //nolint:goconst - case "json": + case FormatJSON: return printJSONOutput(toolHiveContainers) case "mcpservers": return printMCPServersOutput(toolHiveContainers) @@ -101,9 +99,9 @@ func printJSONOutput(containers []rt.ContainerInfo) error { } // Get transport type from labels - transport := labels.GetTransportType(c.Labels) - if transport == "" { - transport = unknownTransport + t := labels.GetTransportType(c.Labels) + if t == "" { + t = unknownTransport } // Get tool type from labels @@ -118,7 +116,7 @@ func printJSONOutput(containers []rt.ContainerInfo) error { // Generate URL for the MCP server url := "" if port > 0 { - url = client.GenerateMCPServerURL(defaultHost, port, name) + url = client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name) } output = append(output, ContainerOutput{ @@ -126,7 +124,7 @@ func printJSONOutput(containers []rt.ContainerInfo) error { Name: name, Image: c.Image, State: c.State, - Transport: transport, + Transport: t, ToolType: toolType, Port: port, URL: url, @@ -174,7 +172,7 @@ func printMCPServersOutput(containers []rt.ContainerInfo) error { // Generate URL for the MCP server url := "" if port > 0 { - url = client.GenerateMCPServerURL(defaultHost, port, name) + url = client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name) } // Add the MCP server to the map @@ -217,9 +215,9 @@ func printTextOutput(containers []rt.ContainerInfo) { } // Get transport type from labels - transport := labels.GetTransportType(c.Labels) - if transport == "" { - transport = unknownTransport + t := labels.GetTransportType(c.Labels) + if t == "" { + t = unknownTransport } // Get port from labels @@ -231,7 +229,7 @@ func printTextOutput(containers []rt.ContainerInfo) { // Generate URL for the MCP server url := "" if port > 0 { - url = client.GenerateMCPServerURL(defaultHost, port, name) + url = client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name) } // Print container information @@ -240,7 +238,7 @@ func printTextOutput(containers []rt.ContainerInfo) { name, c.Image, c.State, - transport, + t, port, url, ) diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 97c2e38..390c1f4 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -11,6 +11,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/transport" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -25,11 +26,13 @@ This command creates a standalone proxy without starting a container.`, } var ( + proxyHost string proxyPort int proxyTargetURI string ) func init() { + proxyCmd.Flags().StringVar(&proxyHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)") proxyCmd.Flags().IntVar(&proxyPort, "port", 0, "Port for the HTTP proxy to listen on (host port)") proxyCmd.Flags().StringVar( &proxyTargetURI, @@ -52,6 +55,13 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { // Get the server name serverName := args[0] + // Validate the host flag and default resolving to IP in case hostname is provided + validatedHost, err := ValidateAndNormaliseHostFlag(proxyHost) + if err != nil { + return fmt.Errorf("invalid host: %s", proxyHost) + } + proxyHost = validatedHost + // Select a port for the HTTP proxy (host port) port, err := networking.FindOrUsePort(proxyPort) if err != nil { @@ -94,7 +104,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { port, proxyTargetURI) // Create the transparent proxy with middlewares - proxy := transparent.NewTransparentProxy(port, serverName, proxyTargetURI, middlewares...) + proxy := transparent.NewTransparentProxy(proxyHost, port, serverName, proxyTargetURI, middlewares...) if err := proxy.Start(ctx); err != nil { return fmt.Errorf("failed to start proxy: %v", err) } diff --git a/cmd/thv/app/registry.go b/cmd/thv/app/registry.go index 46bb917..78859a3 100644 --- a/cmd/thv/app/registry.go +++ b/cmd/thv/app/registry.go @@ -48,8 +48,8 @@ func init() { registryCmd.AddCommand(registryInfoCmd) // Add flags for list and info commands - registryListCmd.Flags().StringVar(®istryFormat, "format", "text", "Output format (json or text)") - registryInfoCmd.Flags().StringVar(®istryFormat, "format", "text", "Output format (json or text)") + registryListCmd.Flags().StringVar(®istryFormat, "format", FormatText, "Output format (json or text)") + registryInfoCmd.Flags().StringVar(®istryFormat, "format", FormatText, "Output format (json or text)") } func registryListCmdFunc(_ *cobra.Command, _ []string) error { @@ -66,7 +66,7 @@ func registryListCmdFunc(_ *cobra.Command, _ []string) error { // Output based on format switch registryFormat { - case "json": + case FormatJSON: return printJSONServers(servers) default: printTextServers(servers) @@ -84,7 +84,7 @@ func registryInfoCmdFunc(_ *cobra.Command, args []string) error { // Output based on format switch registryFormat { - case "json": + case FormatJSON: return printJSONServer(server) default: printTextServerInfo(serverName, server) diff --git a/cmd/thv/app/run.go b/cmd/thv/app/run.go index 9df2f38..4f48df1 100644 --- a/cmd/thv/app/run.go +++ b/cmd/thv/app/run.go @@ -3,6 +3,7 @@ package app import ( "context" "fmt" + "net" "os" "strings" @@ -17,6 +18,7 @@ import ( "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/runner" + "github.com/stacklok/toolhive/pkg/transport" ) var runCmd = &cobra.Command{ @@ -56,6 +58,7 @@ permission profile. Additional configuration can be provided via flags.`, var ( runTransport string runName string + runHost string runPort int runTargetPort int runTargetHost string @@ -72,12 +75,13 @@ var ( func init() { runCmd.Flags().StringVar(&runTransport, "transport", "stdio", "Transport mode (sse or stdio)") runCmd.Flags().StringVar(&runName, "name", "", "Name of the MCP server (auto-generated from image if not provided)") + runCmd.Flags().StringVar(&runHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)") runCmd.Flags().IntVar(&runPort, "port", 0, "Port for the HTTP proxy to listen on (host port)") runCmd.Flags().IntVar(&runTargetPort, "target-port", 0, "Port for the container to expose (only applicable to SSE transport)") runCmd.Flags().StringVar( &runTargetHost, "target-host", - "localhost", + transport.LocalhostIPv4, "Host to forward traffic to (only applicable to SSE transport)") runCmd.Flags().StringVar( &runPermissionProfile, @@ -131,6 +135,14 @@ func init() { func runCmdFunc(cmd *cobra.Command, args []string) error { ctx := cmd.Context() + + // Validate the host flag and default resolving to IP in case hostname is provided + validatedHost, err := ValidateAndNormaliseHostFlag(runHost) + if err != nil { + return fmt.Errorf("invalid host: %s", runHost) + } + runHost = validatedHost + // Get the server name or image serverOrImage := args[0] @@ -173,6 +185,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error { rt, cmdArgs, runName, + runHost, debugMode, runVolumes, runSecrets, @@ -437,3 +450,31 @@ func parseCommandArguments(args []string) []string { } return cmdArgs } + +// ValidateAndNormaliseHostFlag validates and normalizes the host flag resolving it to an IP address if hostname is provided +func ValidateAndNormaliseHostFlag(host string) (string, error) { + // Check if the host is a valid IP address + ip := net.ParseIP(host) + if ip != nil { + if ip.To4() == nil { + return "", fmt.Errorf("IPv6 addresses are not supported: %s", host) + } + return host, nil + } + + // If not an IP address, resolve the hostname to an IP address + addrs, err := net.LookupHost(host) + if err != nil { + return "", fmt.Errorf("invalid host: %s", host) + } + + // Use the first IPv4 address found + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && ip.To4() != nil { + return ip.String(), nil + } + } + + return "", fmt.Errorf("could not resolve host: %s", host) +} diff --git a/docs/cli/thv_proxy.md b/docs/cli/thv_proxy.md index 807d17b..190b65f 100644 --- a/docs/cli/thv_proxy.md +++ b/docs/cli/thv_proxy.md @@ -15,6 +15,7 @@ thv proxy [flags] SERVER_NAME ``` -h, --help help for proxy + --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") --oidc-audience string Expected audience for the token --oidc-client-id string OIDC client ID --oidc-issuer string OIDC issuer URL (e.g., https://accounts.google.com) diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index 1565c1b..caaffb7 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -40,6 +40,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] -e, --env stringArray Environment variables to pass to the MCP server (format: KEY=VALUE) -f, --foreground Run in foreground mode (block until container exits) -h, --help help for run + --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") --k8s-pod-patch string JSON string to patch the Kubernetes pod template (only applicable when using Kubernetes runtime) --name string Name of the MCP server (auto-generated from image if not provided) --oidc-audience string Expected audience for the token @@ -49,7 +50,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] --permission-profile string Permission profile to use (none, network, or path to JSON file) (default "network") --port int Port for the HTTP proxy to listen on (host port) --secret stringArray Specify a secret to be fetched from the secrets manager and set as an environment variable (format: NAME,target=TARGET) - --target-host string Host to forward traffic to (only applicable to SSE transport) (default "localhost") + --target-host string Host to forward traffic to (only applicable to SSE transport) (default "127.0.0.1") --target-port int Port for the container to expose (only applicable to SSE transport) --transport string Transport mode (sse or stdio) (default "stdio") -v, --volume stringArray Mount a volume into the container (format: host-path:container-path[:ro]) diff --git a/pkg/lifecycle/manager.go b/pkg/lifecycle/manager.go index 4b9c4b7..3a314f7 100644 --- a/pkg/lifecycle/manager.go +++ b/pkg/lifecycle/manager.go @@ -202,9 +202,14 @@ func (*defaultManager) RunContainerDetached(runConfig *runner.RunConfig) error { detachedArgs = append(detachedArgs, "--name", runConfig.ContainerName) } + if runConfig.Host != "" { + detachedArgs = append(detachedArgs, "--host", runConfig.Host) + } + if runConfig.Port != 0 { detachedArgs = append(detachedArgs, "--port", fmt.Sprintf("%d", runConfig.Port)) } + if runConfig.TargetPort != 0 { detachedArgs = append(detachedArgs, "--target-port", fmt.Sprintf("%d", runConfig.TargetPort)) } diff --git a/pkg/runner/config.go b/pkg/runner/config.go index f381c2b..db84e8e 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -41,6 +41,9 @@ type RunConfig struct { // Transport is the transport mode (sse or stdio) Transport types.TransportType `json:"transport" yaml:"transport"` + // Host is the host for the HTTP proxy + Host string `json:"host" yaml:"host"` + // Port is the port for the HTTP proxy to listen on (host port) Port int `json:"port" yaml:"port"` @@ -120,6 +123,7 @@ func NewRunConfigFromFlags( runtime rt.Runtime, cmdArgs []string, name string, + host string, debug bool, volumes []string, secretsList []string, @@ -143,6 +147,7 @@ func NewRunConfigFromFlags( TargetHost: targetHost, ContainerLabels: make(map[string]string), EnvVars: make(map[string]string), + Host: host, } // Set OIDC config if any values are provided diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index 43e0b11..5be95b6 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -804,6 +804,7 @@ func TestNewRunConfigFromFlags(t *testing.T) { runtime := &mockRuntime{} cmdArgs := []string{"arg1", "arg2"} name := "test-server" + host := "localhost" debug := true volumes := []string{"/host:/container"} secretsList := []string{"secret1,target=ENV_VAR1"} @@ -819,6 +820,7 @@ func TestNewRunConfigFromFlags(t *testing.T) { runtime, cmdArgs, name, + host, debug, volumes, secretsList, diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index b928a1d..b4616b4 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -41,7 +41,7 @@ func (r *Runner) Run(ctx context.Context) error { Type: r.Config.Transport, Port: r.Config.Port, TargetPort: r.Config.TargetPort, - Host: "localhost", + Host: r.Config.Host, TargetHost: r.Config.TargetHost, Runtime: r.Config.Runtime, Debug: r.Config.Debug, diff --git a/pkg/transport/factory.go b/pkg/transport/factory.go index adfd6fe..0983a91 100644 --- a/pkg/transport/factory.go +++ b/pkg/transport/factory.go @@ -19,7 +19,7 @@ func NewFactory() *Factory { func (*Factory) Create(config types.Config) (types.Transport, error) { switch config.Type { case types.TransportTypeStdio: - return NewStdioTransport(config.Port, config.Runtime, config.Debug, config.Middlewares...), nil + return NewStdioTransport(config.Host, config.Port, config.Runtime, config.Debug, config.Middlewares...), nil case types.TransportTypeSSE: return NewSSETransport( config.Host, diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index f9bd7a6..5ea1a4f 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -48,6 +48,7 @@ type Proxy interface { //nolint:revive // Intentionally named HTTPSSEProxy despite package name type HTTPSSEProxy struct { // Basic configuration + host string port int containerName string middlewares []types.Middleware @@ -69,9 +70,10 @@ type HTTPSSEProxy struct { } // NewHTTPSSEProxy creates a new HTTP SSE proxy for transports. -func NewHTTPSSEProxy(port int, containerName string, middlewares ...types.Middleware) *HTTPSSEProxy { +func NewHTTPSSEProxy(host string, port int, containerName string, middlewares ...types.Middleware) *HTTPSSEProxy { return &HTTPSSEProxy{ middlewares: middlewares, + host: host, port: port, containerName: containerName, shutdownCh: make(chan struct{}), @@ -109,7 +111,7 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error { // Create the server p.server = &http.Server{ - Addr: fmt.Sprintf(":%d", p.port), + Addr: fmt.Sprintf("%s:%d", p.host, p.port), Handler: mux, ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks } @@ -117,8 +119,8 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error { // Start the server in a goroutine go func() { logger.Infof("HTTP proxy started for container %s on port %d", p.containerName, p.port) - logger.Infof("SSE endpoint: http://localhost:%d%s", p.port, ssecommon.HTTPSSEEndpoint) - logger.Infof("JSON-RPC endpoint: http://localhost:%d%s", p.port, ssecommon.HTTPMessagesEndpoint) + logger.Infof("SSE endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPSSEEndpoint) + logger.Infof("JSON-RPC endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPMessagesEndpoint) if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Errorf("HTTP server error: %v", err) diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 8dc8268..c945778 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -24,6 +24,7 @@ import ( //nolint:revive // Intentionally named TransparentProxy despite package name type TransparentProxy struct { // Basic configuration + host string port int containerName string targetURI string @@ -43,12 +44,14 @@ type TransparentProxy struct { // NewTransparentProxy creates a new transparent proxy with optional middlewares. func NewTransparentProxy( + host string, port int, containerName string, targetURI string, middlewares ...types.Middleware, ) *TransparentProxy { return &TransparentProxy{ + host: host, port: port, containerName: containerName, targetURI: targetURI, @@ -86,15 +89,15 @@ func (p *TransparentProxy) Start(_ context.Context) error { // Create the server p.server = &http.Server{ - Addr: fmt.Sprintf(":%d", p.port), + Addr: fmt.Sprintf("%s:%d", p.host, p.port), Handler: finalHandler, ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks } // Start the server in a goroutine go func() { - logger.Infof("Transparent proxy started for container %s on port %d -> %s", - p.containerName, p.port, p.targetURI) + logger.Infof("Transparent proxy started for container %s on %s:%d -> %s", + p.containerName, p.host, p.port, p.targetURI) if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Errorf("Transparent proxy error: %v", err) diff --git a/pkg/transport/sse.go b/pkg/transport/sse.go index 293b5b0..f89316a 100644 --- a/pkg/transport/sse.go +++ b/pkg/transport/sse.go @@ -8,7 +8,6 @@ import ( "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/transport/errors" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" @@ -18,6 +17,8 @@ import ( const ( // LocalhostName is the standard hostname for localhost LocalhostName = "localhost" + // LocalhostIPv4 is the standard IPv4 address for localhost + LocalhostIPv4 = "127.0.0.1" ) // SSETransport implements the Transport interface using Server-Sent Events. @@ -57,12 +58,12 @@ func NewSSETransport( middlewares ...types.Middleware, ) *SSETransport { if host == "" { - host = LocalhostName + host = LocalhostIPv4 } // If targetHost is not specified, default to localhost if targetHost == "" { - targetHost = LocalhostName + targetHost = LocalhostIPv4 } return &SSETransport{ @@ -112,21 +113,21 @@ func (t *SSETransport) Setup(ctx context.Context, runtime rt.Runtime, containerN containerPortStr := fmt.Sprintf("%d/tcp", t.targetPort) containerOptions.ExposedPorts[containerPortStr] = struct{}{} - // Create port bindings for localhost + // Create host port bindings (configurable through the --host flag) portBindings := []rt.PortBinding{ { - HostIP: "127.0.0.1", // IPv4 localhost + HostIP: t.host, HostPort: fmt.Sprintf("%d", t.targetPort), }, } - // Check if IPv6 is available and add IPv6 localhost binding - if networking.IsIPv6Available() { - portBindings = append(portBindings, rt.PortBinding{ - HostIP: "::1", // IPv6 localhost - HostPort: fmt.Sprintf("%d", t.targetPort), - }) - } + // Check if IPv6 is available and add IPv6 localhost binding (commented out for now) + //if networking.IsIPv6Available() { + // portBindings = append(portBindings, rt.PortBinding{ + // HostIP: "::1", // IPv6 localhost + // HostPort: fmt.Sprintf("%d", t.targetPort), + // }) + //} // Set the port bindings containerOptions.PortBindings[containerPortStr] = portBindings @@ -195,7 +196,7 @@ func (t *SSETransport) Start(ctx context.Context) error { t.port, targetURI) // Create the transparent proxy with middlewares - t.proxy = transparent.NewTransparentProxy(t.port, t.containerName, targetURI, t.middlewares...) + t.proxy = transparent.NewTransparentProxy(t.host, t.port, t.containerName, targetURI, t.middlewares...) if err := t.proxy.Start(ctx); err != nil { return err } diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index b1b7768..dfd3fa0 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -24,6 +24,7 @@ import ( // StdioTransport implements the Transport interface using standard input/output. // It acts as a proxy between the MCP client and the container's stdin/stdout. type StdioTransport struct { + host string port int containerID string containerName string @@ -51,12 +52,14 @@ type StdioTransport struct { // NewStdioTransport creates a new stdio transport. func NewStdioTransport( + host string, port int, runtime rt.Runtime, debug bool, middlewares ...types.Middleware, ) *StdioTransport { return &StdioTransport{ + host: host, port: port, runtime: runtime, debug: debug, @@ -148,7 +151,7 @@ func (t *StdioTransport) Start(ctx context.Context) error { } // Create and start the HTTP SSE proxy with middlewares - t.httpProxy = httpsse.NewHTTPSSEProxy(t.port, t.containerName, t.middlewares...) + t.httpProxy = httpsse.NewHTTPSSEProxy(t.host, t.port, t.containerName, t.middlewares...) if err := t.httpProxy.Start(ctx); err != nil { return err }