Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/internal/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ func ServeFlags(flags *pflag.FlagSet, opts *ToolboxOptions) {
flags.StringVar(&opts.Cfg.ToolboxUrl, "toolbox-url", "", "Specifies the Toolbox URL. Used as the resource field in the MCP PRM file when MCP Auth is enabled. Falls back to TOOLBOX_URL environment variable.")
flags.StringVar(&opts.Cfg.McpPrmFile, "mcp-prm-file", "", "Path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation.")
flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. If unset, defaults to a loopback-only allowlist ('127.0.0.1', 'localhost', '::1') when --address is a loopback address (to prevent DNS rebinding in local development), and to '*' otherwise.")
}
4 changes: 4 additions & 0 deletions cmd/internal/serve/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func runServe(cmd *cobra.Command, opts *internal.ToolboxOptions) error {
_ = shutdown(ctx)
}()

// Record whether the user explicitly set --allowed-hosts so the server can
// apply a context-aware (loopback-only) default for local deployments.
opts.Cfg.AllowedHostsSet = cmd.Flags().Changed("allowed-hosts")

// start server
s, err := server.NewServer(ctx, opts.Cfg)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,10 @@ func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error {
}
}

// Record whether the user explicitly set --allowed-hosts so the server can
// apply a context-aware (loopback-only) default for local deployments.
opts.Cfg.AllowedHostsSet = cmd.Flags().Changed("allowed-hosts")

// start server
s, err := server.NewServer(ctx, opts.Cfg)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type ServerConfig struct {
AllowedOrigins []string
// Specifies a list of hosts permitted to access this server.
AllowedHosts []string
// AllowedHostsSet indicates whether the user explicitly set --allowed-hosts.
// When false, the server may apply a context-aware default (loopback-only)
// for local (loopback) deployments to mitigate DNS rebinding attacks.
AllowedHostsSet bool
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
UserAgentMetadata []string
// PollInterval sets the polling frequency for configuration file updates.
Expand Down
37 changes: 37 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,32 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
}

// loopbackAllowedHosts is the secure-by-default allowlist applied to
// --allowed-hosts when the server binds to a loopback address and the user did
// not explicitly set the flag. It blocks DNS rebinding for local development.
var loopbackAllowedHosts = []string{"127.0.0.1", "localhost", "::1"}

// isLoopbackAddress reports whether addr refers to the loopback interface
// (e.g. "127.0.0.1", any "127.*" address, "localhost", or "::1"). It returns
// false for wildcard binds ("0.0.0.0", "::") and specific public addresses.
func isLoopbackAddress(addr string) bool {
host := strings.TrimSpace(addr)
// Tolerate an address that includes a port (e.g. "127.0.0.1:5000").
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// Strip brackets from IPv6 literals (e.g. "[::1]").
host = strings.Trim(host, "[]")
if host == "localhost" {
return true
}
if ip := net.ParseIP(host); ip != nil {
return ip.IsLoopback()
}
// Fall back to a 127.* prefix check for non-canonical IPv4 inputs.
return strings.HasPrefix(host, "127.")
Comment on lines +339 to +340

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback prefix check strings.HasPrefix(host, "127.") can incorrectly match domain names that start with 127. (for example, 127.example.com or 127-server.local if resolved/configured as the address). If a remote deployment binds to such a domain, it would be incorrectly classified as a loopback address, causing the server to restrict --allowed-hosts to loopback only and breaking remote access.

To prevent this, we should ensure that the remaining characters after 127. only contain digits and dots (which are characteristic of non-canonical IPv4 representations like 127.1 or 127.0.1).

	// Fall back to a 127.* prefix check for non-canonical IPv4 inputs.
	if strings.HasPrefix(host, "127.") {
		for i := 4; i < len(host); i++ {
			if (host[i] < '0' || host[i] > '9') && host[i] != '.' {
				return false
			}
		}
		return len(host) > 4
	}
	return false

}

