Skip to content

Commit 129ebff

Browse files
committed
Add support for --host flag
Signed-off-by: Radoslav Dimitrov <[email protected]>
1 parent a4dd0d9 commit 129ebff

File tree

11 files changed

+92
-22
lines changed

11 files changed

+92
-22
lines changed

cmd/thv/app/proxy.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/stacklok/toolhive/pkg/auth"
1212
"github.com/stacklok/toolhive/pkg/logger"
1313
"github.com/stacklok/toolhive/pkg/networking"
14+
"github.com/stacklok/toolhive/pkg/transport"
1415
"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
1516
"github.com/stacklok/toolhive/pkg/transport/types"
1617
)
@@ -25,11 +26,13 @@ This command creates a standalone proxy without starting a container.`,
2526
}
2627

2728
var (
29+
proxyHost string
2830
proxyPort int
2931
proxyTargetURI string
3032
)
3133

3234
func init() {
35+
proxyCmd.Flags().StringVar(&proxyHost, "host", transport.LocalhostName, "Host for the HTTP proxy to listen on (IP or hostname)")
3336
proxyCmd.Flags().IntVar(&proxyPort, "port", 0, "Port for the HTTP proxy to listen on (host port)")
3437
proxyCmd.Flags().StringVar(
3538
&proxyTargetURI,
@@ -52,6 +55,13 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
5255
// Get the server name
5356
serverName := args[0]
5457

58+
// Validate the host flag and default resolving to IP in case hostname is provided
59+
validatedHost, err := ValidateAndNormaliseHostFlag(proxyHost)
60+
if err != nil {
61+
return fmt.Errorf("invalid host: %s", proxyHost)
62+
}
63+
proxyHost = validatedHost
64+
5565
// Select a port for the HTTP proxy (host port)
5666
port, err := networking.FindOrUsePort(proxyPort)
5767
if err != nil {
@@ -94,7 +104,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
94104
port, proxyTargetURI)
95105

96106
// Create the transparent proxy with middlewares
97-
proxy := transparent.NewTransparentProxy(port, serverName, proxyTargetURI, middlewares...)
107+
proxy := transparent.NewTransparentProxy(proxyHost, port, serverName, proxyTargetURI, middlewares...)
98108
if err := proxy.Start(ctx); err != nil {
99109
return fmt.Errorf("failed to start proxy: %v", err)
100110
}

cmd/thv/app/run.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package app
33
import (
44
"context"
55
"fmt"
6+
"net"
67
"os"
78
"strings"
89

@@ -17,6 +18,7 @@ import (
1718
"github.com/stacklok/toolhive/pkg/permissions"
1819
"github.com/stacklok/toolhive/pkg/registry"
1920
"github.com/stacklok/toolhive/pkg/runner"
21+
"github.com/stacklok/toolhive/pkg/transport"
2022
)
2123

2224
var runCmd = &cobra.Command{
@@ -56,6 +58,7 @@ permission profile. Additional configuration can be provided via flags.`,
5658
var (
5759
runTransport string
5860
runName string
61+
runHost string
5962
runPort int
6063
runTargetPort int
6164
runTargetHost string
@@ -72,6 +75,7 @@ var (
7275
func init() {
7376
runCmd.Flags().StringVar(&runTransport, "transport", "stdio", "Transport mode (sse or stdio)")
7477
runCmd.Flags().StringVar(&runName, "name", "", "Name of the MCP server (auto-generated from image if not provided)")
78+
runCmd.Flags().StringVar(&runHost, "host", transport.LocalhostName, "Host for the HTTP proxy to listen on (IP or hostname)")
7579
runCmd.Flags().IntVar(&runPort, "port", 0, "Port for the HTTP proxy to listen on (host port)")
7680
runCmd.Flags().IntVar(&runTargetPort, "target-port", 0, "Port for the container to expose (only applicable to SSE transport)")
7781
runCmd.Flags().StringVar(
@@ -131,6 +135,14 @@ func init() {
131135

132136
func runCmdFunc(cmd *cobra.Command, args []string) error {
133137
ctx := cmd.Context()
138+
139+
// Validate the host flag and default resolving to IP in case hostname is provided
140+
validatedHost, err := ValidateAndNormaliseHostFlag(runHost)
141+
if err != nil {
142+
return fmt.Errorf("invalid host: %s", runHost)
143+
}
144+
runHost = validatedHost
145+
134146
// Get the server name or image
135147
serverOrImage := args[0]
136148

@@ -173,6 +185,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
173185
rt,
174186
cmdArgs,
175187
runName,
188+
runHost,
176189
debugMode,
177190
runVolumes,
178191
runSecrets,
@@ -437,3 +450,31 @@ func parseCommandArguments(args []string) []string {
437450
}
438451
return cmdArgs
439452
}
453+
454+
// ValidateAndNormaliseHostFlag validates and normalizes the host flag resolving it to an IP address if hostname is provided
455+
func ValidateAndNormaliseHostFlag(host string) (string, error) {
456+
// Check if the host is a valid IP address
457+
ip := net.ParseIP(host)
458+
if ip != nil {
459+
if ip.To4() == nil {
460+
return "", fmt.Errorf("IPv6 addresses are not supported: %s", host)
461+
}
462+
return host, nil
463+
}
464+
465+
// If not an IP address, resolve the hostname to an IP address
466+
addrs, err := net.LookupHost(host)
467+
if err != nil {
468+
return "", fmt.Errorf("invalid host: %s", host)
469+
}
470+
471+
// Use the first IPv4 address found
472+
for _, addr := range addrs {
473+
ip := net.ParseIP(addr)
474+
if ip != nil && ip.To4() != nil {
475+
return ip.String(), nil
476+
}
477+
}
478+
479+
return "", fmt.Errorf("could not resolve host: %s", host)
480+
}

pkg/lifecycle/manager.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,14 @@ func (*defaultManager) RunContainerDetached(runConfig *runner.RunConfig) error {
197197
detachedArgs = append(detachedArgs, "--name", runConfig.Name)
198198
}
199199

200+
if runConfig.Host != "" {
201+
detachedArgs = append(detachedArgs, "--host", runConfig.Host)
202+
}
203+
200204
if runConfig.Port != 0 {
201205
detachedArgs = append(detachedArgs, "--port", fmt.Sprintf("%d", runConfig.Port))
202206
}
207+
203208
if runConfig.TargetPort != 0 {
204209
detachedArgs = append(detachedArgs, "--target-port", fmt.Sprintf("%d", runConfig.TargetPort))
205210
}

pkg/runner/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ type RunConfig struct {
4141
// Transport is the transport mode (sse or stdio)
4242
Transport types.TransportType `json:"transport" yaml:"transport"`
4343

44+
// Host is the host for the HTTP proxy
45+
Host string `json:"host" yaml:"host"`
46+
4447
// Port is the port for the HTTP proxy to listen on (host port)
4548
Port int `json:"port" yaml:"port"`
4649

@@ -120,6 +123,7 @@ func NewRunConfigFromFlags(
120123
runtime rt.Runtime,
121124
cmdArgs []string,
122125
name string,
126+
host string,
123127
debug bool,
124128
volumes []string,
125129
secretsList []string,
@@ -143,6 +147,7 @@ func NewRunConfigFromFlags(
143147
TargetHost: targetHost,
144148
ContainerLabels: make(map[string]string),
145149
EnvVars: make(map[string]string),
150+
Host: host,
146151
}
147152

148153
// Set OIDC config if any values are provided

pkg/runner/config_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ func TestNewRunConfigFromFlags(t *testing.T) {
804804
runtime := &mockRuntime{}
805805
cmdArgs := []string{"arg1", "arg2"}
806806
name := "test-server"
807+
host := "localhost"
807808
debug := true
808809
volumes := []string{"/host:/container"}
809810
secretsList := []string{"secret1,target=ENV_VAR1"}
@@ -819,6 +820,7 @@ func TestNewRunConfigFromFlags(t *testing.T) {
819820
runtime,
820821
cmdArgs,
821822
name,
823+
host,
822824
debug,
823825
volumes,
824826
secretsList,

pkg/runner/runner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (r *Runner) Run(ctx context.Context) error {
4141
Type: r.Config.Transport,
4242
Port: r.Config.Port,
4343
TargetPort: r.Config.TargetPort,
44-
Host: "localhost",
44+
Host: r.Config.Host,
4545
TargetHost: r.Config.TargetHost,
4646
Runtime: r.Config.Runtime,
4747
Debug: r.Config.Debug,

pkg/transport/factory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func NewFactory() *Factory {
1919
func (*Factory) Create(config types.Config) (types.Transport, error) {
2020
switch config.Type {
2121
case types.TransportTypeStdio:
22-
return NewStdioTransport(config.Port, config.Runtime, config.Debug, config.Middlewares...), nil
22+
return NewStdioTransport(config.Host, config.Port, config.Runtime, config.Debug, config.Middlewares...), nil
2323
case types.TransportTypeSSE:
2424
return NewSSETransport(
2525
config.Host,

pkg/transport/proxy/httpsse/http_proxy.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type Proxy interface {
4848
//nolint:revive // Intentionally named HTTPSSEProxy despite package name
4949
type HTTPSSEProxy struct {
5050
// Basic configuration
51+
host string
5152
port int
5253
containerName string
5354
middlewares []types.Middleware
@@ -69,9 +70,10 @@ type HTTPSSEProxy struct {
6970
}
7071

7172
// NewHTTPSSEProxy creates a new HTTP SSE proxy for transports.
72-
func NewHTTPSSEProxy(port int, containerName string, middlewares ...types.Middleware) *HTTPSSEProxy {
73+
func NewHTTPSSEProxy(host string, port int, containerName string, middlewares ...types.Middleware) *HTTPSSEProxy {
7374
return &HTTPSSEProxy{
7475
middlewares: middlewares,
76+
host: host,
7577
port: port,
7678
containerName: containerName,
7779
shutdownCh: make(chan struct{}),
@@ -109,16 +111,16 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error {
109111

110112
// Create the server
111113
p.server = &http.Server{
112-
Addr: fmt.Sprintf(":%d", p.port),
114+
Addr: fmt.Sprintf("%s:%d", p.host, p.port),
113115
Handler: mux,
114116
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
115117
}
116118

117119
// Start the server in a goroutine
118120
go func() {
119121
logger.Infof("HTTP proxy started for container %s on port %d", p.containerName, p.port)
120-
logger.Infof("SSE endpoint: http://localhost:%d%s", p.port, ssecommon.HTTPSSEEndpoint)
121-
logger.Infof("JSON-RPC endpoint: http://localhost:%d%s", p.port, ssecommon.HTTPMessagesEndpoint)
122+
logger.Infof("SSE endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPSSEEndpoint)
123+
logger.Infof("JSON-RPC endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPMessagesEndpoint)
122124

123125
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
124126
logger.Errorf("HTTP server error: %v", err)

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
//nolint:revive // Intentionally named TransparentProxy despite package name
2525
type TransparentProxy struct {
2626
// Basic configuration
27+
host string
2728
port int
2829
containerName string
2930
targetURI string
@@ -43,12 +44,14 @@ type TransparentProxy struct {
4344

4445
// NewTransparentProxy creates a new transparent proxy with optional middlewares.
4546
func NewTransparentProxy(
47+
host string,
4648
port int,
4749
containerName string,
4850
targetURI string,
4951
middlewares ...types.Middleware,
5052
) *TransparentProxy {
5153
return &TransparentProxy{
54+
host: host,
5255
port: port,
5356
containerName: containerName,
5457
targetURI: targetURI,
@@ -86,15 +89,15 @@ func (p *TransparentProxy) Start(_ context.Context) error {
8689

8790
// Create the server
8891
p.server = &http.Server{
89-
Addr: fmt.Sprintf(":%d", p.port),
92+
Addr: fmt.Sprintf("%s:%d", p.host, p.port),
9093
Handler: finalHandler,
9194
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
9295
}
9396

9497
// Start the server in a goroutine
9598
go func() {
96-
logger.Infof("Transparent proxy started for container %s on port %d -> %s",
97-
p.containerName, p.port, p.targetURI)
99+
logger.Infof("Transparent proxy started for container %s on %s:%d -> %s",
100+
p.containerName, p.host, p.port, p.targetURI)
98101

99102
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
100103
logger.Errorf("Transparent proxy error: %v", err)

pkg/transport/sse.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"github.com/stacklok/toolhive/pkg/container"
99
rt "github.com/stacklok/toolhive/pkg/container/runtime"
1010
"github.com/stacklok/toolhive/pkg/logger"
11-
"github.com/stacklok/toolhive/pkg/networking"
1211
"github.com/stacklok/toolhive/pkg/permissions"
1312
"github.com/stacklok/toolhive/pkg/transport/errors"
1413
"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
@@ -112,21 +111,21 @@ func (t *SSETransport) Setup(ctx context.Context, runtime rt.Runtime, containerN
112111
containerPortStr := fmt.Sprintf("%d/tcp", t.targetPort)
113112
containerOptions.ExposedPorts[containerPortStr] = struct{}{}
114113

115-
// Create port bindings for localhost
114+
// Create host port bindings (configurable through the --host flag)
116115
portBindings := []rt.PortBinding{
117116
{
118-
HostIP: "127.0.0.1", // IPv4 localhost
117+
HostIP: t.host,
119118
HostPort: fmt.Sprintf("%d", t.targetPort),
120119
},
121120
}
122121

123-
// Check if IPv6 is available and add IPv6 localhost binding
124-
if networking.IsIPv6Available() {
125-
portBindings = append(portBindings, rt.PortBinding{
126-
HostIP: "::1", // IPv6 localhost
127-
HostPort: fmt.Sprintf("%d", t.targetPort),
128-
})
129-
}
122+
// Check if IPv6 is available and add IPv6 localhost binding (commented out for now)
123+
//if networking.IsIPv6Available() {
124+
// portBindings = append(portBindings, rt.PortBinding{
125+
// HostIP: "::1", // IPv6 localhost
126+
// HostPort: fmt.Sprintf("%d", t.targetPort),
127+
// })
128+
//}
130129

131130
// Set the port bindings
132131
containerOptions.PortBindings[containerPortStr] = portBindings
@@ -195,7 +194,7 @@ func (t *SSETransport) Start(ctx context.Context) error {
195194
t.port, targetURI)
196195

197196
// Create the transparent proxy with middlewares
198-
t.proxy = transparent.NewTransparentProxy(t.port, t.containerName, targetURI, t.middlewares...)
197+
t.proxy = transparent.NewTransparentProxy(t.host, t.port, t.containerName, targetURI, t.middlewares...)
199198
if err := t.proxy.Start(ctx); err != nil {
200199
return err
201200
}

0 commit comments

Comments
 (0)