Skip to content

Add support for --host flag #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/thv/app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/stacklok/toolhive/pkg/labels"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/secrets"
"github.com/stacklok/toolhive/pkg/transport"
)

var configCmd = &cobra.Command{
Expand Down Expand Up @@ -262,7 +263,7 @@ func addRunningMCPsToClient(ctx context.Context, clientName string) error {
}

// Generate URL for the MCP server
url := client.GenerateMCPServerURL("localhost", port, name)
url := client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name)

// Update each configuration file
for _, clientConfig := range clientConfigs {
Expand Down
30 changes: 14 additions & 16 deletions cmd/thv/app/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/stacklok/toolhive/pkg/labels"
"github.com/stacklok/toolhive/pkg/lifecycle"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/transport"
)

var listCmd = &cobra.Command{
Expand All @@ -28,10 +29,7 @@ var (
)

// Constants for list command
const (
defaultHost = "localhost"
unknownTransport = "unknown"
)
const unknownTransport = "unknown"

// ContainerOutput represents container information for JSON output
type ContainerOutput struct {
Expand Down Expand Up @@ -73,7 +71,7 @@ func listCmdFunc(cmd *cobra.Command, _ []string) error {
// Output based on format
switch listFormat {
//nolint:goconst
case "json":
case FormatJSON:
return printJSONOutput(toolHiveContainers)
case "mcpservers":
return printMCPServersOutput(toolHiveContainers)
Expand Down Expand Up @@ -101,9 +99,9 @@ func printJSONOutput(containers []rt.ContainerInfo) error {
}

// Get transport type from labels
transport := labels.GetTransportType(c.Labels)
if transport == "" {
transport = unknownTransport
t := labels.GetTransportType(c.Labels)
if t == "" {
t = unknownTransport
}

// Get tool type from labels
Expand All @@ -118,15 +116,15 @@ func printJSONOutput(containers []rt.ContainerInfo) error {
// Generate URL for the MCP server
url := ""
if port > 0 {
url = client.GenerateMCPServerURL(defaultHost, port, name)
url = client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name)
}

output = append(output, ContainerOutput{
ID: truncatedID,
Name: name,
Image: c.Image,
State: c.State,
Transport: transport,
Transport: t,
ToolType: toolType,
Port: port,
URL: url,
Expand Down Expand Up @@ -174,7 +172,7 @@ func printMCPServersOutput(containers []rt.ContainerInfo) error {
// Generate URL for the MCP server
url := ""
if port > 0 {
url = client.GenerateMCPServerURL(defaultHost, port, name)
url = client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name)
}

// Add the MCP server to the map
Expand Down Expand Up @@ -217,9 +215,9 @@ func printTextOutput(containers []rt.ContainerInfo) {
}

// Get transport type from labels
transport := labels.GetTransportType(c.Labels)
if transport == "" {
transport = unknownTransport
t := labels.GetTransportType(c.Labels)
if t == "" {
t = unknownTransport
}

// Get port from labels
Expand All @@ -231,7 +229,7 @@ func printTextOutput(containers []rt.ContainerInfo) {
// Generate URL for the MCP server
url := ""
if port > 0 {
url = client.GenerateMCPServerURL(defaultHost, port, name)
url = client.GenerateMCPServerURL(transport.LocalhostIPv4, port, name)
}

// Print container information
Expand All @@ -240,7 +238,7 @@ func printTextOutput(containers []rt.ContainerInfo) {
name,
c.Image,
c.State,
transport,
t,
port,
url,
)
Expand Down
12 changes: 11 additions & 1 deletion cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/networking"
"github.com/stacklok/toolhive/pkg/transport"
"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
"github.com/stacklok/toolhive/pkg/transport/types"
)
Expand All @@ -25,11 +26,13 @@ This command creates a standalone proxy without starting a container.`,
}

var (
proxyHost string
proxyPort int
proxyTargetURI string
)

func init() {
proxyCmd.Flags().StringVar(&proxyHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)")
proxyCmd.Flags().IntVar(&proxyPort, "port", 0, "Port for the HTTP proxy to listen on (host port)")
proxyCmd.Flags().StringVar(
&proxyTargetURI,
Expand All @@ -52,6 +55,13 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
// Get the server name
serverName := args[0]

// Validate the host flag and default resolving to IP in case hostname is provided
validatedHost, err := ValidateAndNormaliseHostFlag(proxyHost)
if err != nil {
return fmt.Errorf("invalid host: %s", proxyHost)
}
proxyHost = validatedHost

// Select a port for the HTTP proxy (host port)
port, err := networking.FindOrUsePort(proxyPort)
if err != nil {
Expand Down Expand Up @@ -94,7 +104,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
port, proxyTargetURI)

// Create the transparent proxy with middlewares
proxy := transparent.NewTransparentProxy(port, serverName, proxyTargetURI, middlewares...)
proxy := transparent.NewTransparentProxy(proxyHost, port, serverName, proxyTargetURI, middlewares...)
if err := proxy.Start(ctx); err != nil {
return fmt.Errorf("failed to start proxy: %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions cmd/thv/app/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func init() {
registryCmd.AddCommand(registryInfoCmd)

// Add flags for list and info commands
registryListCmd.Flags().StringVar(&registryFormat, "format", "text", "Output format (json or text)")
registryInfoCmd.Flags().StringVar(&registryFormat, "format", "text", "Output format (json or text)")
registryListCmd.Flags().StringVar(&registryFormat, "format", FormatText, "Output format (json or text)")
registryInfoCmd.Flags().StringVar(&registryFormat, "format", FormatText, "Output format (json or text)")
}

func registryListCmdFunc(_ *cobra.Command, _ []string) error {
Expand All @@ -66,7 +66,7 @@ func registryListCmdFunc(_ *cobra.Command, _ []string) error {

// Output based on format
switch registryFormat {
case "json":
case FormatJSON:
return printJSONServers(servers)
default:
printTextServers(servers)
Expand All @@ -84,7 +84,7 @@ func registryInfoCmdFunc(_ *cobra.Command, args []string) error {

// Output based on format
switch registryFormat {
case "json":
case FormatJSON:
return printJSONServer(server)
default:
printTextServerInfo(serverName, server)
Expand Down
43 changes: 42 additions & 1 deletion cmd/thv/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package app
import (
"context"
"fmt"
"net"
"os"
"strings"

Expand All @@ -17,6 +18,7 @@ import (
"github.com/stacklok/toolhive/pkg/permissions"
"github.com/stacklok/toolhive/pkg/registry"
"github.com/stacklok/toolhive/pkg/runner"
"github.com/stacklok/toolhive/pkg/transport"
)

var runCmd = &cobra.Command{
Expand Down Expand Up @@ -56,6 +58,7 @@ permission profile. Additional configuration can be provided via flags.`,
var (
runTransport string
runName string
runHost string
runPort int
runTargetPort int
runTargetHost string
Expand All @@ -72,12 +75,13 @@ var (
func init() {
runCmd.Flags().StringVar(&runTransport, "transport", "stdio", "Transport mode (sse or stdio)")
runCmd.Flags().StringVar(&runName, "name", "", "Name of the MCP server (auto-generated from image if not provided)")
runCmd.Flags().StringVar(&runHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)")
runCmd.Flags().IntVar(&runPort, "port", 0, "Port for the HTTP proxy to listen on (host port)")
runCmd.Flags().IntVar(&runTargetPort, "target-port", 0, "Port for the container to expose (only applicable to SSE transport)")
runCmd.Flags().StringVar(
&runTargetHost,
"target-host",
"localhost",
transport.LocalhostIPv4,
"Host to forward traffic to (only applicable to SSE transport)")
runCmd.Flags().StringVar(
&runPermissionProfile,
Expand Down Expand Up @@ -131,6 +135,14 @@ func init() {

func runCmdFunc(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

// Validate the host flag and default resolving to IP in case hostname is provided
validatedHost, err := ValidateAndNormaliseHostFlag(runHost)
if err != nil {
return fmt.Errorf("invalid host: %s", runHost)
}
runHost = validatedHost

// Get the server name or image
serverOrImage := args[0]

Expand Down Expand Up @@ -173,6 +185,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
rt,
cmdArgs,
runName,
runHost,
debugMode,
runVolumes,
runSecrets,
Expand Down Expand Up @@ -437,3 +450,31 @@ func parseCommandArguments(args []string) []string {
}
return cmdArgs
}

// ValidateAndNormaliseHostFlag validates and normalizes the host flag resolving it to an IP address if hostname is provided
func ValidateAndNormaliseHostFlag(host string) (string, error) {
// Check if the host is a valid IP address
ip := net.ParseIP(host)
if ip != nil {
if ip.To4() == nil {
return "", fmt.Errorf("IPv6 addresses are not supported: %s", host)
}
return host, nil
}

// If not an IP address, resolve the hostname to an IP address
addrs, err := net.LookupHost(host)
if err != nil {
return "", fmt.Errorf("invalid host: %s", host)
}

// Use the first IPv4 address found
for _, addr := range addrs {
ip := net.ParseIP(addr)
if ip != nil && ip.To4() != nil {
return ip.String(), nil
}
}

return "", fmt.Errorf("could not resolve host: %s", host)
}
1 change: 1 addition & 0 deletions docs/cli/thv_proxy.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion docs/cli/thv_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pkg/lifecycle/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,14 @@ func (*defaultManager) RunContainerDetached(runConfig *runner.RunConfig) error {
detachedArgs = append(detachedArgs, "--name", runConfig.ContainerName)
}

if runConfig.Host != "" {
detachedArgs = append(detachedArgs, "--host", runConfig.Host)
}

if runConfig.Port != 0 {
detachedArgs = append(detachedArgs, "--port", fmt.Sprintf("%d", runConfig.Port))
}

if runConfig.TargetPort != 0 {
detachedArgs = append(detachedArgs, "--target-port", fmt.Sprintf("%d", runConfig.TargetPort))
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type RunConfig struct {
// Transport is the transport mode (sse or stdio)
Transport types.TransportType `json:"transport" yaml:"transport"`

// Host is the host for the HTTP proxy
Host string `json:"host" yaml:"host"`

// Port is the port for the HTTP proxy to listen on (host port)
Port int `json:"port" yaml:"port"`

Expand Down Expand Up @@ -120,6 +123,7 @@ func NewRunConfigFromFlags(
runtime rt.Runtime,
cmdArgs []string,
name string,
host string,
debug bool,
volumes []string,
secretsList []string,
Expand All @@ -143,6 +147,7 @@ func NewRunConfigFromFlags(
TargetHost: targetHost,
ContainerLabels: make(map[string]string),
EnvVars: make(map[string]string),
Host: host,
}

// Set OIDC config if any values are provided
Expand Down
2 changes: 2 additions & 0 deletions pkg/runner/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ func TestNewRunConfigFromFlags(t *testing.T) {
runtime := &mockRuntime{}
cmdArgs := []string{"arg1", "arg2"}
name := "test-server"
host := "localhost"
debug := true
volumes := []string{"/host:/container"}
secretsList := []string{"secret1,target=ENV_VAR1"}
Expand All @@ -819,6 +820,7 @@ func TestNewRunConfigFromFlags(t *testing.T) {
runtime,
cmdArgs,
name,
host,
debug,
volumes,
secretsList,
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (r *Runner) Run(ctx context.Context) error {
Type: r.Config.Transport,
Port: r.Config.Port,
TargetPort: r.Config.TargetPort,
Host: "localhost",
Host: r.Config.Host,
TargetHost: r.Config.TargetHost,
Runtime: r.Config.Runtime,
Debug: r.Config.Debug,
Expand Down
2 changes: 1 addition & 1 deletion pkg/transport/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func NewFactory() *Factory {
func (*Factory) Create(config types.Config) (types.Transport, error) {
switch config.Type {
case types.TransportTypeStdio:
return NewStdioTransport(config.Port, config.Runtime, config.Debug, config.Middlewares...), nil
return NewStdioTransport(config.Host, config.Port, config.Runtime, config.Debug, config.Middlewares...), nil
case types.TransportTypeSSE:
return NewSSETransport(
config.Host,
Expand Down
Loading