func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -407,6 +433,17 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
}
r.Use(cors.Handler(corsOpts))
// validate hosts for DNS rebinding attacks
//
// Secure-by-default for local development: if the user did not explicitly
// set --allowed-hosts and the server binds to a loopback address, downgrade
// the wildcard default to a loopback-only allowlist so a malicious web page
// cannot reach the server via DNS rebinding. Non-loopback binds (e.g.
// 0.0.0.0 / :: on Cloud Run or in a container) keep the wildcard default so
// remote access is not broken.
if !cfg.AllowedHostsSet && isLoopbackAddress(cfg.Address) && slices.Contains(cfg.AllowedHosts, "*") {
cfg.AllowedHosts = slices.Clone(loopbackAllowedHosts)
s.logger.InfoContext(ctx, "Defaulting --allowed-hosts to loopback for local development; pass --allowed-hosts explicitly to override.")
}
if slices.Contains(cfg.AllowedHosts, "*") {
s.logger.WarnContext(ctx, "wildcard (*) hosts allow any domain to access this resource, making it vulnerable to DNS rebinding attacks regardless of whether you are in a production or local development environment. For improved security, use the --allowed-hosts flag to specify trusted domains.")
}
Expand Down
206 changes: 172 additions & 34 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ func TestServe(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := server.ServerConfig{
Version: "0.0.0",
Address: tt.addr,
Port: tt.port,
AllowedHosts: []string{"*"},
Version: "0.0.0",
Address: tt.addr,
Port: tt.port,
AllowedHosts: []string{"*"},
AllowedHostsSet: true,
}

instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
Expand Down Expand Up @@ -336,12 +337,13 @@ func TestEndpointSecurityAllowedOrigin(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
addr, port := "127.0.0.1", 0
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
EnableAPI: true,
AllowedOrigins: tc.allowedOrigins,
AllowedHosts: []string{"*"},
Version: "0.0.0",
Address: addr,
Port: port,
EnableAPI: true,
AllowedOrigins: tc.allowedOrigins,
AllowedHosts: []string{"*"},
AllowedHostsSet: true,
}

instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
Expand Down Expand Up @@ -485,11 +487,12 @@ func TestEndpointSecurityAllowedHost(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
addr, port := "127.0.0.1", 0
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
EnableAPI: true,
AllowedHosts: tc.allowedHosts,
Version: "0.0.0",
Address: addr,
Port: port,
EnableAPI: true,
AllowedHosts: tc.allowedHosts,
AllowedHostsSet: true,
}

instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
Expand Down Expand Up @@ -604,6 +607,137 @@ func TestEndpointSecurityAllowedHost(t *testing.T) {
}
}

// TestAllowedHostsContextAwareDefault verifies the secure-by-default behavior of
// --allowed-hosts: when the user does not explicitly set the flag (the wildcard
// default is in place) and the server binds to a loopback address, the host
// allowlist is downgraded to loopback-only to block DNS rebinding. For
// non-loopback binds the wildcard is preserved, and an explicit flag value is
// always respected.
func TestAllowedHostsContextAwareDefault(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("error setting up logger: %s", err)
}

testCases := []struct {
desc string
address string
allowedHosts []string
allowedHostsSet bool
host string
wantStatus int
}{
{
desc: "loopback bind, no flag => loopback default blocks rebinding host",
address: "127.0.0.1",
allowedHosts: []string{"*"},
allowedHostsSet: false,
host: "evil.com",
wantStatus: http.StatusForbidden,
},
{
desc: "loopback bind, no flag => loopback default allows localhost",
address: "127.0.0.1",
allowedHosts: []string{"*"},
allowedHostsSet: false,
host: "localhost",
wantStatus: http.StatusOK,
},
{
desc: "loopback bind, no flag => loopback default allows 127.0.0.1",
address: "127.0.0.1",
allowedHosts: []string{"*"},
allowedHostsSet: false,
host: "127.0.0.1",
wantStatus: http.StatusOK,
},
{
desc: "non-loopback bind, no flag => wildcard preserved",
address: "0.0.0.0",
allowedHosts: []string{"*"},
allowedHostsSet: false,
host: "evil.com",
wantStatus: http.StatusOK,
},
{
desc: "loopback bind, explicit wildcard flag => respected",
address: "127.0.0.1",
allowedHosts: []string{"*"},
allowedHostsSet: true,
host: "evil.com",
wantStatus: http.StatusOK,
},
{
desc: "loopback bind, explicit specific flag => respected",
address: "127.0.0.1",
allowedHosts: []string{"trusted.com"},
allowedHostsSet: true,
host: "trusted.com",
wantStatus: http.StatusOK,
},
}

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cfg := server.ServerConfig{
Version: "0.0.0",
Address: tc.address,
Port: 0,
EnableAPI: true,
AllowedHosts: tc.allowedHosts,
AllowedHostsSet: tc.allowedHostsSet,
}

instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx = util.WithInstrumentation(ctx, instrumentation)

s, err := server.NewServer(ctx, cfg)
if err != nil {
t.Fatalf("error setting up server: %s", err)
}

if err := s.Listen(ctx, "", ""); err != nil {
t.Fatalf("unable to start server: %v", err)
}

_, actualPort, err := net.SplitHostPort(s.Addr())
if err != nil {
t.Fatalf("failed to parse server address: %v", err)
}

go func() {
if err := s.Serve(ctx); err != nil && err != http.ErrServerClosed {
t.Errorf("server serve error: %v", err)
}
}()

// Always dial the loopback interface (the listener accepts on it for
// both 127.0.0.1 and 0.0.0.0 binds); the host-check keys off the
// Host header, which we set independently below.
reqURL := fmt.Sprintf("http://127.0.0.1:%s/api/toolset", actualPort)
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Host = net.JoinHostPort(tc.host, actualPort)

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("failed to send request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tc.wantStatus {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("expected status %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(body))
}
})
}
}

func TestNameValidation(t *testing.T) {
testCases := []struct {
desc string
Expand Down Expand Up @@ -709,11 +843,12 @@ func TestPRMEndpoint(t *testing.T) {
// Configure the server
addr, port := "127.0.0.1", 5003
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
ToolboxUrl: "https://my-toolbox.example.com",
AllowedHosts: []string{"*"},
Version: "0.0.0",
Address: addr,
Port: port,
ToolboxUrl: "https://my-toolbox.example.com",
AllowedHosts: []string{"*"},
AllowedHostsSet: true,
AuthServiceConfigs: map[string]auth.AuthServiceConfig{
"generic1": generic.Config{
Name: "generic1",
Expand Down Expand Up @@ -818,11 +953,12 @@ func TestPRMOverride(t *testing.T) {
// Configure the server with the Override Flag
addr, port := "127.0.0.1", 5004
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
McpPrmFile: tmpFile.Name(),
AllowedHosts: []string{"*"},
Version: "0.0.0",
Address: addr,
Port: port,
McpPrmFile: tmpFile.Name(),
AllowedHosts: []string{"*"},
AllowedHostsSet: true,
}

// Initialize and Start the Server
Expand Down Expand Up @@ -895,10 +1031,11 @@ func TestLegacyAPIGone(t *testing.T) {
// Configure the server (EnableAPI defaults to false)
addr, port := "127.0.0.1", 5005
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
AllowedHosts: []string{"*"},
Version: "0.0.0",
Address: addr,
Port: port,
AllowedHosts: []string{"*"},
AllowedHostsSet: true,
}

// Initialize and Start the Server
Expand Down Expand Up @@ -1005,11 +1142,12 @@ func TestMCPAuthMiddleware(t *testing.T) {
// Configure the server
addr, port := "127.0.0.1", 5004
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
ToolboxUrl: "https://my-toolbox.example.com",
AllowedHosts: []string{"*"},
Version: "0.0.0",
Address: addr,
Port: port,
ToolboxUrl: "https://my-toolbox.example.com",
AllowedHosts: []string{"*"},
AllowedHostsSet: true,
AuthServiceConfigs: map[string]auth.AuthServiceConfig{
"generic1": generic.Config{
Name: "generic1",
Expand Down