@@ -18,6 +18,11 @@ import (
1818 "google.golang.org/adk/tool/mcptoolset"
1919)
2020
21+ // DynamicHeaderProvider is a function that returns headers to inject into MCP requests.
22+ // It receives the context and should return a map of headers.
23+ // This is used for dynamic token injection (e.g., STS tokens) per session.
24+ type DynamicHeaderProvider func (ctx context.Context ) map [string ]string
25+
2126const (
2227 // Default timeout matching Python KAGENT_REMOTE_AGENT_TIMEOUT
2328 defaultTimeout = 30 * time .Minute
@@ -62,9 +67,10 @@ func allowedRequestHeaders(ctx context.Context, allowed []string) map[string]str
6267type mcpServerParams struct {
6368 URL string
6469 Headers map [string ]string
65- AllowedHeaders []string // header names to forward from incoming request
66- PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
67- ServerType string // "http" or "sse"
70+ AllowedHeaders []string // header names to forward from incoming request
71+ PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
72+ HeaderProvider DynamicHeaderProvider // optional per-request headers derived from invocation context (e.g., STS exchanged access tokens)
73+ ServerType string // "http" or "sse"
6874 Timeout * float64
6975 SseReadTimeout * float64
7076 TLSInsecureSkipVerify * bool
@@ -79,7 +85,16 @@ type mcpServerParams struct {
7985// When propagateToken is true, Authorization is forwarded to every MCP server
8086// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin
8187// behaviour triggered by KAGENT_PROPAGATE_TOKEN.
82- func CreateToolsets (ctx context.Context , httpTools []adk.HttpMcpServerConfig , sseTools []adk.SseMcpServerConfig , propagateToken bool ) []tool.Toolset {
88+ //
89+ // Optional headerProvider can be used to inject per-request headers
90+ // derived from invocation context (e.g., STS exchanged access tokens).
91+ func CreateToolsets (
92+ ctx context.Context ,
93+ httpTools []adk.HttpMcpServerConfig ,
94+ sseTools []adk.SseMcpServerConfig ,
95+ propagateToken bool ,
96+ headerProvider DynamicHeaderProvider ,
97+ ) []tool.Toolset {
8398 log := logr .FromContextOrDiscard (ctx )
8499 var toolsets []tool.Toolset
85100
@@ -90,6 +105,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
90105 Headers : httpTool .Params .Headers ,
91106 AllowedHeaders : httpTool .AllowedHeaders ,
92107 PropagateToken : propagateToken ,
108+ HeaderProvider : headerProvider ,
93109 ServerType : "http" ,
94110 Timeout : httpTool .Params .Timeout ,
95111 SseReadTimeout : httpTool .Params .SseReadTimeout ,
@@ -111,6 +127,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
111127 Headers : sseTool .Params .Headers ,
112128 AllowedHeaders : sseTool .AllowedHeaders ,
113129 PropagateToken : propagateToken ,
130+ HeaderProvider : headerProvider ,
114131 ServerType : "sse" ,
115132 Timeout : sseTool .Params .Timeout ,
116133 SseReadTimeout : sseTool .Params .SseReadTimeout ,
@@ -208,12 +225,13 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
208225 }
209226
210227 var httpTransport http.RoundTripper = baseTransport
211- if len (params .Headers ) > 0 || len (params .AllowedHeaders ) > 0 || params .PropagateToken {
228+ if len (params .Headers ) > 0 || len (params .AllowedHeaders ) > 0 || params .PropagateToken || params . HeaderProvider != nil {
212229 httpTransport = & headerRoundTripper {
213230 base : baseTransport ,
214231 headers : params .Headers ,
215232 allowedHeaders : params .AllowedHeaders ,
216233 propagateToken : params .PropagateToken ,
234+ headerProvider : params .HeaderProvider ,
217235 }
218236 }
219237
@@ -239,18 +257,20 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
239257}
240258
241259// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
242- // requests. It supports three sources of headers, applied in this order so that
260+ // requests. It supports four sources of headers, applied in this order so that
243261// higher-priority sources win on collision:
244262// 1. propagateToken: when true, Authorization is read from the incoming A2A
245263// CallContext and forwarded unconditionally (independent of allowedHeaders).
246264// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext.
247- // 3. headers: static key/value pairs configured on the MCP server spec (highest
265+ // 3. headerProvider: runtime headers derived from ADK context, such as STS tokens.
266+ // 4. headers: static key/value pairs configured on the MCP server spec (highest
248267// priority — always wins).
249268type headerRoundTripper struct {
250269 base http.RoundTripper
251270 headers map [string ]string
252271 allowedHeaders []string // header names (case-insensitive) to forward from A2A context
253272 propagateToken bool // when true, Authorization is forwarded independently
273+ headerProvider DynamicHeaderProvider
254274}
255275
256276func (rt * headerRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
@@ -273,6 +293,13 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
273293 req .Header .Set (k , v )
274294 }
275295
296+ // Dynamic headers (e.g., STS access tokens) override propagated/allowed headers.
297+ if rt .headerProvider != nil {
298+ for key , value := range rt .headerProvider (req .Context ()) {
299+ req .Header .Set (key , value )
300+ }
301+ }
302+
276303 // Apply static headers last — they take precedence over all dynamic sources.
277304 for key , value := range rt .headers {
278305 req .Header .Set (key , value )
0 commit comments