Skip to content

Commit 283e4e3

Browse files
authored
feat(auth): support manual PRM override (#2717)
Add flag to allow manual PRM override.
1 parent f8891b8 commit 283e4e3

File tree

6 files changed

+197
-1
lines changed

6 files changed

+197
-1
lines changed

cmd/internal/flags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ func ServeFlags(flags *pflag.FlagSet, opts *ToolboxOptions) {
6565
flags.BoolVar(&opts.Cfg.UI, "ui", false, "Launches the Toolbox UI web server.")
6666
flags.BoolVar(&opts.Cfg.EnableAPI, "enable-api", false, "Enable the /api endpoint.")
6767
flags.StringVar(&opts.Cfg.ToolboxUrl, "toolbox-url", "", "Specifies the Toolbox URL. Used as the resource field in the MCP PRM file when MCP Auth is enabled. Falls back to TOOLBOX_URL environment variable.")
68+
flags.StringVar(&opts.Cfg.McpPrmFile, "mcp-prm-file", "", "Path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation.")
6869
flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
6970
flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
7071
}

docs/en/reference/cli.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ description: >
1515
| `-h` | `--help` | help for toolbox | |
1616
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
1717
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
18+
| | `--mcp-prm-file` | Path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation for MCP Server-Wide Authentication. | |
1819
| `-p` | `--port` | Port the server will listen on. | `5000` |
1920
| | `--prebuilt` | Use one or more prebuilt tool configuration by source type. See [Prebuilt Tools Reference](../documentation/configuration/prebuilt-configs/_index.md) for allowed values. | |
2021
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |

internal/server/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ type ServerConfig struct {
7474
EnableAPI bool
7575
// ToolboxUrl specifies the URL to advertise in the MCP PRM file as the resource field.
7676
ToolboxUrl string
77+
// McpPrmFile specifies the path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation.
78+
McpPrmFile string
7779
// Specifies a list of origins permitted to access this server.
7880
AllowedOrigins []string
7981
// Specifies a list of hosts permitted to access this server.

internal/server/prm.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package server
16+
17+
// ProtectedResourceMetadata represents the OAuth 2.0 Protected Resource Metadata document as defined in RFC 9728.
18+
// Reference: https://datatracker.ietf.org/doc/html/rfc9728
19+
type ProtectedResourceMetadata struct {
20+
// REQUIRED. The protected resource's resource identifier (a URL using the https scheme).
21+
Resource string `json:"resource"`
22+
23+
// REQUIRED. Array containing a list of OAuth authorization server issuer identifiers.
24+
AuthorizationServers []string `json:"authorization_servers,omitempty"`
25+
26+
// OPTIONAL. URL of the protected resource's JSON Web Key (JWK) Set document.
27+
JWKSURI string `json:"jwks_uri,omitempty"`
28+
29+
// RECOMMENDED. Array containing a list of scope values used to request access.
30+
ScopesSupported []string `json:"scopes_supported,omitempty"`
31+
32+
// OPTIONAL. Array containing a list of the supported methods of sending an
33+
// OAuth 2.0 bearer token (e.g., "header", "body", "query").
34+
BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"`
35+
36+
// OPTIONAL. Array containing a list of the JWS signing algorithms (alg values)
37+
// supported by the protected resource for signing resource responses.
38+
ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"`
39+
40+
// RECOMMENDED. Human-readable name of the protected resource intended for display.
41+
ResourceName string `json:"resource_name,omitempty"`
42+
43+
// OPTIONAL. URL of a page containing human-readable developer documentation.
44+
ResourceDocumentation string `json:"resource_documentation,omitempty"`
45+
46+
// OPTIONAL. URL of a page containing human-readable policy requirements.
47+
ResourcePolicyURI string `json:"resource_policy_uri,omitempty"`
48+
49+
// OPTIONAL. URL of a page containing human-readable terms of service.
50+
ResourceTOSURI string `json:"resource_tos_uri,omitempty"`
51+
52+
// OPTIONAL. Boolean indicating support for mutual-TLS client certificate-bound
53+
// access tokens. If omitted, the default is false.
54+
TLSClientCertificateBoundAccessTokens *bool `json:"tls_client_certificate_bound_access_tokens,omitempty"`
55+
56+
// OPTIONAL. Array containing a list of authorization details type values supported.
57+
AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"`
58+
59+
// OPTIONAL. Array containing a list of JWS alg values supported for DPoP proof JWTs.
60+
DPoPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"`
61+
62+
// OPTIONAL. Boolean specifying whether the protected resource always requires
63+
// the use of DPoP-bound access tokens. If omitted, the default is false.
64+
DPoPBoundAccessTokensRequired *bool `json:"dpop_bound_access_tokens_required,omitempty"`
65+
66+
// OPTIONAL. A JWT containing metadata parameters about the protected resource
67+
// as claims. Consists of the entire signed JWT string.
68+
SignedMetadata string `json:"signed_metadata,omitempty"`
69+
}

internal/server/server.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ package server
1616

1717
import (
1818
"context"
19+
"encoding/json"
1920
"errors"
2021
"fmt"
2122
"io"
2223
"net"
2324
"net/http"
25+
"os"
2426
"slices"
2527
"strconv"
2628
"strings"
@@ -55,6 +57,7 @@ type Server struct {
5557
instrumentation *telemetry.Instrumentation
5658
sseManager *sseManager
5759
ResourceMgr *resources.ResourceManager
60+
mcpPrmFile string
5861
}
5962

6063
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
@@ -382,6 +385,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
382385
sseManager: sseManager,
383386
ResourceMgr: resourceManager,
384387
toolboxUrl: cfg.ToolboxUrl,
388+
mcpPrmFile: cfg.McpPrmFile,
385389
}
386390

387391
// cors
@@ -419,8 +423,35 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
419423
break
420424
}
421425
}
422-
if mcpAuthEnabled {
426+
427+
// Manual PRM override
428+
var cachedPrmBytes []byte
429+
var prmConfig ProtectedResourceMetadata
430+
if s.mcpPrmFile != "" {
431+
var err error
432+
cachedPrmBytes, err = os.ReadFile(s.mcpPrmFile)
433+
if err != nil {
434+
return nil, fmt.Errorf("failed to read manual PRM file at startup: %w", err)
435+
}
436+
// Unmarshal into the struct to strictly validate the schema
437+
if err := json.Unmarshal(cachedPrmBytes, &prmConfig); err != nil {
438+
return nil, fmt.Errorf("manual PRM file does not match expected schema: %w", err)
439+
}
440+
}
441+
442+
// Register route if auth is enabled or a manual file is provided
443+
if mcpAuthEnabled || s.mcpPrmFile != "" {
423444
r.Get("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, req *http.Request) {
445+
// Serve from memory if file was loaded
446+
if s.mcpPrmFile != "" {
447+
w.Header().Set("Content-Type", "application/json")
448+
w.WriteHeader(http.StatusOK)
449+
if _, err := w.Write(cachedPrmBytes); err != nil {
450+
s.logger.ErrorContext(req.Context(), "failed to write manual PRM file response", "error", err)
451+
}
452+
return
453+
}
454+
424455
prmHandler(s, w, req)
425456
})
426457
}

