diff --git a/examples/sqlite-vec/go.mod b/examples/sqlite-vec/go.mod index c0fabdcdb..686e69412 100644 --- a/examples/sqlite-vec/go.mod +++ b/examples/sqlite-vec/go.mod @@ -3,9 +3,15 @@ module github.com/docker/mcp-gateway/examples/sqlite-vec go 1.24 require ( - github.com/google/jsonschema-go v0.3.0 + github.com/google/jsonschema-go v0.4.2 github.com/mattn/go-sqlite3 v1.14.22 - github.com/modelcontextprotocol/go-sdk v1.0.0 + github.com/modelcontextprotocol/go-sdk v1.3.1 ) -require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect +require ( + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.3 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect +) diff --git a/go.mod b/go.mod index afe90c39a..ce255dbfd 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/jmoiron/sqlx v1.4.0 github.com/mikefarah/yq/v4 v4.45.4 - github.com/modelcontextprotocol/go-sdk v1.3.1 + github.com/modelcontextprotocol/go-sdk v1.4.0 github.com/modelcontextprotocol/registry v0.0.0-00010101000000-000000000000 github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.1 @@ -37,7 +37,7 @@ require ( go.opentelemetry.io/otel/sdk v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 - golang.org/x/oauth2 v0.32.0 + golang.org/x/oauth2 v0.34.0 golang.org/x/sync v0.19.0 gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473 gopkg.in/yaml.v3 v3.0.1 @@ -200,7 +200,7 @@ require ( golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.47.0 // indirect - golang.org/x/sys v0.39.0 // indirect + golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.38.0 // indirect golang.org/x/text v0.32.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c // indirect diff --git a/go.sum b/go.sum index d77417f95..55f5755f7 100644 --- a/go.sum +++ b/go.sum @@ -564,8 +564,8 @@ github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7z github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= -github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI= -github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw= +github.com/modelcontextprotocol/go-sdk v1.4.0 h1:u0kr8lbJc1oBcawK7Df+/ajNMpIDFE41OEPxdeTLOn8= +github.com/modelcontextprotocol/go-sdk v1.4.0/go.mod h1:Nxc2n+n/GdCebUaqCOhTetptS17SXXNu9IfNTaLDi1E= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -869,8 +869,8 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= -golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -902,8 +902,8 @@ golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -922,8 +922,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE index 508be9266..5791499cb 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE +++ b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE @@ -1,6 +1,193 @@ +The MCP project is undergoing a licensing transition from the MIT License to the Apache License, Version 2.0 ("Apache-2.0"). All new code and specification contributions to the project are licensed under Apache-2.0. Documentation contributions (excluding specifications) are licensed under CC-BY-4.0. + +Contributions for which relicensing consent has been obtained are licensed under Apache-2.0. Contributions made by authors who originally licensed their work under the MIT License and who have not yet granted explicit permission to relicense remain licensed under the MIT License. + +No rights beyond those granted by the applicable original license are conveyed for such contributions. + +--- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright + owner or by an individual or Legal Entity authorized to submit on behalf + of the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + +--- + MIT License -Copyright (c) 2025 Go MCP SDK Authors +Copyright (c) 2024-2025 Model Context Protocol a Series of LF Projects, LLC. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -19,3 +206,11 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- + +Creative Commons Attribution 4.0 International (CC-BY-4.0) + +Documentation in this project (excluding specifications) is licensed under +CC-BY-4.0. See https://creativecommons.org/licenses/by/4.0/legalcode for +the full license text. diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go index 87665121c..36ff259e9 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go @@ -25,8 +25,7 @@ type TokenInfo struct { // session hijacking by ensuring that all requests for a given session // come from the same user. UserID string - // TODO: add standard JWT fields - Extra map[string]any + Extra map[string]any } // The error that a TokenVerifier should return if the token cannot be verified. @@ -106,6 +105,9 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO } return nil, err.Error(), http.StatusInternalServerError } + if tokenInfo == nil { + return nil, "token validation failed", http.StatusInternalServerError + } // Check scopes. All must be present. if opts != nil { diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/authorization_code.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/authorization_code.go new file mode 100644 index 000000000..2a6ed32b7 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/authorization_code.go @@ -0,0 +1,548 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// ClientSecretAuthConfig is used to configure client authentication using client_secret. +// Authentication method will be selected based on the authorization server's supported methods, +// according to the following preference order: +// 1. client_secret_post +// 2. client_secret_basic +type ClientSecretAuthConfig struct { + // ClientID is the client ID to be used for client authentication. + ClientID string + // ClientSecret is the client secret to be used for client authentication. + ClientSecret string +} + +// ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document +// based client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. +// See https://client.dev/ for more information. +type ClientIDMetadataDocumentConfig struct { + // URL is the client identifier URL as per + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-client-id-metadata-document-00#section-3. + URL string +} + +// PreregisteredClientConfig is used to configure a pre-registered client per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. +// Currently only "client_secret_basic" and "client_secret_post" authentication methods are supported. +type PreregisteredClientConfig struct { + // ClientSecretAuthConfig is the client_secret based configuration to be used for client authentication. + ClientSecretAuthConfig *ClientSecretAuthConfig +} + +// DynamicClientRegistrationConfig is used to configure dynamic client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration. +type DynamicClientRegistrationConfig struct { + // Metadata to be used in dynamic client registration request as per + // https://datatracker.ietf.org/doc/html/rfc7591#section-2. + Metadata *oauthex.ClientRegistrationMetadata +} + +// AuthorizationResult is the result of an authorization flow. +// It is returned by [AuthorizationCodeHandler].AuthorizationCodeFetcher implementations. +type AuthorizationResult struct { + // Code is the authorization code obtained from the authorization server. + Code string + // State string returned by the authorization server. + State string +} + +// AuthorizationArgs is the input to [AuthorizationCodeHandlerConfig].AuthorizationCodeFetcher. +type AuthorizationArgs struct { + // Authorization URL to be opened in a browser for the user to start the authorization process. + URL string +} + +// AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeHandler]. +type AuthorizationCodeHandlerConfig struct { + // Client registration configuration. + // It is attempted in the following order: + // 1. Client ID Metadata Document + // 2. Preregistration + // 3. Dynamic Client Registration + // At least one method must be configured. + ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig + PreregisteredClientConfig *PreregisteredClientConfig + DynamicClientRegistrationConfig *DynamicClientRegistrationConfig + + // RedirectURL is a required URL to redirect to after authorization. + // The caller is responsible for handling the redirect out of band. + // + // If Dynamic Client Registration is used: + // - this field is permitted to be empty, in which case it will be set + // to the first redirect URI from + // DynamicClientRegistrationConfig.Metadata.RedirectURIs. + // - if the field is not empty, it must be one of the redirect URIs in + // DynamicClientRegistrationConfig.Metadata.RedirectURIs. + RedirectURL string + + // AuthorizationCodeFetcher is a required function called to initiate the authorization flow. + // It is responsible for opening the URL in a browser for the user to start the authorization process. + // It should return the authorization code and state once the Authorization Server + // redirects back to the RedirectURL. + AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) +} + +// AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses +// the authorization code flow to obtain access tokens. +type AuthorizationCodeHandler struct { + config *AuthorizationCodeHandlerConfig + + // tokenSource is the token source to use for authorization. + tokenSource oauth2.TokenSource +} + +var _ OAuthHandler = (*AuthorizationCodeHandler)(nil) + +func (h *AuthorizationCodeHandler) isOAuthHandler() {} + +func (h *AuthorizationCodeHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// NewAuthorizationCodeHandler creates a new AuthorizationCodeHandler. +// It performs validation of the configuration and returns an error if it is invalid. +// The passed config is consumed by the handler and should not be modified after. +func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*AuthorizationCodeHandler, error) { + if config == nil { + return nil, errors.New("config must be provided") + } + if config.ClientIDMetadataDocumentConfig == nil && + config.PreregisteredClientConfig == nil && + config.DynamicClientRegistrationConfig == nil { + return nil, errors.New("at least one client registration configuration must be provided") + } + if config.AuthorizationCodeFetcher == nil { + return nil, errors.New("AuthorizationCodeFetcher is required") + } + if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { + return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") + } + preCfg := config.PreregisteredClientConfig + if preCfg != nil { + if preCfg.ClientSecretAuthConfig == nil { + return nil, errors.New("ClientSecretAuthConfig is required for pre-registered client") + } + if preCfg.ClientSecretAuthConfig.ClientID == "" || preCfg.ClientSecretAuthConfig.ClientSecret == "" { + return nil, fmt.Errorf("pre-registered client ID or secret is empty") + } + } + dCfg := config.DynamicClientRegistrationConfig + if dCfg != nil { + if dCfg.Metadata == nil { + return nil, errors.New("Metadata is required for dynamic client registration") + } + if len(dCfg.Metadata.RedirectURIs) == 0 { + return nil, errors.New("Metadata.RedirectURIs is required for dynamic client registration") + } + if config.RedirectURL == "" { + config.RedirectURL = dCfg.Metadata.RedirectURIs[0] + } else if !slices.Contains(dCfg.Metadata.RedirectURIs, config.RedirectURL) { + return nil, fmt.Errorf("RedirectURL %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) + } + } + if config.RedirectURL == "" { + // If the RedirectURL was supposed to be set by the dynamic client registration, + // it should have been set by now. Otherwise, it is required. + return nil, errors.New("RedirectURL is required") + } + return &AuthorizationCodeHandler{config: config}, nil +} + +func isNonRootHTTPSURL(u string) bool { + pu, err := url.Parse(u) + if err != nil { + return false + } + return pu.Scheme == "https" && pu.Path != "" +} + +// Authorize performs the authorization flow. +// It is designed to perform the whole Authorization Code Grant flow. +// On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. +func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer resp.Body.Close() + + wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) + } + + if resp.StatusCode == http.StatusForbidden && errorFromChallenges(wwwChallenges) != "insufficient_scope" { + // We only want to perform step-up authorization for insufficient_scope errors. + // Returning nil, so that the call is retried immediately and the response + // is handled appropriately by the connection. + // Step-up authorization is defined at + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#step-up-authorization-flow + return nil + } + + prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, req.URL.String()) + if err != nil { + return err + } + + asm, err := h.getAuthServerMetadata(ctx, prm) + if err != nil { + return err + } + + resolvedClientConfig, err := h.handleRegistration(ctx, asm) + if err != nil { + return err + } + + scps := scopesFromChallenges(wwwChallenges) + if len(scps) == 0 && len(prm.ScopesSupported) > 0 { + scps = prm.ScopesSupported + } + + cfg := &oauth2.Config{ + ClientID: resolvedClientConfig.clientID, + ClientSecret: resolvedClientConfig.clientSecret, + + Endpoint: oauth2.Endpoint{ + AuthURL: asm.AuthorizationEndpoint, + TokenURL: asm.TokenEndpoint, + AuthStyle: resolvedClientConfig.authStyle, + }, + RedirectURL: h.config.RedirectURL, + Scopes: scps, + } + + authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String()) + if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. + return err + } + + return h.exchangeAuthorizationCode(ctx, cfg, authRes, prm.Resource) +} + +// resourceMetadataURLFromChallenges returns a resource metadata URL from the given "WWW-Authenticate" header challenges, +// or the empty string if there is none. +func resourceMetadataURLFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// scopesFromChallenges returns the scopes from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. +func scopesFromChallenges(cs []oauthex.Challenge) []string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["scope"] != "" { + return strings.Fields(c.Params["scope"]) + } + } + return nil +} + +// errorFromChallenges returns the error from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. +func errorFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["error"] != "" { + return c.Params["error"] + } + } + return "" +} + +// getProtectedResourceMetadata returns the protected resource metadata. +// If no metadata was found or the fetched metadata fails security checks, +// it returns an error. +func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, mcpServerURL string) (*oauthex.ProtectedResourceMetadata, error) { + var errs []error + // Use MCP server URL as the resource URI per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. + for _, url := range protectedResourceMetadataURLs(resourceMetadataURLFromChallenges(wwwChallenges), mcpServerURL) { + prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient) + if err != nil { + errs = append(errs, err) + continue + } + if prm == nil { + errs = append(errs, fmt.Errorf("protected resource metadata is nil")) + continue + } + return prm, nil + } + return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...)) +} + +type prmURL struct { + // URL represents a URL where Protected Resource Metadata may be retrieved. + URL string + // Resource represents the corresponding resource URL for [URL]. + // It is required to perform validation described in RFC 9728, section 3.3. + Resource string +} + +// protectedResourceMetadataURLs returns a list of URLs to try when looking for +// protected resource metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#protected-resource-metadata-discovery-requirements +func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { + var urls []prmURL + if metadataURL != "" { + urls = append(urls, prmURL{ + URL: metadataURL, + Resource: resourceURL, + }) + } + ru, err := url.Parse(resourceURL) + if err != nil { + return urls + } + mu := *ru + // "At the path of the server's MCP endpoint". + mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") + urls = append(urls, prmURL{ + URL: mu.String(), + Resource: resourceURL, + }) + // "At the root". + mu.Path = "/.well-known/oauth-protected-resource" + ru.Path = "" + urls = append(urls, prmURL{ + URL: mu.String(), + Resource: ru.String(), + }) + return urls +} + +// getAuthServerMetadata returns the authorization server metadata. +// The provided Protected Resource Metadata must not be nil. +// It returns an error if the metadata request fails with non-4xx HTTP status code +// or the fetched metadata fails security checks. +// If no metadata was found, it returns a minimal set of endpoints +// as a fallback to 2025-03-26 spec. +func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) { + var authServerURL string + if len(prm.AuthorizationServers) > 0 { + // Use the first authorization server, similarly to other SDKs. + authServerURL = prm.AuthorizationServers[0] + } else { + // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. + authURL, err := url.Parse(prm.Resource) + if err != nil { + return nil, fmt.Errorf("failed to parse resource URL: %v", err) + } + authURL.Path = "" + authServerURL = authURL.String() + } + + for _, u := range authorizationServerMetadataURLs(authServerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm != nil { + return asm, nil + } + } + + // Fallback to 2025-03-26 spec: predefined endpoints. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery + asm := &oauthex.AuthServerMeta{ + Issuer: authServerURL, + AuthorizationEndpoint: authServerURL + "/authorize", + TokenEndpoint: authServerURL + "/token", + RegistrationEndpoint: authServerURL + "/register", + } + return asm, nil +} + +// authorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func authorizationServerMetadataURLs(issuerURL string) []string { + var urls []string + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil + } + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls + } + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls +} + +type registrationType int + +const ( + registrationTypeClientIDMetadataDocument registrationType = iota + registrationTypePreregistered + registrationTypeDynamic +) + +type resolvedClientConfig struct { + registrationType registrationType + clientID string + clientSecret string + authStyle oauth2.AuthStyle +} + +func selectTokenAuthMethod(supported []string) oauth2.AuthStyle { + prefOrder := []string{ + // Preferred in OAuth 2.1 draft: https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-14.html#name-client-secret. + "client_secret_post", + "client_secret_basic", + } + for _, method := range prefOrder { + if slices.Contains(supported, method) { + return authMethodToStyle(method) + } + } + return oauth2.AuthStyleAutoDetect +} + +func authMethodToStyle(method string) oauth2.AuthStyle { + switch method { + case "client_secret_post": + return oauth2.AuthStyleInParams + case "client_secret_basic": + return oauth2.AuthStyleInHeader + case "none": + // "none" is equivalent to "client_secret_post" but without sending client secret. + return oauth2.AuthStyleInParams + default: + // "client_secret_basic" is the default per https://datatracker.ietf.org/doc/html/rfc7591#section-2. + return oauth2.AuthStyleInHeader + } +} + +// handleRegistration handles client registration. +// The provided authorization server metadata must be non-nil. +// Support for different registration methods is defined as follows: +// - Client ID Metadata Document: metadata must have +// `ClientIDMetadataDocumentSupported` set to true. +// - Pre-registered client: assumed to be supported. +// - Dynamic client registration: metadata must have +// `RegistrationEndpoint` set to a non-empty value. +func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { + // 1. Attempt to use Client ID Metadata Document (SEP-991). + cimdCfg := h.config.ClientIDMetadataDocumentConfig + if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { + return &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: cimdCfg.URL, + }, nil + } + // 2. Attempt to use pre-registered client configuration. + pCfg := h.config.PreregisteredClientConfig + if pCfg != nil { + authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported) + return &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: pCfg.ClientSecretAuthConfig.ClientID, + clientSecret: pCfg.ClientSecretAuthConfig.ClientSecret, + authStyle: authStyle, + }, nil + } + // 3. Attempt to use dynamic client registration. + dcrCfg := h.config.DynamicClientRegistrationConfig + if dcrCfg != nil && asm.RegistrationEndpoint != "" { + regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + cfg := &resolvedClientConfig{ + registrationType: registrationTypeDynamic, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + authStyle: authMethodToStyle(regResp.TokenEndpointAuthMethod), + } + return cfg, nil + } + return nil, fmt.Errorf("no configured client registration methods are supported by the authorization server") +} + +type authResult struct { + *AuthorizationResult + // usedCodeVerifier is the PKCE code verifier used to obtain the authorization code. + // It is preserved for the token exchange step. + usedCodeVerifier string +} + +// getAuthorizationCode uses the [AuthorizationCodeHandler.AuthorizationCodeFetcher] +// to obtain an authorization code. +func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { + codeVerifier := oauth2.GenerateVerifier() + state := rand.Text() + + authURL := cfg.AuthCodeURL(state, + oauth2.S256ChallengeOption(codeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + ) + + authRes, err := h.config.AuthorizationCodeFetcher(ctx, &AuthorizationArgs{URL: authURL}) + if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. + return nil, err + } + if authRes.State != state { + return nil, fmt.Errorf("state mismatch") + } + return &authResult{ + AuthorizationResult: authRes, + usedCodeVerifier: codeVerifier, + }, nil +} + +// exchangeAuthorizationCode exchanges the authorization code for a token +// and stores it in a token source. +func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { + opts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(authResult.usedCodeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + } + token, err := cfg.Exchange(ctx, authResult.Code, opts...) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + h.tokenSource = cfg.TokenSource(ctx, token) + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go index acadc51be..0af6963fc 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go @@ -2,122 +2,41 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -//go:build mcp_go_client_oauth - package auth import ( - "bytes" - "errors" - "io" + "context" "net/http" - "sync" "golang.org/x/oauth2" ) -// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization -// is approved, or an error if not. -// The handler receives the HTTP request and response that triggered the authentication flow. -// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. -type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) - -// HTTPTransport is an [http.RoundTripper] that follows the MCP -// OAuth protocol when it encounters a 401 Unauthorized response. -type HTTPTransport struct { - handler OAuthHandler - mu sync.Mutex // protects opts.Base - opts HTTPTransportOptions -} - -// NewHTTPTransport returns a new [*HTTPTransport]. -// The handler is invoked when an HTTP request results in a 401 Unauthorized status. -// It is called only once per transport. Once a TokenSource is obtained, it is used -// for the lifetime of the transport; subsequent 401s are not processed. -func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { - if handler == nil { - return nil, errors.New("handler cannot be nil") - } - t := &HTTPTransport{ - handler: handler, - } - if opts != nil { - t.opts = *opts - } - if t.opts.Base == nil { - t.opts.Base = http.DefaultTransport - } - return t, nil -} - -// HTTPTransportOptions are options to [NewHTTPTransport]. -type HTTPTransportOptions struct { - // Base is the [http.RoundTripper] to use. - // If nil, [http.DefaultTransport] is used. - Base http.RoundTripper -} - -func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - t.mu.Lock() - base := t.opts.Base - t.mu.Unlock() - - var ( - // If haveBody is set, the request has a nontrivial body, and we need avoid - // reading (or closing) it multiple times. In that case, bodyBytes is its - // content. - haveBody bool - bodyBytes []byte - ) - if req.Body != nil && req.Body != http.NoBody { - // if we're setting Body, we must mutate first. - req = req.Clone(req.Context()) - haveBody = true - var err error - bodyBytes, err = io.ReadAll(req.Body) - if err != nil { - return nil, err - } - // Now that we've read the request body, http.RoundTripper requires that we - // close it. - req.Body.Close() // ignore error - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - resp, err := base.RoundTrip(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusUnauthorized { - return resp, nil - } - if _, ok := base.(*oauth2.Transport); ok { - // We failed to authorize even with a token source; give up. - return resp, nil - } - - resp.Body.Close() - // Try to authorize. - t.mu.Lock() - defer t.mu.Unlock() - // If we don't have a token source, get one by following the OAuth flow. - // (We may have obtained one while t.mu was not held above.) - // TODO: We hold the lock for the entire OAuth flow. This could be a long - // time. Is there a better way? - if _, ok := t.opts.Base.(*oauth2.Transport); !ok { - ts, err := t.handler(req, resp) - if err != nil { - return nil, err - } - t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} - } - - // If we don't have a body, the request is reusable, though it will be cloned - // by the base. However, if we've had to read the body, we must clone. - if haveBody { - req = req.Clone(req.Context()) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - return t.opts.Base.RoundTrip(req) +// OAuthHandler is an interface for handling OAuth flows. +// +// If a transport wishes to support OAuth 2 authorization, it should support +// being configured with an OAuthHandler. It should call the handler's +// TokenSource method whenever it sends an HTTP request to set the +// Authorization header. If a request fails with a 401 or 403, it should call +// Authorize, and if that returns nil, it should retry the request. It should +// not call Authorize after the second failure. See +// [github.com/modelcontextprotocol/go-sdk/mcp.StreamableClientTransport] +// for an example. +type OAuthHandler interface { + isOAuthHandler() + + // TokenSource returns a token source to be used for outgoing requests. + // Returned token source might be nil. In that case, the transport will not + // add any authorization headers to the request. + TokenSource(context.Context) (oauth2.TokenSource, error) + + // Authorize is called when an HTTP request results in an error that may + // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). + // It is responsible for performing the OAuth flow to obtain an access token. + // The arguments are the request that failed and the response that was received for it. + // The headers of the request are available, but the body will have already been consumed + // when Authorize is called. + // If the returned error is nil, TokenSource is expected to return a non-nil token source. + // After a successful call to Authorize, the HTTP request will be retried by the transport. + // The function is responsible for closing the response body. + Authorize(context.Context, *http.Request, *http.Response) error } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/client_private.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client_private.go new file mode 100644 index 000000000..767c59eea --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client_private.go @@ -0,0 +1,135 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandlerLegacy conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +type HTTPTransport struct { + handler OAuthHandlerLegacy + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/json/json.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/json/json.go index f06609cb3..1148770e3 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/internal/json/json.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/json/json.go @@ -1,6 +1,6 @@ // Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. // Package json provides internal JSON utilities. diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go index 46fcc9db9..72527cb9c 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go @@ -163,8 +163,8 @@ func (r *headerReader) Read(ctx context.Context) (Message, error) { return nil, fmt.Errorf("invalid header line %q", line) } name, value := line[:colon], strings.TrimSpace(line[colon+1:]) - switch name { - case "Content-Length": + switch { + case strings.EqualFold(name, "Content-Length"): if contentLength, err = strconv.ParseInt(value, 10, 32); err != nil { return nil, fmt.Errorf("failed parsing Content-Length: %v", value) } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go index ae5b64099..b424780eb 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go @@ -5,11 +5,14 @@ package jsonrpc2 import ( + "bytes" "encoding/json" "errors" "fmt" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" + + "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" ) // ID is a Request identifier, which is defined by the spec to be a string, integer, or null. @@ -147,9 +150,9 @@ func toWireError(err error) *WireError { func EncodeMessage(msg Message) ([]byte, error) { wire := wireCombined{VersionTag: wireVersion} msg.marshal(&wire) - data, err := json.Marshal(&wire) + data, err := jsonMarshal(&wire) if err != nil { - return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + return nil, fmt.Errorf("marshaling jsonrpc message: %w", err) } return data, nil } @@ -160,11 +163,14 @@ func EncodeMessage(msg Message) ([]byte, error) { func EncodeIndent(msg Message, prefix, indent string) ([]byte, error) { wire := wireCombined{VersionTag: wireVersion} msg.marshal(&wire) - data, err := json.MarshalIndent(&wire, prefix, indent) - if err != nil { - return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + enc.SetIndent(prefix, indent) + if err := enc.Encode(&wire); err != nil { + return nil, fmt.Errorf("marshaling jsonrpc message: %w", err) } - return data, nil + return bytes.TrimRight(buf.Bytes(), "\n"), nil } func DecodeMessage(data []byte) (Message, error) { @@ -206,9 +212,31 @@ func marshalToRaw(obj any) (json.RawMessage, error) { if obj == nil { return nil, nil } - data, err := json.Marshal(obj) + data, err := jsonMarshal(obj) if err != nil { return nil, err } return json.RawMessage(data), nil } + +// jsonescaping is a compatibility parameter that allows to restore +// JSON escaping in the JSON marshaling, which stopped being the default +// in the 1.4.0 version of the SDK. See the documentation for the +// mcpgodebug package for instructions how to enable it. +// The option will be removed in the 1.6.0 version of the SDK. +var jsonescaping = mcpgodebug.Value("jsonescaping") + +// jsonMarshal marshals obj to JSON like json.Marshal but without HTML escaping. +func jsonMarshal(obj any) ([]byte, error) { + if jsonescaping == "1" { + return json.Marshal(obj) + } + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(obj); err != nil { + return nil, err + } + // json.Encoder.Encode adds a trailing newline. Trim it to be consistent with json.Marshal. + return bytes.TrimRight(buf.Bytes(), "\n"), nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug/mcpgodebug.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug/mcpgodebug.go new file mode 100644 index 000000000..7f8f7ca35 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug/mcpgodebug.go @@ -0,0 +1,52 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +// Package mcpgodebug provides a mechanism to configure compatibility parameters +// via the MCPGODEBUG environment variable. +// +// The value of MCPGODEBUG is a comma-separated list of key=value pairs. +// For example: +// +// MCPGODEBUG=someoption=1,otheroption=value +package mcpgodebug + +import ( + "fmt" + "os" + "strings" +) + +const compatibilityEnvKey = "MCPGODEBUG" + +var compatibilityParams map[string]string + +func init() { + var err error + compatibilityParams, err = parseCompatibility(os.Getenv(compatibilityEnvKey)) + if err != nil { + panic(err) + } +} + +// Value returns the value of the compatibility parameter with the given key. +// It returns an empty string if the key is not set. +func Value(key string) string { + return compatibilityParams[key] +} + +func parseCompatibility(envValue string) (map[string]string, error) { + if envValue == "" { + return nil, nil + } + + params := make(map[string]string) + for part := range strings.SplitSeq(envValue, ",") { + k, v, ok := strings.Cut(part, "=") + if !ok { + return nil, fmt.Errorf("MCPGODEBUG: invalid format: %q", part) + } + params[strings.TrimSpace(k)] = strings.TrimSpace(v) + } + return params, nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/net.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/net.go new file mode 100644 index 000000000..6858614eb --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/net.go @@ -0,0 +1,26 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. +package util + +import ( + "net" + "net/netip" + "strings" +) + +func IsLoopback(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + // If SplitHostPort fails, it might be just a host without a port. + host = strings.Trim(addr, "[]") + } + if host == "localhost" { + return true + } + ip, err := netip.ParseAddr(host) + if err != nil { + return false + } + return ip.IsLoopback() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go index 63ffa0af7..74900b1c7 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go @@ -51,6 +51,9 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { } options = nil // prevent reuse + if opts.CreateMessageHandler != nil && opts.CreateMessageWithToolsHandler != nil { + panic("cannot set both CreateMessageHandler and CreateMessageWithToolsHandler; use CreateMessageWithToolsHandler for tool support, or CreateMessageHandler for basic sampling") + } if opts.Logger == nil { // ensure we have a logger opts.Logger = ensureLogger(nil) } @@ -76,6 +79,19 @@ type ClientOptions struct { // non nil value for [ClientCapabilities.Sampling], that value overrides the // inferred capability. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // CreateMessageWithToolsHandler handles incoming sampling/createMessage + // requests that may involve tool use. It returns + // [CreateMessageWithToolsResult], which supports array content for parallel + // tool calls. + // + // Setting this handler causes the client to advertise the sampling + // capability with tools support (sampling.tools). As with + // [CreateMessageHandler], [ClientOptions.Capabilities].Sampling overrides + // the inferred capability. + // + // It is a panic to set both CreateMessageHandler and + // CreateMessageWithToolsHandler. + CreateMessageWithToolsHandler func(context.Context, *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) // ElicitationHandler handles incoming requests for elicitation/create. // // Setting ElicitationHandler to a non-nil value automatically causes the @@ -108,7 +124,16 @@ type ClientOptions struct { // are set in the Capabilities field, their values override the inferred // value. // - // For example, to to configure elicitation modes: + // For example, to advertise sampling with tools and context support: + // + // Capabilities: &ClientCapabilities{ + // Sampling: &SamplingCapabilities{ + // Tools: &SamplingToolsCapabilities{}, + // Context: &SamplingContextCapabilities{}, + // }, + // } + // + // Or to configure elicitation modes: // // Capabilities: &ClientCapabilities{ // Elicitation: &ElicitationCapabilities{ @@ -118,8 +143,7 @@ type ClientOptions struct { // } // // Conversely, if Capabilities does not set a field (for example, if the - // Elicitation field is nil), the inferred elicitation capability will be - // used. + // Elicitation field is nil), the inferred capability will be used. Capabilities *ClientCapabilities // ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete. ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest) @@ -197,10 +221,13 @@ func (c *Client) capabilities(protocolVersion string) *ClientCapabilities { caps.Roots = *caps.RootsV2 } - // Augment with sampling capability if handler is set. - if c.opts.CreateMessageHandler != nil { + // Augment with sampling capability if a handler is set. + if c.opts.CreateMessageHandler != nil || c.opts.CreateMessageWithToolsHandler != nil { if caps.Sampling == nil { caps.Sampling = &SamplingCapabilities{} + if c.opts.CreateMessageWithToolsHandler != nil { + caps.Sampling.Tools = &SamplingToolsCapabilities{} + } } } @@ -452,12 +479,27 @@ func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRoots }, nil } -func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { - if c.opts.CreateMessageHandler == nil { - // TODO: wrap or annotate this error? Pick a standard code? - return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} +func (c *Client) createMessage(ctx context.Context, req *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + if c.opts.CreateMessageWithToolsHandler != nil { + return c.opts.CreateMessageWithToolsHandler(ctx, req) + } + if c.opts.CreateMessageHandler != nil { + // Downconvert the request for the basic handler. + baseParams, err := req.Params.toBase() + if err != nil { + return nil, err + } + baseReq := &CreateMessageRequest{ + Session: req.Session, + Params: baseParams, + } + res, err := c.opts.CreateMessageHandler(ctx, baseReq) + if err != nil { + return nil, err + } + return res.toWithTools(), nil } - return c.opts.CreateMessageHandler(ctx, req) + return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} } // urlElicitationMiddleware returns middleware that automatically handles URL elicitation @@ -589,7 +631,7 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, return nil, err } // Validate elicitation result content against requested schema. - if schema != nil && res.Content != nil { + if res.Action == "accept" && schema != nil && res.Content != nil { resolved, err := schema.Resolve(nil) if err != nil { return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)} @@ -667,8 +709,10 @@ func validateElicitProperty(propName string, propSchema *jsonschema.Schema) erro return validateElicitNumberProperty(propName, propSchema) case "boolean": return validateElicitBooleanProperty(propName, propSchema) + case "array": + return validateElicitArrayProperty(propName, propSchema) default: - return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, boolean, and array are allowed", propName, propSchema.Type) } } @@ -681,7 +725,7 @@ func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) } // Enum values themselves are validated by the JSON schema library - // Validate enumNames if present - must match enum length + // Validate legacy enumNames if present - must match enum length. if propSchema.Extra != nil { if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { // Type check enumNames - should be a slice @@ -696,6 +740,15 @@ func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema } return nil } + // Handle new style of titled enums. + if propSchema.OneOf != nil { + for _, entry := range propSchema.OneOf { + if err := validateTitledEnumEntry(entry); err != nil { + return fmt.Errorf("elicit schema property %q oneOf has invalid entry: %v", propName, err) + } + } + return nil + } // Validate format if specified - only specific formats are allowed if propSchema.Format != "" { @@ -748,6 +801,53 @@ func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema return nil } +// validateElicitArrayProperty validates multi-select enum properties. +func validateElicitArrayProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Items == nil { + return fmt.Errorf("elicit schema property %q is array but missing 'items' definition", propName) + } + + items := propSchema.Items + switch items.Type { + case "string": + // Untitled enums. + if items.Enum == nil { + return fmt.Errorf("elicit schema property %q items must specify enum for untitled enums", propName) + } + return nil + case "": + // Titled enums. + if len(items.AnyOf) == 0 { + return fmt.Errorf("elicit schema property %q items must specify anyOf for titled enums", propName) + } + for _, entry := range items.AnyOf { + if err := validateTitledEnumEntry(entry); err != nil { + return fmt.Errorf("elicit schema property %q items has invalid entry: %v", propName, err) + } + } + return nil + default: + return fmt.Errorf("elicit schema property %q items have unsupported type %q", propName, items.Type) + } +} + +func validateTitledEnumEntry(entry *jsonschema.Schema) error { + if entry.Const == nil { + return fmt.Errorf("const is required for titled enum entries") + } + constVal, ok := (*entry.Const).(string) + if !ok { + return fmt.Errorf("const must be a string for titled enum entries") + } + if constVal == "" { + return fmt.Errorf("const cannot be empty for titled enum entries") + } + if entry.Title == "" { + return fmt.Errorf("title is required for titled enum entries") + } + return nil +} + // validateElicitBooleanProperty validates boolean-type properties. func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { return validateDefaultProperty[bool](propName, propSchema) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go index fb1a0d1e5..95ea40d80 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go @@ -9,12 +9,16 @@ package mcp import ( "encoding/json" - "errors" "fmt" + + internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" ) // A Content is a [TextContent], [ImageContent], [AudioContent], -// [ResourceLink], or [EmbeddedResource]. +// [ResourceLink], [EmbeddedResource], [ToolUseContent], or [ToolResultContent]. +// +// Note: [ToolUseContent] and [ToolResultContent] are only valid in sampling +// message contexts (CreateMessageParams/CreateMessageResult). type Content interface { MarshalJSON() ([]byte, error) fromWire(*wireContent) @@ -183,69 +187,165 @@ func (c *EmbeddedResource) fromWire(wire *wireContent) { c.Annotations = wire.Annotations } -// ResourceContents contains the contents of a specific resource or -// sub-resource. -type ResourceContents struct { - URI string `json:"uri"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text,omitempty"` - Blob []byte `json:"blob,omitempty"` - Meta Meta `json:"_meta,omitempty"` +// ToolUseContent represents a request from the assistant to invoke a tool. +// This content type is only valid in sampling messages. +type ToolUseContent struct { + // ID is a unique identifier for this tool use, used to match with ToolResultContent. + ID string + // Name is the name of the tool to invoke. + Name string + // Input contains the tool arguments as a JSON object. + Input map[string]any + Meta Meta } -func (r *ResourceContents) MarshalJSON() ([]byte, error) { - // If we could assume Go 1.24, we could use omitzero for Blob and avoid this method. - if r.URI == "" { - return nil, errors.New("ResourceContents missing URI") +func (c *ToolUseContent) MarshalJSON() ([]byte, error) { + input := c.Input + if input == nil { + input = map[string]any{} + } + wire := struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]any `json:"input"` + Meta Meta `json:"_meta,omitempty"` + }{ + Type: "tool_use", + ID: c.ID, + Name: c.Name, + Input: input, + Meta: c.Meta, } - if r.Blob == nil { - // Text. Marshal normally. - type wireResourceContents ResourceContents // (lacks MarshalJSON method) - return json.Marshal((wireResourceContents)(*r)) + return json.Marshal(wire) +} + +func (c *ToolUseContent) fromWire(wire *wireContent) { + c.ID = wire.ID + c.Name = wire.Name + c.Input = wire.Input + c.Meta = wire.Meta +} + +// ToolResultContent represents the result of a tool invocation. +// This content type is only valid in sampling messages with role "user". +type ToolResultContent struct { + // ToolUseID references the ID from the corresponding ToolUseContent. + ToolUseID string + // Content holds the unstructured result of the tool call. + Content []Content + // StructuredContent holds an optional structured result as a JSON object. + StructuredContent any + // IsError indicates whether the tool call ended in an error. + IsError bool + Meta Meta +} + +func (c *ToolResultContent) MarshalJSON() ([]byte, error) { + // Marshal nested content + var contentWire []*wireContent + for _, content := range c.Content { + data, err := content.MarshalJSON() + if err != nil { + return nil, err + } + var w wireContent + if err := internaljson.Unmarshal(data, &w); err != nil { + return nil, err + } + contentWire = append(contentWire, &w) } - // Blob. - if r.Text != "" { - return nil, errors.New("ResourceContents has non-zero Text and Blob fields") + if contentWire == nil { + contentWire = []*wireContent{} // avoid JSON null } - // r.Blob may be the empty slice, so marshal with an alternative definition. - br := struct { - URI string `json:"uri,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - Blob []byte `json:"blob"` - Meta Meta `json:"_meta,omitempty"` + + wire := struct { + Type string `json:"type"` + ToolUseID string `json:"toolUseId"` + Content []*wireContent `json:"content"` + StructuredContent any `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` + Meta Meta `json:"_meta,omitempty"` }{ - URI: r.URI, - MIMEType: r.MIMEType, - Blob: r.Blob, - Meta: r.Meta, + Type: "tool_result", + ToolUseID: c.ToolUseID, + Content: contentWire, + StructuredContent: c.StructuredContent, + IsError: c.IsError, + Meta: c.Meta, } - return json.Marshal(br) + return json.Marshal(wire) +} + +func (c *ToolResultContent) fromWire(wire *wireContent) { + c.ToolUseID = wire.ToolUseID + c.StructuredContent = wire.StructuredContent + c.IsError = wire.IsError + c.Meta = wire.Meta + // Content is handled separately in contentFromWire due to nested content +} + +// ResourceContents contains the contents of a specific resource or +// sub-resource. +type ResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitzero"` + Meta Meta `json:"_meta,omitempty"` } // wireContent is the wire format for content. // It represents the protocol types TextContent, ImageContent, AudioContent, -// ResourceLink, and EmbeddedResource. +// ResourceLink, EmbeddedResource, ToolUseContent, and ToolResultContent. // The Type field distinguishes them. In the protocol, each type has a constant // value for the field. -// At most one of Text, Data, Resource, and URI is non-zero. type wireContent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - Data []byte `json:"data,omitempty"` - Resource *ResourceContents `json:"resource,omitempty"` - URI string `json:"uri,omitempty"` - Name string `json:"name,omitempty"` - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - Size *int64 `json:"size,omitempty"` - Meta Meta `json:"_meta,omitempty"` - Annotations *Annotations `json:"annotations,omitempty"` - Icons []Icon `json:"icons,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` // TextContent + MIMEType string `json:"mimeType,omitempty"` // ImageContent, AudioContent, ResourceLink + Data []byte `json:"data,omitempty"` // ImageContent, AudioContent + Resource *ResourceContents `json:"resource,omitempty"` // EmbeddedResource + URI string `json:"uri,omitempty"` // ResourceLink + Name string `json:"name,omitempty"` // ResourceLink, ToolUseContent + Title string `json:"title,omitempty"` // ResourceLink + Description string `json:"description,omitempty"` // ResourceLink + Size *int64 `json:"size,omitempty"` // ResourceLink + Meta Meta `json:"_meta,omitempty"` // all types + Annotations *Annotations `json:"annotations,omitempty"` // all types except ToolUseContent, ToolResultContent + Icons []Icon `json:"icons,omitempty"` // ResourceLink + ID string `json:"id,omitempty"` // ToolUseContent + Input map[string]any `json:"input,omitempty"` // ToolUseContent + ToolUseID string `json:"toolUseId,omitempty"` // ToolResultContent + NestedContent []*wireContent `json:"content,omitempty"` // ToolResultContent + StructuredContent any `json:"structuredContent,omitempty"` // ToolResultContent + IsError bool `json:"isError,omitempty"` // ToolResultContent +} + +// unmarshalContent unmarshals JSON that is either a single content object or +// an array of content objects. A single object is wrapped in a one-element slice. +func unmarshalContent(raw json.RawMessage, allow map[string]bool) ([]Content, error) { + if len(raw) == 0 || string(raw) == "null" { + return nil, fmt.Errorf("nil content") + } + // Try array first, then fall back to single object. + var wires []*wireContent + if err := internaljson.Unmarshal(raw, &wires); err == nil { + return contentsFromWire(wires, allow) + } + var wire wireContent + if err := internaljson.Unmarshal(raw, &wire); err != nil { + return nil, err + } + c, err := contentFromWire(&wire, allow) + if err != nil { + return nil, err + } + return []Content{c}, nil } func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { - var blocks []Content + blocks := make([]Content, 0, len(wires)) for _, wire := range wires { block, err := contentFromWire(wire, allow) if err != nil { @@ -284,6 +384,27 @@ func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) v := new(EmbeddedResource) v.fromWire(wire) return v, nil + case "tool_use": + v := new(ToolUseContent) + v.fromWire(wire) + return v, nil + case "tool_result": + v := new(ToolResultContent) + v.fromWire(wire) + // Handle nested content - tool_result content can contain text, image, audio, + // resource_link, and resource (same as CallToolResult.content) + if wire.NestedContent != nil { + toolResultContentAllow := map[string]bool{ + "text": true, "image": true, "audio": true, + "resource_link": true, "resource": true, + } + nestedContent, err := contentsFromWire(wire.NestedContent, toolResultContentAllow) + if err != nil { + return nil, fmt.Errorf("tool_result nested content: %w", err) + } + v.Content = nestedContent + } + return v, nil } - return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type) + return nil, fmt.Errorf("unrecognized content type %q", wire.Type) } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go index 5c322c4a3..62dd2ad2b 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go @@ -67,9 +67,7 @@ func writeEvent(w io.Writer, evt Event) (int, error) { // TODO(rfindley): consider a different API here that makes failure modes more // apparent. func scanEvents(r io.Reader) iter.Seq2[Event, error] { - scanner := bufio.NewScanner(r) - const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size - scanner.Buffer(nil, maxTokenSize) + reader := bufio.NewReader(r) // TODO: investigate proper behavior when events are out of order, or have // non-standard names. @@ -94,31 +92,43 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { evt Event dataBuf *bytes.Buffer // if non-nil, preceding field was also data ) - flushData := func() { + yieldEvent := func() bool { if dataBuf != nil { evt.Data = dataBuf.Bytes() dataBuf = nil } + if evt.Empty() { + return true + } + if !yield(evt, nil) { + return false + } + evt = Event{} + return true } - for scanner.Scan() { - line := scanner.Bytes() + for { + line, err := reader.ReadBytes('\n') + if err != nil && !errors.Is(err, io.EOF) { + yield(Event{}, fmt.Errorf("error reading event: %v", err)) + return + } + line = bytes.TrimRight(line, "\r\n") + isEOF := errors.Is(err, io.EOF) + if len(line) == 0 { - flushData() - // \n\n is the record delimiter - if !evt.Empty() && !yield(evt, nil) { + if !yieldEvent() { + return + } + if isEOF { return } - evt = Event{} continue } before, after, found := bytes.Cut(line, []byte{':'}) if !found { - yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + yield(Event{}, fmt.Errorf("%w: malformed line in SSE stream: %q", errMalformedEvent, string(line))) return } - if !bytes.Equal(before, dataKey) { - flushData() - } switch { case bytes.Equal(before, eventKey): evt.Name = strings.TrimSpace(string(after)) @@ -128,27 +138,19 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { evt.Retry = strings.TrimSpace(string(after)) case bytes.Equal(before, dataKey): data := bytes.TrimSpace(after) - if dataBuf != nil { - dataBuf.WriteByte('\n') - dataBuf.Write(data) - } else { + if dataBuf == nil { dataBuf = new(bytes.Buffer) - dataBuf.Write(data) + } else { + dataBuf.WriteByte('\n') } + dataBuf.Write(data) } - } - if err := scanner.Err(); err != nil { - if errors.Is(err, bufio.ErrTooLong) { - err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) - } - if !yield(Event{}, err) { + + if isEOF { + yieldEvent() return } } - flushData() - if !evt.Empty() { - yield(evt, nil) - } } } @@ -310,6 +312,11 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, // index is no longer available. var ErrEventsPurged = errors.New("data purged") +// errMalformedEvent is returned when an SSE event cannot be parsed due to format violations. +// This is a hard error indicating corrupted data or protocol violations, as opposed to +// transient I/O errors which may be retryable. +var errMalformedEvent = errors.New("malformed event") + // After implements [EventStore.After]. func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { // Return the data items to yield. diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go index 96a96b828..b1bd82b15 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go @@ -89,21 +89,12 @@ type LoggingHandler struct { handler slog.Handler } -// discardHandler is a slog.Handler that drops all logs. -// TODO: use slog.DiscardHandler when we require Go 1.24+. -type discardHandler struct{} - -func (discardHandler) Enabled(context.Context, slog.Level) bool { return false } -func (discardHandler) Handle(context.Context, slog.Record) error { return nil } -func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} } -func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} } - // ensureLogger returns l if non-nil, otherwise a discard logger. func ensureLogger(l *slog.Logger) *slog.Logger { if l != nil { return l } - return slog.New(discardHandler{}) + return slog.New(slog.DiscardHandler) } // NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go index 0e07d6706..837ce7843 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go @@ -13,6 +13,7 @@ package mcp import ( "encoding/json" "fmt" + "maps" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" ) @@ -190,12 +191,18 @@ type RootCapabilities struct { // this schema, but this is not a closed set: any client can define its own, // additional capabilities. type ClientCapabilities struct { - // NOTE: any addition to ClientCapabilities must also be reflected in // [ClientCapabilities.clone]. // Experimental reports non-standard capabilities that the client supports. + // The caller should not modify the map after assigning it. Experimental map[string]any `json:"experimental,omitempty"` + // Extensions reports extensions that the client supports. + // Keys are extension identifiers in "{vendor-prefix}/{extension-name}" format. + // Values are per-extension settings objects; use [ClientCapabilities.AddExtension] + // to ensure nil settings are normalized to empty objects. + // The caller should not modify the map or its values after assigning it. + Extensions map[string]any `json:"extensions,omitempty"` // Roots describes the client's support for roots. // // Deprecated: use RootsV2. As described in #607, Roots should have been a @@ -214,11 +221,33 @@ type ClientCapabilities struct { Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"` } -// clone returns a deep copy of the ClientCapabilities. +// AddExtension adds an extension with the given name and settings. +// If settings is nil, an empty map is used to ensure valid JSON serialization +// (the spec requires an object, not null). +// The settings map should not be modified after the call. +func (c *ClientCapabilities) AddExtension(name string, settings map[string]any) { + if c.Extensions == nil { + c.Extensions = make(map[string]any) + } + if settings == nil { + settings = map[string]any{} + } + c.Extensions[name] = settings +} + +// clone returns a copy of the ClientCapabilities. +// Values in the Extensions and Experimental maps are shallow-copied. func (c *ClientCapabilities) clone() *ClientCapabilities { cp := *c + cp.Experimental = maps.Clone(c.Experimental) + cp.Extensions = maps.Clone(c.Extensions) cp.RootsV2 = shallowClone(c.RootsV2) - cp.Sampling = shallowClone(c.Sampling) + if c.Sampling != nil { + x := *c.Sampling + x.Tools = shallowClone(c.Sampling.Tools) + x.Context = shallowClone(c.Sampling.Context) + cp.Sampling = &x + } if c.Elicitation != nil { x := *c.Elicitation x.Form = shallowClone(c.Elicitation.Form) @@ -359,6 +388,11 @@ type CreateMessageParams struct { Meta `json:"_meta,omitempty"` // A request to include context from one or more MCP servers (including the // caller), to be attached to the prompt. The client may ignore this request. + // + // The default is "none". Values "thisServer" and + // "allServers" are soft-deprecated. Servers SHOULD only use these values if + // the client declares ClientCapabilities.sampling.context. These values may + // be removed in future spec releases. IncludeContext string `json:"includeContext,omitempty"` // The maximum number of tokens to sample, as requested by the server. The // client may choose to sample fewer tokens than requested. @@ -381,6 +415,106 @@ func (x *CreateMessageParams) isParams() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } +// CreateMessageWithToolsParams is a sampling request that includes tools. +// It extends the basic [CreateMessageParams] fields with tools, tool choice, +// and messages that support array content (for parallel tool calls). +// +// Use with [ServerSession.CreateMessageWithTools]. +type CreateMessageWithToolsParams struct { + Meta `json:"_meta,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + MaxTokens int64 `json:"maxTokens"` + // Messages supports array content for tool_use and tool_result blocks. + Messages []*SamplingMessageV2 `json:"messages"` + Metadata any `json:"metadata,omitempty"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + // Tools is the list of tools available for the model to use. + Tools []*Tool `json:"tools,omitempty"` + // ToolChoice controls how the model should use tools. + ToolChoice *ToolChoice `json:"toolChoice,omitempty"` +} + +func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// toBase converts to CreateMessageParams by taking the content block from each +// message. Tools and ToolChoice are dropped. Returns an error if any message +// has multiple content blocks, since SamplingMessage only supports one. +func (p *CreateMessageWithToolsParams) toBase() (*CreateMessageParams, error) { + var msgs []*SamplingMessage + for _, m := range p.Messages { + if len(m.Content) > 1 { + return nil, fmt.Errorf("message has %d content blocks; use CreateMessageWithToolsHandler to support multiple content", len(m.Content)) + } + var content Content + if len(m.Content) > 0 { + content = m.Content[0] + } + msgs = append(msgs, &SamplingMessage{Content: content, Role: m.Role}) + } + return &CreateMessageParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + }, nil +} + +// SamplingMessageV2 describes a message issued to or received from an +// LLM API, supporting array content for parallel tool calls. The "V2" refers +// to the 2025-11-25 spec, which changed content from a single block to +// single-or-array. In v2 of the SDK, this will replace [SamplingMessage]. +// +// When marshaling, a single-element Content slice is marshaled as a single +// object for compatibility with pre-2025-11-25 implementations. When +// unmarshaling, a single JSON content object is accepted and wrapped in a +// one-element slice. +type SamplingMessageV2 struct { + Content []Content `json:"content"` + Role Role `json:"role"` +} + +var samplingWithToolsAllow = map[string]bool{ + "text": true, "image": true, "audio": true, + "tool_use": true, "tool_result": true, +} + +// MarshalJSON marshals the message. A single-element Content slice is marshaled +// as a single object for backward compatibility. +func (m *SamplingMessageV2) MarshalJSON() ([]byte, error) { + if len(m.Content) == 1 { + return json.Marshal(&SamplingMessage{Content: m.Content[0], Role: m.Role}) + } + type msg SamplingMessageV2 // avoid recursion + return json.Marshal((*msg)(m)) +} + +func (m *SamplingMessageV2) UnmarshalJSON(data []byte) error { + type msg SamplingMessageV2 // avoid recursion + var wire struct { + msg + Content json.RawMessage `json:"content"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = unmarshalContent(wire.Content, samplingWithToolsAllow); err != nil { + return err + } + *m = SamplingMessageV2(wire.msg) + return nil +} + // The client's response to a sampling/create_message request from the server. // The client should inform the user before returning the sampled message, to // allow them to inspect the response (human in the loop) and decide whether to @@ -394,6 +528,12 @@ type CreateMessageResult struct { Model string `json:"model"` Role Role `json:"role"` // The reason why sampling stopped, if known. + // + // Standard values: + // - "endTurn": natural end of the assistant's turn + // - "stopSequence": a stop sequence was encountered + // - "maxTokens": reached the maximum token limit + // - "toolUse": the model wants to use one or more tools StopReason string `json:"stopReason,omitempty"` } @@ -415,6 +555,84 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { return nil } +// CreateMessageWithToolsResult is the client's response to a +// sampling/create_message request that included tools. Content is a slice to +// support parallel tool calls (multiple tool_use blocks in one response). +// +// Use [ServerSession.CreateMessageWithTools] to send a sampling request with +// tools and receive this result type. +// +// When unmarshaling, a single JSON content object is accepted and wrapped in a +// one-element slice, for compatibility with clients that return a single block. +type CreateMessageWithToolsResult struct { + Meta `json:"_meta,omitempty"` + Content []Content `json:"content"` + Model string `json:"model"` + Role Role `json:"role"` + // The reason why sampling stopped. + // + // Standard values: "endTurn", "stopSequence", "maxTokens", "toolUse". + StopReason string `json:"stopReason,omitempty"` +} + +// createMessageWithToolsResultAllow lists content types valid in assistant responses. +// tool_result is excluded: it only appears in user messages. +var createMessageWithToolsResultAllow = map[string]bool{ + "text": true, "image": true, "audio": true, + "tool_use": true, +} + +func (*CreateMessageWithToolsResult) isResult() {} + +// MarshalJSON marshals the result. When Content has a single element, it is +// marshaled as a single object for compatibility with pre-2025-11-25 +// implementations that expect a single content block. +func (r *CreateMessageWithToolsResult) MarshalJSON() ([]byte, error) { + if len(r.Content) == 1 { + return json.Marshal(&CreateMessageResult{ + Meta: r.Meta, + Content: r.Content[0], + Model: r.Model, + Role: r.Role, + StopReason: r.StopReason, + }) + } + type result CreateMessageWithToolsResult // avoid recursion + return json.Marshal((*result)(r)) +} + +func (r *CreateMessageWithToolsResult) UnmarshalJSON(data []byte) error { + type result CreateMessageWithToolsResult // avoid recursion + var wire struct { + result + Content json.RawMessage `json:"content"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = unmarshalContent(wire.Content, createMessageWithToolsResultAllow); err != nil { + return err + } + *r = CreateMessageWithToolsResult(wire.result) + return nil +} + +// toWithTools converts a CreateMessageResult to CreateMessageWithToolsResult. +func (r *CreateMessageResult) toWithTools() *CreateMessageWithToolsResult { + var content []Content + if r.Content != nil { + content = []Content{r.Content} + } + return &CreateMessageWithToolsResult{ + Meta: r.Meta, + Content: content, + Model: r.Model, + Role: r.Role, + StopReason: r.StopReason, + } +} + type GetPromptParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -984,25 +1202,46 @@ func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t // below directly above ClientCapabilities. // SamplingCapabilities describes the client's support for sampling. -type SamplingCapabilities struct{} +type SamplingCapabilities struct { + // Context indicates the client supports includeContext values other than "none". + Context *SamplingContextCapabilities `json:"context,omitempty"` + // Tools indicates the client supports tools and toolChoice in sampling requests. + Tools *SamplingToolsCapabilities `json:"tools,omitempty"` +} + +// SamplingContextCapabilities indicates the client supports context inclusion. +type SamplingContextCapabilities struct{} + +// SamplingToolsCapabilities indicates the client supports tool use in sampling. +type SamplingToolsCapabilities struct{} + +// ToolChoice controls how the model uses tools during sampling. +type ToolChoice struct { + // Mode controls tool invocation behavior: + // - "auto": Model decides whether to use tools (default) + // - "required": Model must use at least one tool + // - "none": Model must not use any tools + Mode string `json:"mode,omitempty"` +} // ElicitationCapabilities describes the capabilities for elicitation. // // If neither Form nor URL is set, the 'Form' capabilitiy is assumed. type ElicitationCapabilities struct { - Form *FormElicitationCapabilities - URL *URLElicitationCapabilities + Form *FormElicitationCapabilities `json:"form,omitempty"` + URL *URLElicitationCapabilities `json:"url,omitempty"` } // FormElicitationCapabilities describes capabilities for form elicitation. -type FormElicitationCapabilities struct { -} +type FormElicitationCapabilities struct{} // URLElicitationCapabilities describes capabilities for url elicitation. -type URLElicitationCapabilities struct { -} +type URLElicitationCapabilities struct{} // Describes a message issued to or received from an LLM API. +// +// For assistant messages, Content may be text, image, audio, or tool_use. +// For user messages, Content may be text, image, audio, or tool_result. type SamplingMessage struct { Content Content `json:"content"` Role Role `json:"role"` @@ -1019,8 +1258,9 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error { if err := internaljson.Unmarshal(data, &wire); err != nil { return err } + // Allow text, image, audio, tool_use, and tool_result in sampling messages var err error - if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true, "tool_result": true}); err != nil { return err } *m = SamplingMessage(wire.msg) @@ -1299,12 +1539,18 @@ type ToolCapabilities struct { // ServerCapabilities describes capabilities that a server supports. type ServerCapabilities struct { - // NOTE: any addition to ServerCapabilities must also be reflected in // [ServerCapabilities.clone]. // Experimental reports non-standard capabilities that the server supports. + // The caller should not modify the map after assigning it. Experimental map[string]any `json:"experimental,omitempty"` + // Extensions reports extensions that the server supports. + // Keys are extension identifiers in "{vendor-prefix}/{extension-name}" format. + // Values are per-extension settings objects; use [ServerCapabilities.AddExtension] + // to ensure nil settings are normalized to empty objects. + // The caller should not modify the map or its values after assigning it. + Extensions map[string]any `json:"extensions,omitempty"` // Completions is present if the server supports argument autocompletion // suggestions. Completions *CompletionCapabilities `json:"completions,omitempty"` @@ -1318,9 +1564,26 @@ type ServerCapabilities struct { Tools *ToolCapabilities `json:"tools,omitempty"` } -// clone returns a deep copy of the ServerCapabilities. +// AddExtension adds an extension with the given name and settings. +// If settings is nil, an empty map is used to ensure valid JSON serialization +// (the spec requires an object, not null). +// The settings map should not be modified after the call. +func (c *ServerCapabilities) AddExtension(name string, settings map[string]any) { + if c.Extensions == nil { + c.Extensions = make(map[string]any) + } + if settings == nil { + settings = map[string]any{} + } + c.Extensions[name] = settings +} + +// clone returns a copy of the ServerCapabilities. +// Values in the Extensions and Experimental maps are shallow-copied. func (c *ServerCapabilities) clone() *ServerCapabilities { cp := *c + cp.Experimental = maps.Clone(c.Experimental) + cp.Extensions = maps.Clone(c.Extensions) cp.Completions = shallowClone(c.Completions) cp.Logging = shallowClone(c.Logging) cp.Prompts = shallowClone(c.Prompts) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go index f64d6fb62..428094136 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go @@ -24,6 +24,7 @@ type ( type ( CreateMessageRequest = ClientRequest[*CreateMessageParams] + CreateMessageWithToolsRequest = ClientRequest[*CreateMessageWithToolsParams] ElicitRequest = ClientRequest[*ElicitParams] initializedClientRequest = ClientRequest[*InitializedParams] InitializeRequest = ClientRequest[*InitializeParams] diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go index dc657f5dd..bc4b3cb1f 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go @@ -113,6 +113,23 @@ func computeURIFilepath(rawURI, dirFilepath string, rootFilepaths []string) (str return uriFilepathRel, nil } +// withFile calls f on the file at join(dir, rel), +// protecting against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + r, err := os.OpenRoot(dir) + if err != nil { + return err + } + defer r.Close() + file, err := r.Open(rel) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} + // fileRoots transforms the Roots obtained from the client into absolute paths on // the local filesystem. // TODO(jba): expose this functionality to user ResourceHandlers, diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go deleted file mode 100644 index 4a35603c6..000000000 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build go1.24 - -package mcp - -import ( - "errors" - "os" -) - -// withFile calls f on the file at join(dir, rel), -// protecting against path traversal attacks. -func withFile(dir, rel string, f func(*os.File) error) (err error) { - r, err := os.OpenRoot(dir) - if err != nil { - return err - } - defer r.Close() - file, err := r.Open(rel) - if err != nil { - return err - } - // Record error, in case f writes. - defer func() { err = errors.Join(err, file.Close()) }() - return f(file) -} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go deleted file mode 100644 index d1f72eedc..000000000 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build !go1.24 - -package mcp - -import ( - "errors" - "os" - "path/filepath" -) - -// withFile calls f on the file at join(dir, rel). -// It does not protect against path traversal attacks. -func withFile(dir, rel string, f func(*os.File) error) (err error) { - file, err := os.Open(filepath.Join(dir, rel)) - if err != nil { - return err - } - // Record error, in case f writes. - defer func() { err = errors.Join(err, file.Close()) }() - return f(file) -} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go index cbed5b116..e3c03e278 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "crypto/rand" "encoding/base64" "encoding/gob" "encoding/json" @@ -176,7 +177,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { } if opts.GetSessionID == nil { - opts.GetSessionID = randText + opts.GetSessionID = rand.Text } if opts.Logger == nil { // ensure we have a logger @@ -1163,6 +1164,10 @@ func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) } // CreateMessage sends a sampling request to the client. +// +// If the client returns multiple content blocks (e.g. parallel tool calls), +// CreateMessage returns an error. Use [ServerSession.CreateMessageWithTools] +// for tool-enabled sampling. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { if err := ss.checkInitialized(methodCreateMessage); err != nil { return nil, err @@ -1175,7 +1180,44 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag p2.Messages = []*SamplingMessage{} // avoid JSON "null" params = &p2 } - return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) + res, err := handleSend[*CreateMessageWithToolsResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) + if err != nil { + return nil, err + } + // Downconvert to singular content. + if len(res.Content) > 1 { + return nil, fmt.Errorf("CreateMessage result has %d content blocks; use CreateMessageWithTools for multiple content", len(res.Content)) + } + var content Content + if len(res.Content) > 0 { + content = res.Content[0] + } + return &CreateMessageResult{ + Meta: res.Meta, + Content: content, + Model: res.Model, + Role: res.Role, + StopReason: res.StopReason, + }, nil +} + +// CreateMessageWithTools sends a sampling request with tools to the client, +// returning a [CreateMessageWithToolsResult] that supports array content +// (for parallel tool calls). Use this instead of [ServerSession.CreateMessage] +// when the request includes tools. +func (ss *ServerSession) CreateMessageWithTools(ctx context.Context, params *CreateMessageWithToolsParams) (*CreateMessageWithToolsResult, error) { + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } + if params == nil { + params = &CreateMessageWithToolsParams{Messages: []*SamplingMessageV2{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessageV2{} // avoid JSON "null" + params = &p2 + } + return handleSend[*CreateMessageWithToolsResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) } // Elicit sends an elicitation request to the client asking for user input. @@ -1219,6 +1261,10 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli return nil, err } + if res.Action != "accept" { + return res, nil + } + if params.RequestedSchema == nil { return res, nil } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go index ae65c16cb..e57dad102 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "crypto/rand" "fmt" "io" "net/http" @@ -216,7 +217,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - sessionID = randText() + sessionID = rand.Text() endpoint, err := req.URL.Parse("?sessionid=" + sessionID) if err != nil { http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go index 1fdf97334..0b11eff00 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go @@ -11,6 +11,7 @@ package mcp import ( "bytes" "context" + crand "crypto/rand" "encoding/json" "errors" "fmt" @@ -19,17 +20,19 @@ import ( "maps" "math" "math/rand/v2" + "net" "net/http" "slices" "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/modelcontextprotocol/go-sdk/auth" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" + "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -161,6 +164,16 @@ type StreamableHTTPOptions struct { // // If SessionTimeout is the zero value, idle sessions are never closed. SessionTimeout time.Duration + + // DisableLocalhostProtection disables automatic DNS rebinding protection. + // By default, requests arriving via a localhost address (127.0.0.1, [::1]) + // that have a non-localhost Host header are rejected with 403 Forbidden. + // This protects against DNS rebinding attacks regardless of whether the + // server is listening on localhost specifically or on 0.0.0.0. + // + // Only disable this if you understand the security implications. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + DisableLocalhostProtection bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -207,7 +220,24 @@ func (h *StreamableHTTPHandler) closeAll() { } } +// disablelocalhostprotection is a compatibility parameter that allows to disable +// DNS rebinding protection, which was added in the 1.4.0 version of the SDK. +// See the documentation for the mcpgodebug package for instructions how to enable it. +// The option will be removed in the 1.6.0 version of the SDK. +var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection") + func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // DNS rebinding protection: auto-enabled for localhost servers. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + if !h.opts.DisableLocalhostProtection && disablelocalhostprotection != "1" { + if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr != nil { + if util.IsLoopback(localAddr.String()) && !util.IsLoopback(req.Host) { + http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden) + return + } + } + } + // Allow multiple 'Accept' headers. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") @@ -374,7 +404,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // stateless servers. body, err := io.ReadAll(req.Body) if err != nil { - http.Error(w, "failed to read body", http.StatusInternalServerError) + http.Error(w, "failed to read body", http.StatusBadRequest) return } req.Body.Close() @@ -1141,7 +1171,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Important: don't publish the incoming messages until the stream is // registered, as the server may attempt to respond to imcoming messages as // soon as they're published. - stream, err := c.newStream(req.Context(), calls, randText()) + stream, err := c.newStream(req.Context(), calls, crand.Text()) if err != nil { http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return @@ -1425,6 +1455,9 @@ type StreamableClientTransport struct { // - You want to avoid maintaining a persistent connection DisableStandaloneSSE bool + // OAuthHandler is an optional field that, if provided, will be used to authorize the requests. + OAuthHandler auth.OAuthHandler + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1500,6 +1533,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er cancel: cancel, failed: make(chan struct{}), disableStandaloneSSE: t.DisableStandaloneSSE, + oauthHandler: t.OAuthHandler, } return conn, nil } @@ -1518,6 +1552,9 @@ type streamableClientConn struct { // for receiving server-to-client notifications when no request is in flight. disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE] + // oauthHandler is the OAuth handler for the connection. + oauthHandler auth.OAuthHandler // from [StreamableClientTransport.OAuthHandler] + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error @@ -1540,17 +1577,6 @@ type streamableClientConn struct { sessionID string } -// errSessionMissing distinguishes if the session is known to not be present on -// the server (see [streamableClientConn.fail]). -// -// TODO(rfindley): should we expose this error value (and its corresponding -// API) to the user? -// -// The spec says that if the server returns 404, clients should reestablish -// a session. For now, we delegate that to the user, but do they need a way to -// differentiate a 'NotFound' error from other errors? -var errSessionMissing = errors.New("session not found") - var _ clientConnection = (*streamableClientConn)(nil) func (c *streamableClientConn) sessionUpdated(state clientSessionState) { @@ -1629,7 +1655,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { // If err is non-nil, it is terminal, and subsequent (or pending) Reads will // fail. // -// If err wraps errSessionMissing, the failure indicates that the session is no +// If err wraps ErrSessionMissing, the failure indicates that the session is no // longer present on the server, and no final DELETE will be performed when // closing the connection. func (c *streamableClientConn) fail(err error) { @@ -1698,20 +1724,46 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %v", requestSummary, err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + doRequest := func() (*http.Request, *http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + if err := c.setMCPHeaders(req); err != nil { + // Failure to set headers means that the request was not sent. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return nil, nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + resp, err := c.client.Do(req) + if err != nil { + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + err = fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + return req, resp, err + } + + req, resp, err := doRequest() if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - c.setMCPHeaders(req) - resp, err := c.client.Do(req) - if err != nil { - // Any error from client.Do means the request didn't reach the server. - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil { + if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + } + // Retry the request after successful authorization. + _, resp, err = doRequest() + if err != nil { + return err + } } if err := c.checkResponse(requestSummary, resp); err != nil { @@ -1779,23 +1831,32 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } -// testAuth controls whether a fake Authorization header is added to outgoing requests. -// TODO: replace with a better mechanism when client-side auth is in place. -var testAuth atomic.Bool - -func (c *streamableClientConn) setMCPHeaders(req *http.Request) { +func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { c.mu.Lock() defer c.mu.Unlock() + if c.oauthHandler != nil { + ts, err := c.oauthHandler.TokenSource(c.ctx) + if err != nil { + return err + } + if ts != nil { + token, err := ts.Token() + if err != nil { + return err + } + if token != nil { + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + } + } + } if c.initializedResult != nil { req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) } if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } - if testAuth.Load() { - req.Header.Set("Authorization", "Bearer foo") - } + return nil } func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { @@ -1824,15 +1885,14 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // stream is complete when we receive its response. Otherwise, this is the // standalone stream. func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc2.Request) { + // Track the last event ID to detect progress. + // The retry counter is only reset when progress is made (lastEventID advances). + // This prevents infinite retry loops when a server repeatedly terminates + // connections without making progress (#679). + var prevLastEventID string + retriesWithoutProgress := 0 + for { - // Connection was successful. Continue the loop with the new response. - // - // TODO(#679): we should set a reasonable limit on the number of times - // we'll try getting a response for a given request, or enforce that we - // actually make progress. - // - // Eventually, if we don't get the response, we should stop trying and - // fail the request. lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall) // If the connection was closed by the client, we're done. @@ -1846,6 +1906,23 @@ func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary str return } + // Check if we made progress (lastEventID advanced). + // Only reset the retry counter when actual progress is made. + if lastEventID != "" && lastEventID != prevLastEventID { + // Progress was made: reset the retry counter. + retriesWithoutProgress = 0 + prevLastEventID = lastEventID + } else { + // No progress: increment the retry counter. + retriesWithoutProgress++ + if retriesWithoutProgress > c.maxRetries { + if ctx.Err() == nil { + c.fail(fmt.Errorf("%s: exceeded %d retries without progress (session ID: %v)", requestSummary, c.maxRetries, c.sessionID)) + } + return + } + } + // The stream was interrupted or ended by the server. Attempt to reconnect. newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) if err != nil { @@ -1880,9 +1957,9 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // which it MUST respond to requests containing that session ID with HTTP // 404 Not Found." if resp.StatusCode == http.StatusNotFound { - // Return an errSessionMissing to avoid sending a redundant DELETE when the + // Return an ErrSessionMissing to avoid sending a redundant DELETE when the // session is already gone. - return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) + return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, ErrSessionMissing) } // Transient server errors (502, 503, 504, 429) should not break the connection. // Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr. @@ -1910,6 +1987,14 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary if ctx.Err() != nil { return "", 0, true // don't reconnect: client cancelled } + + // Malformed events are hard errors that indicate corrupted data or protocol + // violations. These should fail the connection permanently. + if errors.Is(err, errMalformedEvent) { + c.fail(fmt.Errorf("%s: %v", requestSummary, err)) + return "", 0, true + } + break } @@ -1922,6 +2007,15 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary reconnectDelay = time.Duration(n) * time.Millisecond } } + + // According to SSE specification + // (https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation) + // events with an empty data buffer are allowed. + // In MCP these can be priming events (SEP-1699) that carry only a Last-Event-ID for stream resumption. + if len(evt.Data) == 0 { + continue + } + // According to SSE spec, events with no name default to "message" if evt.Name != "" && evt.Name != "message" { continue @@ -2015,7 +2109,9 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin if err != nil { return nil, err } - c.setMCPHeaders(req) + if err := c.setMCPHeaders(req); err != nil { + return nil, err + } if lastEventID != "" { req.Header.Set(lastEventIDHeader, lastEventID) } @@ -2039,15 +2135,16 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { - if errors.Is(c.failure(), errSessionMissing) { + if errors.Is(c.failure(), ErrSessionMissing) { // If the session is missing, no need to delete it. } else { req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil) if err != nil { c.closeErr = err } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + if err := c.setMCPHeaders(req); err != nil { + c.closeErr = err + } else if _, err := c.client.Do(req); err != nil { c.closeErr = err } } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go index 41a100461..c2cc25b8a 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go @@ -161,7 +161,7 @@ The client must handle two response formats from POST requests: - DELETE: Terminate the session - Used by [streamableClientConn.Close] - - Skipped if session is already known to be gone ([errSessionMissing]) + - Skipped if session is already known to be gone ([ErrSessionMissing]) # Error Handling @@ -173,7 +173,7 @@ Errors are categorized and handled differently: - Triggers reconnection in [streamableClientConn.handleSSE] 2. Terminal (breaks the connection): - - 404 Not Found: Session terminated by server ([errSessionMissing]) + - 404 Not Found: Session terminated by server ([ErrSessionMissing]) - Message decode errors: Protocol violation - Context cancellation: Client closed connection - Mismatched session IDs: Protocol error @@ -183,7 +183,7 @@ Terminal errors are stored via [streamableClientConn.fail] and returned by subsequent [streamableClientConn.Read] calls. The [streamableClientConn.failed] channel signals that the connection is broken. -Special case: [errSessionMissing] indicates the server has terminated the session, +Special case: [ErrSessionMissing] indicates the server has terminated the session, so [streamableClientConn.Close] skips the DELETE request. # Protocol Version Header diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go index 585df87ff..5f2a50072 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go @@ -25,6 +25,10 @@ import ( // is closed or in the process of closing. var ErrConnectionClosed = errors.New("connection closed") +// ErrSessionMissing is returned when the session is known to not be present on +// the server. +var ErrSessionMissing = errors.New("session not found") + // A Transport is used to create a bidirectional connection between MCP client // and server. // diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go index be3f3c7c0..8ffaa74ef 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go @@ -5,7 +5,6 @@ package mcp import ( - "crypto/rand" "encoding/json" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" @@ -17,20 +16,6 @@ func assert(cond bool, msg string) { } } -// Copied from crypto/rand. -// TODO: once 1.24 is assured, just use crypto/rand. -const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" - -func randText() string { - // ⌈log₃₂ 2¹²⁸⌉ = 26 chars - src := make([]byte, 26) - rand.Read(src) - for i := range src { - src[i] = base32alphabet[src[i]%32] - } - return string(src) -} - // remarshal marshals from to JSON, and then unmarshals into to, which must be // a pointer type. func remarshal(from, to any) error { diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go index 9aa0c8d7d..b05d80b6d 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go @@ -14,6 +14,9 @@ import ( "errors" "fmt" "net/http" + "net/url" + + "github.com/modelcontextprotocol/go-sdk/internal/util" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -28,8 +31,6 @@ import ( // // [RFC 8414]: https://tools.ietf.org/html/rfc8414) type AuthServerMeta struct { - // GENERATED BY GEMINI 2.5. - // Issuer is the REQUIRED URL identifying the authorization server. Issuer string `json:"issuer"` @@ -113,51 +114,61 @@ type AuthServerMeta struct { // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of // PKCE code challenge methods supported by this authorization server. CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` -} -var wellKnownPaths = []string{ - "/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration", + // ClientIDMetadataDocumentSupported is a boolean indicating whether the authorization server + // supports client ID metadata documents. + ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata -// from an OAuth authorization server with the given issuerURL. +// from an OAuth authorization server with the given metadataURL. // // It follows [RFC 8414]: -// - The well-known paths specified there are inserted into the URL's path, one at time. -// The first to succeed is used. -// - The Issuer field is checked against issuerURL. +// - The metadataURL must use HTTPS or be a local address. +// - The Issuer field is checked against metadataURL.Issuer. +// +// It also verifies that the authorization server supports PKCE and that the URLs +// in the metadata don't use dangerous schemes. +// +// It returns an error if the request fails with a non-4xx status code or the fetched +// metadata doesn't pass security validations. +// It returns nil if the request fails with a 4xx status code. // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 -func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { - var errs []error - for _, p := range wellKnownPaths { - u, err := prependToPath(issuerURL, p) - if err != nil { - // issuerURL is bad; no point in continuing. - return nil, err - } - asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) - if err == nil { - if asm.Issuer != issuerURL { // section 3.3 - // Security violation; don't keep trying. - return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) - } - - if len(asm.CodeChallengeMethodsSupported) == 0 { - return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) +func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.Client) (*AuthServerMeta, error) { + u, err := url.Parse(metadataURL) + if err != nil { + return nil, err + } + // Only allow HTTP for local addresses (testing or development purposes). + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) + } + asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20) + if err != nil { + var httpErr *httpStatusError + if errors.As(err, &httpErr) { + if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { + return nil, nil } + } + return nil, fmt.Errorf("%v", err) // Do not expose error types. + } + if asm.Issuer != issuer { + // Validate the Issuer field (see RFC 8414, section 3.3). + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuer) + } - // Validate endpoint URLs to prevent XSS attacks (see #526). - if err := validateAuthServerMetaURLs(asm); err != nil { - return nil, err - } + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuer) + } - return asm, nil - } - errs = append(errs, err) + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err } - return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) + + return asm, nil } // validateAuthServerMetaURLs validates all URL fields in AuthServerMeta diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go index cdda695b7..836a4201b 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go @@ -19,21 +19,12 @@ import ( "strings" ) -// prependToPath prepends pre to the path of urlStr. -// When pre is the well-known path, this is the algorithm specified in both RFC 9728 -// section 3.1 and RFC 8414 section 3.1. -func prependToPath(urlStr, pre string) (string, error) { - u, err := url.Parse(urlStr) - if err != nil { - return "", err - } - p := "/" + strings.Trim(pre, "/") - if u.Path != "" { - p += "/" - } +type httpStatusError struct { + StatusCode int +} - u.Path = p + strings.TrimLeft(u.Path, "/") - return u.String(), nil +func (e *httpStatusError) Error() string { + return fmt.Sprintf("bad status %d", e.StatusCode) } // getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both @@ -53,11 +44,9 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 } defer res.Body.Close() - // Specs require a 200. if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bad status %s", res.Status) + return nil, &httpStatusError{StatusCode: res.StatusCode} } - // Specs require application/json. ct := res.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(ct) if err != nil || mediaType != "application/json" { diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go index 34ed55b59..151da7e51 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go @@ -4,89 +4,3 @@ // Package oauthex implements extensions to OAuth2. package oauthex - -// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, -// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. -// -// The following features are not supported: -// - additional keys (§2, last sentence) -// - human-readable metadata (§2.1) -// - signed metadata (§2.2) -type ProtectedResourceMetadata struct { - // GENERATED BY GEMINI 2.5. - - // Resource (resource) is the protected resource's resource identifier. - // Required. - Resource string `json:"resource"` - - // AuthorizationServers (authorization_servers) is an optional slice containing a list of - // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be - // used with this protected resource. - AuthorizationServers []string `json:"authorization_servers,omitempty"` - - // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set - // document. This contains public keys belonging to the protected resource, such as - // signing key(s) that the resource server uses to sign resource responses. - JWKSURI string `json:"jwks_uri,omitempty"` - - // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope - // values (as defined in RFC 6749) used in authorization requests to request access - // to this protected resource. - ScopesSupported []string `json:"scopes_supported,omitempty"` - - // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing - // a list of the supported methods of sending an OAuth 2.0 bearer token to the - // protected resource. Defined values are "header", "body", and "query". - BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` - - // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms (alg values) supported by the protected - // resource for signing resource responses. - ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` - - // ResourceName (resource_name) is a human-readable name of the protected resource - // intended for display to the end user. It is RECOMMENDED that this field be included. - // This value may be internationalized. - ResourceName string `json:"resource_name,omitempty"` - - // ResourceDocumentation (resource_documentation) is an optional URL of a page containing - // human-readable information for developers using the protected resource. - // This value may be internationalized. - ResourceDocumentation string `json:"resource_documentation,omitempty"` - - // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing - // human-readable policy information on how a client can use the data provided. - // This value may be internationalized. - ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` - - // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected - // resource's human-readable terms of service. This value may be internationalized. - ResourceTOSURI string `json:"resource_tos_uri,omitempty"` - - // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an - // optional boolean indicating support for mutual-TLS client certificate-bound - // access tokens (RFC 8705). Defaults to false if omitted. - TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` - - // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional - // slice of 'type' values supported by the resource server for the - // 'authorization_details' parameter (RFC 9396). - AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` - - // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms supported by the resource server for validating - // DPoP proof JWTs (RFC 9449). - DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` - - // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean - // specifying whether the protected resource always requires the use of DPoP-bound - // access tokens (RFC 9449). Defaults to false if omitted. - DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` - - // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters - // about the protected resource as claims. If present, these values take precedence - // over values conveyed in plain JSON. - // TODO:implement. - // Note that §2.2 says it's okay to ignore this. - // SignedMetadata string `json:"signed_metadata,omitempty"` -} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go index bb61f7974..8b911cad1 100644 --- a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go @@ -38,6 +38,8 @@ const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resour // // It then retrieves the metadata at that location using the given client (or the // default client if nil) and validates its resource field against resourceID. +// +// Deprecated: Use [GetProtectedResourceMetadata] instead. This function will be removed in v1.5.0. func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) @@ -47,7 +49,7 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, } // Insert well-known URI into URL. u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) - return getPRM(ctx, u.String(), c, resourceID) + return GetProtectedResourceMetadata(ctx, u.String(), resourceID, c) } // GetProtectedResourceMetadataFromHeader retrieves protected resource metadata @@ -57,8 +59,9 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, // Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata // matches the serverURL (the URL that the client used to make the original request to the resource server). // If there is no metadata URL in the header, it returns nil, nil. +// +// Deprecated: Use [GetProtectedResourceMetadata] instead. This function will be removed in v1.5.0. func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { - defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] if len(headers) == 0 { return nil, nil @@ -67,26 +70,49 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin if err != nil { return nil, err } - metadataURL := ResourceMetadataURL(cs) + metadataURL := resourceMetadataURL(cs) if metadataURL == "" { return nil, nil } - return getPRM(ctx, metadataURL, c, serverURL) + return GetProtectedResourceMetadata(ctx, metadataURL, serverURL, c) } -// getPRM makes a GET request to the given URL, and validates the response. -// As part of the validation, it compares the returned resource field to wantResource. -func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { - if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { - return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) +// resourceMetadataURL returns a resource metadata URL from the given "WWW-Authenticate" header challenges, +// or the empty string if there is none. +func resourceMetadataURL(cs []Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server. +// The metadataURL is typically a URL with a host:port and possibly a path. +// The resourceURL is the resource URI the metadataURL is for. +// The following checks are performed: +// - The metadataURL must use HTTPS or be a local address. +// - The resource field of the resulting metadata must match the resourceURL. +// - The authorization_servers field of the resulting metadata is checked for dangerous URL schemes. +func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) + u, err := url.Parse(metadataURL) + if err != nil { + return nil, err + } + // Only allow HTTP for local addresses (testing or development purposes). + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) } - prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL, 1<<20) if err != nil { return nil, err } // Validate the Resource field (see RFC 9728, section 3.3). - if prm.Resource != wantResource { - return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + if prm.Resource != resourceURL { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, resourceURL) } // Validate the authorization server URLs to prevent XSS attacks (see #526). for _, u := range prm.AuthorizationServers { @@ -97,37 +123,12 @@ func getPRM(ctx context.Context, purl string, c *http.Client, wantResource strin return prm, nil } -// challenge represents a single authentication challenge from a WWW-Authenticate header. -// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. -type challenge struct { - // GENERATED BY GEMINI 2.5. - // - // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). - // It is case-insensitive. A parsed value will always be lower-case. - Scheme string - // Params is a map of authentication parameters. - // Keys are case-insensitive. Parsed keys are always lower-case. - Params map[string]string -} - -// ResourceMetadataURL returns a resource metadata URL from the given challenges, -// or the empty string if there is none. -func ResourceMetadataURL(cs []challenge) string { - for _, c := range cs { - if u := c.Params["resource_metadata"]; u != "" { - return u - } - } - return "" -} - // ParseWWWAuthenticate parses a WWW-Authenticate header string. // The header format is defined in RFC 9110, Section 11.6.1, and can contain // one or more challenges, separated by commas. // It returns a slice of challenges or an error if one of the headers is malformed. -func ParseWWWAuthenticate(headers []string) ([]challenge, error) { - // GENERATED BY GEMINI 2.5 (human-tweaked) - var challenges []challenge +func ParseWWWAuthenticate(headers []string) ([]Challenge, error) { + var challenges []Challenge for _, h := range headers { challengeStrings, err := splitChallenges(h) if err != nil { @@ -151,7 +152,6 @@ func ParseWWWAuthenticate(headers []string) ([]challenge, error) { // It correctly handles commas within quoted strings and distinguishes between // commas separating auth-params and commas separating challenges. func splitChallenges(header string) ([]string, error) { - // GENERATED BY GEMINI 2.5. var challenges []string inQuotes := false start := 0 @@ -195,15 +195,14 @@ func splitChallenges(header string) ([]string, error) { // parseSingleChallenge parses a string containing exactly one challenge. // challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] -func parseSingleChallenge(s string) (challenge, error) { - // GENERATED BY GEMINI 2.5, human-tweaked. +func parseSingleChallenge(s string) (Challenge, error) { s = strings.TrimSpace(s) if s == "" { - return challenge{}, errors.New("empty challenge string") + return Challenge{}, errors.New("empty challenge string") } scheme, paramsStr, found := strings.Cut(s, " ") - c := challenge{Scheme: strings.ToLower(scheme)} + c := Challenge{Scheme: strings.ToLower(scheme)} if !found { return c, nil } @@ -215,7 +214,7 @@ func parseSingleChallenge(s string) (challenge, error) { // Find the end of the parameter key. keyEnd := strings.Index(paramsStr, "=") if keyEnd <= 0 { - return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + return Challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) } key := strings.TrimSpace(paramsStr[:keyEnd]) @@ -243,7 +242,7 @@ func parseSingleChallenge(s string) (challenge, error) { // A quoted string must be terminated. if i == len(paramsStr) { - return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + return Challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") } value = valBuilder.String() @@ -261,7 +260,7 @@ func parseSingleChallenge(s string) (challenge, error) { } } if value == "" { - return challenge{}, fmt.Errorf("no value for auth param %q", key) + return Challenge{}, fmt.Errorf("no value for auth param %q", key) } // Per RFC 9110, parameter keys are case-insensitive. @@ -272,10 +271,10 @@ func parseSingleChallenge(s string) (challenge, error) { paramsStr = strings.TrimSpace(paramsStr[1:]) } else if paramsStr != "" { // If there's content but it's not a new parameter, the format is wrong. - return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + return Challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) } } // Per RFC 9110, the scheme is case-insensitive. - return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil + return Challenge{Scheme: strings.ToLower(scheme), Params: params}, nil } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta_public.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta_public.go new file mode 100644 index 000000000..3bf7d9aca --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta_public.go @@ -0,0 +1,105 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +// This is a temporary file to expose the required objects to the main package. + +package oauthex + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} + +// Challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type Challenge struct { + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} diff --git a/vendor/golang.org/x/oauth2/deviceauth.go b/vendor/golang.org/x/oauth2/deviceauth.go index e99c92f39..e783a9437 100644 --- a/vendor/golang.org/x/oauth2/deviceauth.go +++ b/vendor/golang.org/x/oauth2/deviceauth.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "mime" "net/http" "net/url" "strings" @@ -116,10 +117,38 @@ func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAu return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) } if code := r.StatusCode; code < 200 || code > 299 { - return nil, &RetrieveError{ + retrieveError := &RetrieveError{ Response: r, Body: body, } + + content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) + switch content { + case "application/x-www-form-urlencoded", "text/plain": + // some endpoints return a query string + vals, err := url.ParseQuery(string(body)) + if err != nil { + return nil, retrieveError + } + retrieveError.ErrorCode = vals.Get("error") + retrieveError.ErrorDescription = vals.Get("error_description") + retrieveError.ErrorURI = vals.Get("error_uri") + default: + var tj struct { + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` + } + if json.Unmarshal(body, &tj) != nil { + return nil, retrieveError + } + retrieveError.ErrorCode = tj.ErrorCode + retrieveError.ErrorDescription = tj.ErrorDescription + retrieveError.ErrorURI = tj.ErrorURI + } + + return nil, retrieveError } da := &DeviceAuthResponse{} diff --git a/vendor/golang.org/x/oauth2/oauth2.go b/vendor/golang.org/x/oauth2/oauth2.go index 3e3b63069..5c527d31f 100644 --- a/vendor/golang.org/x/oauth2/oauth2.go +++ b/vendor/golang.org/x/oauth2/oauth2.go @@ -98,7 +98,7 @@ const ( // in the POST body as application/x-www-form-urlencoded parameters. AuthStyleInParams AuthStyle = 1 - // AuthStyleInHeader sends the client_id and client_password + // AuthStyleInHeader sends the client_id and client_secret // using HTTP Basic Authorization. This is an optional style // described in the OAuth2 RFC 6749 section 2.3.1. AuthStyleInHeader AuthStyle = 2 diff --git a/vendor/golang.org/x/oauth2/pkce.go b/vendor/golang.org/x/oauth2/pkce.go index cea8374d5..f99384f0f 100644 --- a/vendor/golang.org/x/oauth2/pkce.go +++ b/vendor/golang.org/x/oauth2/pkce.go @@ -51,7 +51,7 @@ func S256ChallengeFromVerifier(verifier string) string { return base64.RawURLEncoding.EncodeToString(sha[:]) } -// S256ChallengeOption derives a PKCE code challenge derived from verifier with +// S256ChallengeOption derives a PKCE code challenge from the verifier with // method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] // only. func S256ChallengeOption(verifier string) AuthCodeOption { diff --git a/vendor/golang.org/x/oauth2/token.go b/vendor/golang.org/x/oauth2/token.go index 239ec3296..e995eebb5 100644 --- a/vendor/golang.org/x/oauth2/token.go +++ b/vendor/golang.org/x/oauth2/token.go @@ -103,7 +103,7 @@ func (t *Token) WithExtra(extra any) *Token { } // Extra returns an extra field. -// Extra fields are key-value pairs returned by the server as a +// Extra fields are key-value pairs returned by the server as // part of the token retrieval response. func (t *Token) Extra(key string) any { if raw, ok := t.raw.(map[string]any); ok { diff --git a/vendor/golang.org/x/oauth2/transport.go b/vendor/golang.org/x/oauth2/transport.go index 8bbebbac9..9922ec331 100644 --- a/vendor/golang.org/x/oauth2/transport.go +++ b/vendor/golang.org/x/oauth2/transport.go @@ -58,7 +58,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { var cancelOnce sync.Once // CancelRequest does nothing. It used to be a legacy cancellation mechanism -// but now only it only logs on first use to warn that it's deprecated. +// but now only logs on first use to warn that it's deprecated. // // Deprecated: use contexts for cancellation instead. func (t *Transport) CancelRequest(req *http.Request) { diff --git a/vendor/golang.org/x/sys/cpu/cpu_x86.go b/vendor/golang.org/x/sys/cpu/cpu_x86.go index 1e642f330..f5723d4f7 100644 --- a/vendor/golang.org/x/sys/cpu/cpu_x86.go +++ b/vendor/golang.org/x/sys/cpu/cpu_x86.go @@ -64,6 +64,80 @@ func initOptions() { func archInit() { + // From internal/cpu + const ( + // eax bits + cpuid_AVXVNNI = 1 << 4 + + // ecx bits + cpuid_SSE3 = 1 << 0 + cpuid_PCLMULQDQ = 1 << 1 + cpuid_AVX512VBMI = 1 << 1 + cpuid_AVX512VBMI2 = 1 << 6 + cpuid_SSSE3 = 1 << 9 + cpuid_AVX512GFNI = 1 << 8 + cpuid_AVX512VAES = 1 << 9 + cpuid_AVX512VNNI = 1 << 11 + cpuid_AVX512BITALG = 1 << 12 + cpuid_FMA = 1 << 12 + cpuid_AVX512VPOPCNTDQ = 1 << 14 + cpuid_SSE41 = 1 << 19 + cpuid_SSE42 = 1 << 20 + cpuid_POPCNT = 1 << 23 + cpuid_AES = 1 << 25 + cpuid_OSXSAVE = 1 << 27 + cpuid_AVX = 1 << 28 + + // "Extended Feature Flag" bits returned in EBX for CPUID EAX=0x7 ECX=0x0 + cpuid_BMI1 = 1 << 3 + cpuid_AVX2 = 1 << 5 + cpuid_BMI2 = 1 << 8 + cpuid_ERMS = 1 << 9 + cpuid_AVX512F = 1 << 16 + cpuid_AVX512DQ = 1 << 17 + cpuid_ADX = 1 << 19 + cpuid_AVX512CD = 1 << 28 + cpuid_SHA = 1 << 29 + cpuid_AVX512BW = 1 << 30 + cpuid_AVX512VL = 1 << 31 + + // "Extended Feature Flag" bits returned in ECX for CPUID EAX=0x7 ECX=0x0 + cpuid_AVX512_VBMI = 1 << 1 + cpuid_AVX512_VBMI2 = 1 << 6 + cpuid_GFNI = 1 << 8 + cpuid_AVX512VPCLMULQDQ = 1 << 10 + cpuid_AVX512_BITALG = 1 << 12 + + // edx bits + cpuid_FSRM = 1 << 4 + // edx bits for CPUID 0x80000001 + cpuid_RDTSCP = 1 << 27 + ) + // Additional constants not in internal/cpu + const ( + // eax=1: edx + cpuid_SSE2 = 1 << 26 + // eax=1: ecx + cpuid_CX16 = 1 << 13 + cpuid_RDRAND = 1 << 30 + // eax=7,ecx=0: ebx + cpuid_RDSEED = 1 << 18 + cpuid_AVX512IFMA = 1 << 21 + cpuid_AVX512PF = 1 << 26 + cpuid_AVX512ER = 1 << 27 + // eax=7,ecx=0: edx + cpuid_AVX5124VNNIW = 1 << 2 + cpuid_AVX5124FMAPS = 1 << 3 + cpuid_AMXBF16 = 1 << 22 + cpuid_AMXTile = 1 << 24 + cpuid_AMXInt8 = 1 << 25 + // eax=7,ecx=1: eax + cpuid_AVX512BF16 = 1 << 5 + cpuid_AVXIFMA = 1 << 23 + // eax=7,ecx=1: edx + cpuid_AVXVNNIInt8 = 1 << 4 + ) + Initialized = true maxID, _, _, _ := cpuid(0, 0) @@ -73,90 +147,90 @@ func archInit() { } _, _, ecx1, edx1 := cpuid(1, 0) - X86.HasSSE2 = isSet(26, edx1) - - X86.HasSSE3 = isSet(0, ecx1) - X86.HasPCLMULQDQ = isSet(1, ecx1) - X86.HasSSSE3 = isSet(9, ecx1) - X86.HasFMA = isSet(12, ecx1) - X86.HasCX16 = isSet(13, ecx1) - X86.HasSSE41 = isSet(19, ecx1) - X86.HasSSE42 = isSet(20, ecx1) - X86.HasPOPCNT = isSet(23, ecx1) - X86.HasAES = isSet(25, ecx1) - X86.HasOSXSAVE = isSet(27, ecx1) - X86.HasRDRAND = isSet(30, ecx1) + X86.HasSSE2 = isSet(edx1, cpuid_SSE2) + + X86.HasSSE3 = isSet(ecx1, cpuid_SSE3) + X86.HasPCLMULQDQ = isSet(ecx1, cpuid_PCLMULQDQ) + X86.HasSSSE3 = isSet(ecx1, cpuid_SSSE3) + X86.HasFMA = isSet(ecx1, cpuid_FMA) + X86.HasCX16 = isSet(ecx1, cpuid_CX16) + X86.HasSSE41 = isSet(ecx1, cpuid_SSE41) + X86.HasSSE42 = isSet(ecx1, cpuid_SSE42) + X86.HasPOPCNT = isSet(ecx1, cpuid_POPCNT) + X86.HasAES = isSet(ecx1, cpuid_AES) + X86.HasOSXSAVE = isSet(ecx1, cpuid_OSXSAVE) + X86.HasRDRAND = isSet(ecx1, cpuid_RDRAND) var osSupportsAVX, osSupportsAVX512 bool // For XGETBV, OSXSAVE bit is required and sufficient. if X86.HasOSXSAVE { eax, _ := xgetbv() // Check if XMM and YMM registers have OS support. - osSupportsAVX = isSet(1, eax) && isSet(2, eax) + osSupportsAVX = isSet(eax, 1<<1) && isSet(eax, 1<<2) if runtime.GOOS == "darwin" { // Darwin requires special AVX512 checks, see cpu_darwin_x86.go osSupportsAVX512 = osSupportsAVX && darwinSupportsAVX512() } else { // Check if OPMASK and ZMM registers have OS support. - osSupportsAVX512 = osSupportsAVX && isSet(5, eax) && isSet(6, eax) && isSet(7, eax) + osSupportsAVX512 = osSupportsAVX && isSet(eax, 1<<5) && isSet(eax, 1<<6) && isSet(eax, 1<<7) } } - X86.HasAVX = isSet(28, ecx1) && osSupportsAVX + X86.HasAVX = isSet(ecx1, cpuid_AVX) && osSupportsAVX if maxID < 7 { return } eax7, ebx7, ecx7, edx7 := cpuid(7, 0) - X86.HasBMI1 = isSet(3, ebx7) - X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX - X86.HasBMI2 = isSet(8, ebx7) - X86.HasERMS = isSet(9, ebx7) - X86.HasRDSEED = isSet(18, ebx7) - X86.HasADX = isSet(19, ebx7) - - X86.HasAVX512 = isSet(16, ebx7) && osSupportsAVX512 // Because avx-512 foundation is the core required extension + X86.HasBMI1 = isSet(ebx7, cpuid_BMI1) + X86.HasAVX2 = isSet(ebx7, cpuid_AVX2) && osSupportsAVX + X86.HasBMI2 = isSet(ebx7, cpuid_BMI2) + X86.HasERMS = isSet(ebx7, cpuid_ERMS) + X86.HasRDSEED = isSet(ebx7, cpuid_RDSEED) + X86.HasADX = isSet(ebx7, cpuid_ADX) + + X86.HasAVX512 = isSet(ebx7, cpuid_AVX512F) && osSupportsAVX512 // Because avx-512 foundation is the core required extension if X86.HasAVX512 { X86.HasAVX512F = true - X86.HasAVX512CD = isSet(28, ebx7) - X86.HasAVX512ER = isSet(27, ebx7) - X86.HasAVX512PF = isSet(26, ebx7) - X86.HasAVX512VL = isSet(31, ebx7) - X86.HasAVX512BW = isSet(30, ebx7) - X86.HasAVX512DQ = isSet(17, ebx7) - X86.HasAVX512IFMA = isSet(21, ebx7) - X86.HasAVX512VBMI = isSet(1, ecx7) - X86.HasAVX5124VNNIW = isSet(2, edx7) - X86.HasAVX5124FMAPS = isSet(3, edx7) - X86.HasAVX512VPOPCNTDQ = isSet(14, ecx7) - X86.HasAVX512VPCLMULQDQ = isSet(10, ecx7) - X86.HasAVX512VNNI = isSet(11, ecx7) - X86.HasAVX512GFNI = isSet(8, ecx7) - X86.HasAVX512VAES = isSet(9, ecx7) - X86.HasAVX512VBMI2 = isSet(6, ecx7) - X86.HasAVX512BITALG = isSet(12, ecx7) + X86.HasAVX512CD = isSet(ebx7, cpuid_AVX512CD) + X86.HasAVX512ER = isSet(ebx7, cpuid_AVX512ER) + X86.HasAVX512PF = isSet(ebx7, cpuid_AVX512PF) + X86.HasAVX512VL = isSet(ebx7, cpuid_AVX512VL) + X86.HasAVX512BW = isSet(ebx7, cpuid_AVX512BW) + X86.HasAVX512DQ = isSet(ebx7, cpuid_AVX512DQ) + X86.HasAVX512IFMA = isSet(ebx7, cpuid_AVX512IFMA) + X86.HasAVX512VBMI = isSet(ecx7, cpuid_AVX512_VBMI) + X86.HasAVX5124VNNIW = isSet(edx7, cpuid_AVX5124VNNIW) + X86.HasAVX5124FMAPS = isSet(edx7, cpuid_AVX5124FMAPS) + X86.HasAVX512VPOPCNTDQ = isSet(ecx7, cpuid_AVX512VPOPCNTDQ) + X86.HasAVX512VPCLMULQDQ = isSet(ecx7, cpuid_AVX512VPCLMULQDQ) + X86.HasAVX512VNNI = isSet(ecx7, cpuid_AVX512VNNI) + X86.HasAVX512GFNI = isSet(ecx7, cpuid_AVX512GFNI) + X86.HasAVX512VAES = isSet(ecx7, cpuid_AVX512VAES) + X86.HasAVX512VBMI2 = isSet(ecx7, cpuid_AVX512VBMI2) + X86.HasAVX512BITALG = isSet(ecx7, cpuid_AVX512BITALG) } - X86.HasAMXTile = isSet(24, edx7) - X86.HasAMXInt8 = isSet(25, edx7) - X86.HasAMXBF16 = isSet(22, edx7) + X86.HasAMXTile = isSet(edx7, cpuid_AMXTile) + X86.HasAMXInt8 = isSet(edx7, cpuid_AMXInt8) + X86.HasAMXBF16 = isSet(edx7, cpuid_AMXBF16) // These features depend on the second level of extended features. if eax7 >= 1 { eax71, _, _, edx71 := cpuid(7, 1) if X86.HasAVX512 { - X86.HasAVX512BF16 = isSet(5, eax71) + X86.HasAVX512BF16 = isSet(eax71, cpuid_AVX512BF16) } if X86.HasAVX { - X86.HasAVXIFMA = isSet(23, eax71) - X86.HasAVXVNNI = isSet(4, eax71) - X86.HasAVXVNNIInt8 = isSet(4, edx71) + X86.HasAVXIFMA = isSet(eax71, cpuid_AVXIFMA) + X86.HasAVXVNNI = isSet(eax71, cpuid_AVXVNNI) + X86.HasAVXVNNIInt8 = isSet(edx71, cpuid_AVXVNNIInt8) } } } -func isSet(bitpos uint, value uint32) bool { - return value&(1<