internal/server/server_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,95 @@ func TestPRMEndpoint(t *testing.T) {
384384
t.Errorf("unexpected PRM response:\ngot %+v\nwant %+v", got, want)
385385
}
386386
}
387+
388+
func TestPRMOverride(t *testing.T) {
389+
ctx, cancel := context.WithCancel(context.Background())
390+
defer cancel()
391+
392+
// Setup a temporary PRM file
393+
prmContent := `{
394+
"resource": "https://override.example.com",
395+
"authorization_servers": ["https://auth.example.com"],
396+
"scopes_supported": ["read", "write"],
397+
"bearer_methods_supported": ["header"]
398+
}`
399+
tmpFile, err := os.CreateTemp("", "prm-*.json")
400+
if err != nil {
401+
t.Fatal(err)
402+
}
403+
defer os.Remove(tmpFile.Name())
404+
if err := os.WriteFile(tmpFile.Name(), []byte(prmContent), 0644); err != nil {
405+
t.Fatal(err)
406+
}
407+
408+
// Setup Logging and Instrumentation (Using Discard to act as Noop)
409+
testLogger, err := log.NewStdLogger(io.Discard, io.Discard, "info")
410+
if err != nil {
411+
t.Fatalf("unexpected error: %s", err)
412+
}
413+
ctx = util.WithLogger(ctx, testLogger)
414+
415+
instrumentation, err := telemetry.CreateTelemetryInstrumentation("0.0.0")
416+
if err != nil {
417+
t.Fatalf("unexpected error: %s", err)
418+
}
419+
ctx = util.WithInstrumentation(ctx, instrumentation)
420+
421+
// Configure the server with the Override Flag
422+
addr, port := "127.0.0.1", 5002
423+
cfg := server.ServerConfig{
424+
Version: "0.0.0",
425+
Address: addr,
426+
Port: port,
427+
McpPrmFile: tmpFile.Name(),
428+
AllowedHosts: []string{"*"},
429+
}
430+
431+
// Initialize and Start the Server
432+
s, err := server.NewServer(ctx, cfg)
433+
if err != nil {
434+
t.Fatalf("unable to initialize server: %v", err)
435+
}
436+
437+
if err := s.Listen(ctx); err != nil {
438+
t.Fatalf("unable to start listener: %v", err)
439+
}
440+
441+
go func() {
442+
if err := s.Serve(ctx); err != nil && err != http.ErrServerClosed {
443+
fmt.Printf("Server serve error: %v\n", err)
444+
}
445+
}()
446+
defer func() {
447+
if err := s.Shutdown(ctx); err != nil {
448+
t.Errorf("failed to cleanly shutdown server: %v", err)
449+
}
450+
}()
451+
452+
// Perform the request to the well-known endpoint
453+
url := fmt.Sprintf("http://%s:%d/.well-known/oauth-protected-resource", addr, port)
454+
resp, err := http.Get(url)
455+
if err != nil {
456+
t.Fatalf("error when sending request: %s", err)
457+
}
458+
defer resp.Body.Close()
459+
460+
if resp.StatusCode != http.StatusOK {
461+
t.Fatalf("expected status 200, got %d", resp.StatusCode)
462+
}
463+
464+
body, err := io.ReadAll(resp.Body)
465+
if err != nil {
466+
t.Fatalf("error reading body: %s", err)
467+
}
468+
469+
// Verification
470+
var got map[string]any
471+
if err := json.Unmarshal(body, &got); err != nil {
472+
t.Fatalf("invalid json response: %s", err)
473+
}
474+
475+
if got["resource"] != "https://override.example.com" {
476+
t.Errorf("expected resource 'https://override.example.com', got '%v'", got["resource"])
477+
}
478+
}

0 commit comments

Comments
 (0)