Skip to content

Commit 031616d

Browse files
committed
Merge remote-tracking branch 'origin/main' into iplay88keys/mcp-server-ca-cert
Signed-off-by: Jeremy Alvis <jeremy.alvis@solo.io>
2 parents 9d90c17 + f95167b commit 031616d

67 files changed

Lines changed: 3478 additions & 674 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.devcontainer/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"build": {
44
"dockerfile": "Dockerfile",
55
"args": {
6-
"TOOLS_GO_VERSION": "1.26.1",
6+
"TOOLS_GO_VERSION": "1.26.2",
77
"TOOLS_NODE_VERSION": "24.13.0",
88
"TOOLS_UV_VERSION": "0.10.4",
99
"TOOLS_K9S_VERSION": "0.50.4",

.github/workflows/ci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ jobs:
203203
run: |
204204
helm unittest helm/kagent
205205
helm unittest helm/tools/querydoc
206+
helm unittest helm/tools/grafana-mcp
206207
207208
ui-tests:
208209
runs-on: ubuntu-latest

Makefile

Lines changed: 135 additions & 93 deletions
Large diffs are not rendered by default.

go/adk/pkg/agent/agent.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/go-logr/logr"
1111
"github.com/kagent-dev/kagent/go/adk/pkg/mcp"
1212
"github.com/kagent-dev/kagent/go/adk/pkg/models"
13+
"github.com/kagent-dev/kagent/go/adk/pkg/sts"
1314
"github.com/kagent-dev/kagent/go/adk/pkg/tools"
1415
"github.com/kagent-dev/kagent/go/api/adk"
1516
"google.golang.org/adk/agent"
@@ -33,23 +34,28 @@ const (
3334
// agentName is used as the ADK agent identity (appears in event Author field).
3435
// extraTools are appended to the agent's tool list (e.g. save_memory).
3536
func CreateGoogleADKAgent(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, extraTools ...tool.Tool) (agent.Agent, error) {
36-
a, _, err := CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentName, extraTools...)
37+
a, _, err := CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentName, nil, extraTools...)
3738
return a, err
3839
}
3940

4041
// CreateGoogleADKAgentWithSubagentSessionIDs creates a Google ADK agent and a
4142
// map of remote-subagent tool name → A2A context session ID (for stamping
4243
// outbound A2A events). Callers that only need the agent can use
4344
// CreateGoogleADKAgent.
44-
func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, extraTools ...tool.Tool) (agent.Agent, map[string]string, error) {
45+
// Optional stsPlugin can be provided for token propagation to MCP tools.
46+
func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, stsPlugin *sts.TokenPropagationPlugin, extraTools ...tool.Tool) (agent.Agent, map[string]string, error) {
4547
log := logr.FromContextOrDiscard(ctx)
4648

4749
if agentConfig == nil {
4850
return nil, nil, fmt.Errorf("agent config is required")
4951
}
5052

5153
propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true"
52-
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken)
54+
var dynamicHeaderProvider mcp.DynamicHeaderProvider
55+
if stsPlugin != nil {
56+
dynamicHeaderProvider = stsPlugin.HeaderProvider
57+
}
58+
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, dynamicHeaderProvider)
5359
subagentSessionIDs := make(map[string]string)
5460

5561
var remoteAgentTools []tool.Tool
@@ -104,6 +110,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
104110
beforeToolCallbacks := []llmagent.BeforeToolCallback{}
105111
// Strip synthetic HITL tool messages from the model request to avoid unnecessary token usage.
106112
beforeModelCallbacks := []llmagent.BeforeModelCallback{}
113+
107114
if len(approvalSet) > 0 {
108115
log.Info("Wiring approval callback", "toolCount", len(approvalSet))
109116
beforeToolCallbacks = append(beforeToolCallbacks, MakeApprovalCallback(approvalSet))

go/adk/pkg/mcp/registry.go

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ import (
1818
"google.golang.org/adk/tool/mcptoolset"
1919
)
2020

21+
// DynamicHeaderProvider is a function that returns headers to inject into MCP requests.
22+
// It receives the context and should return a map of headers.
23+
// This is used for dynamic token injection (e.g., STS tokens) per session.
24+
type DynamicHeaderProvider func(ctx context.Context) map[string]string
25+
2126
const (
2227
// Default timeout matching Python KAGENT_REMOTE_AGENT_TIMEOUT
2328
defaultTimeout = 30 * time.Minute
@@ -62,9 +67,10 @@ func allowedRequestHeaders(ctx context.Context, allowed []string) map[string]str
6267
type mcpServerParams struct {
6368
URL string
6469
Headers map[string]string
65-
AllowedHeaders []string // header names to forward from incoming request
66-
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
67-
ServerType string // "http" or "sse"
70+
AllowedHeaders []string // header names to forward from incoming request
71+
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
72+
HeaderProvider DynamicHeaderProvider // optional per-request headers derived from invocation context (e.g., STS exchanged access tokens)
73+
ServerType string // "http" or "sse"
6874
Timeout *float64
6975
SseReadTimeout *float64
7076
TLSInsecureSkipVerify *bool
@@ -79,7 +85,16 @@ type mcpServerParams struct {
7985
// When propagateToken is true, Authorization is forwarded to every MCP server
8086
// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin
8187
// behaviour triggered by KAGENT_PROPAGATE_TOKEN.
82-
func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset {
88+
//
89+
// Optional headerProvider can be used to inject per-request headers
90+
// derived from invocation context (e.g., STS exchanged access tokens).
91+
func CreateToolsets(
92+
ctx context.Context,
93+
httpTools []adk.HttpMcpServerConfig,
94+
sseTools []adk.SseMcpServerConfig,
95+
propagateToken bool,
96+
headerProvider DynamicHeaderProvider,
97+
) []tool.Toolset {
8398
log := logr.FromContextOrDiscard(ctx)
8499
var toolsets []tool.Toolset
85100

@@ -90,6 +105,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
90105
Headers: httpTool.Params.Headers,
91106
AllowedHeaders: httpTool.AllowedHeaders,
92107
PropagateToken: propagateToken,
108+
HeaderProvider: headerProvider,
93109
ServerType: "http",
94110
Timeout: httpTool.Params.Timeout,
95111
SseReadTimeout: httpTool.Params.SseReadTimeout,
@@ -111,6 +127,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
111127
Headers: sseTool.Params.Headers,
112128
AllowedHeaders: sseTool.AllowedHeaders,
113129
PropagateToken: propagateToken,
130+
HeaderProvider: headerProvider,
114131
ServerType: "sse",
115132
Timeout: sseTool.Params.Timeout,
116133
SseReadTimeout: sseTool.Params.SseReadTimeout,
@@ -208,12 +225,13 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
208225
}
209226

210227
var httpTransport http.RoundTripper = baseTransport
211-
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken {
228+
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken || params.HeaderProvider != nil {
212229
httpTransport = &headerRoundTripper{
213230
base: baseTransport,
214231
headers: params.Headers,
215232
allowedHeaders: params.AllowedHeaders,
216233
propagateToken: params.PropagateToken,
234+
headerProvider: params.HeaderProvider,
217235
}
218236
}
219237

@@ -239,18 +257,20 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
239257
}
240258

241259
// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
242-
// requests. It supports three sources of headers, applied in this order so that
260+
// requests. It supports four sources of headers, applied in this order so that
243261
// higher-priority sources win on collision:
244262
// 1. propagateToken: when true, Authorization is read from the incoming A2A
245263
// CallContext and forwarded unconditionally (independent of allowedHeaders).
246264
// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext.
247-
// 3. headers: static key/value pairs configured on the MCP server spec (highest
265+
// 3. headerProvider: runtime headers derived from ADK context, such as STS tokens.
266+
// 4. headers: static key/value pairs configured on the MCP server spec (highest
248267
// priority — always wins).
249268
type headerRoundTripper struct {
250269
base http.RoundTripper
251270
headers map[string]string
252271
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
253272
propagateToken bool // when true, Authorization is forwarded independently
273+
headerProvider DynamicHeaderProvider
254274
}
255275

256276
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -273,6 +293,13 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
273293
req.Header.Set(k, v)
274294
}
275295

296+
// Dynamic headers (e.g., STS access tokens) override propagated/allowed headers.
297+
if rt.headerProvider != nil {
298+
for key, value := range rt.headerProvider(req.Context()) {
299+
req.Header.Set(key, value)
300+
}
301+
}
302+
276303
// Apply static headers last — they take precedence over all dynamic sources.
277304
for key, value := range rt.headers {
278305
req.Header.Set(key, value)

go/adk/pkg/mcp/registry_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,80 @@ func TestAllowedRequestHeaders_ReturnsNilWhenNoMatches(t *testing.T) {
319319
t.Errorf("expected nil when no allowed headers are present, got %v", got)
320320
}
321321
}
322+
323+
// TestDynamicHeaders_OverridePropagatedAndAllowedHeaders verifies dynamic headers
324+
// take precedence over propagated and allowed request headers.
325+
func TestDynamicHeaders_OverridePropagatedAndAllowedHeaders(t *testing.T) {
326+
t.Parallel()
327+
var capturedAuth, capturedCustom string
328+
329+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
330+
capturedAuth = r.Header.Get("Authorization")
331+
capturedCustom = r.Header.Get("X-Custom")
332+
w.WriteHeader(http.StatusOK)
333+
}))
334+
defer srv.Close()
335+
336+
ctx := a2aCtx(map[string][]string{
337+
"Authorization": {"Bearer incoming"},
338+
"X-Custom": {"custom-from-request"},
339+
})
340+
341+
rt := &headerRoundTripper{
342+
base: http.DefaultTransport,
343+
propagateToken: true,
344+
allowedHeaders: []string{"Authorization", "X-Custom"},
345+
headerProvider: func(context.Context) map[string]string {
346+
return map[string]string{
347+
"Authorization": "Bearer sts-exchanged",
348+
"X-Custom": "custom-from-dynamic",
349+
}
350+
},
351+
}
352+
353+
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
354+
resp, err := rt.RoundTrip(req)
355+
if err != nil {
356+
t.Fatalf("RoundTrip failed: %v", err)
357+
}
358+
resp.Body.Close()
359+
360+
if capturedAuth != "Bearer sts-exchanged" {
361+
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer sts-exchanged")
362+
}
363+
if capturedCustom != "custom-from-dynamic" {
364+
t.Errorf("X-Custom: got %q, want %q", capturedCustom, "custom-from-dynamic")
365+
}
366+
}
367+
368+
// TestStaticHeaders_OverrideDynamic verifies static configured headers remain
369+
// the highest-precedence source.
370+
func TestStaticHeaders_OverrideDynamic(t *testing.T) {
371+
t.Parallel()
372+
var capturedAuth string
373+
374+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
375+
capturedAuth = r.Header.Get("Authorization")
376+
w.WriteHeader(http.StatusOK)
377+
}))
378+
defer srv.Close()
379+
380+
rt := &headerRoundTripper{
381+
base: http.DefaultTransport,
382+
headers: map[string]string{"Authorization": "Bearer static"},
383+
headerProvider: func(context.Context) map[string]string {
384+
return map[string]string{"Authorization": "Bearer dynamic"}
385+
},
386+
}
387+
388+
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
389+
resp, err := rt.RoundTrip(req)
390+
if err != nil {
391+
t.Fatalf("RoundTrip failed: %v", err)
392+
}
393+
resp.Body.Close()
394+
395+
if capturedAuth != "Bearer static" {
396+
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer static")
397+
}
398+
}

go/adk/pkg/runner/adapter.go

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@ package runner
33
import (
44
"context"
55
"fmt"
6+
"os"
67
"strings"
78

9+
"github.com/go-logr/logr"
810
"github.com/kagent-dev/kagent/go/adk/pkg/agent"
911
kagentmemory "github.com/kagent-dev/kagent/go/adk/pkg/memory"
1012
"github.com/kagent-dev/kagent/go/adk/pkg/session"
13+
"github.com/kagent-dev/kagent/go/adk/pkg/sts"
1114
"github.com/kagent-dev/kagent/go/api/adk"
1215
adkmemory "google.golang.org/adk/memory"
16+
adkplugin "google.golang.org/adk/plugin"
1317
"google.golang.org/adk/runner"
1418
adksession "google.golang.org/adk/session"
1519
adktool "google.golang.org/adk/tool"
@@ -31,6 +35,8 @@ func CreateRunnerConfig(
3135
appName string,
3236
memoryService *kagentmemory.KagentMemoryService,
3337
) (runner.Config, map[string]string, error) {
38+
log := logr.FromContextOrDiscard(ctx)
39+
3440
var extraTools []adktool.Tool
3541
if memoryService != nil {
3642
saveTool, err := kagentmemory.NewSaveMemoryTool(memoryService)
@@ -40,7 +46,12 @@ func CreateRunnerConfig(
4046
extraTools = append(extraTools, saveTool)
4147
}
4248

43-
adkAgent, subagentSessionIDs, err := agent.CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentNameFromAppName(appName), extraTools...)
49+
stsPlugin, err := buildTokenPropagationPlugin(ctx, log)
50+
if err != nil {
51+
return runner.Config{}, nil, err
52+
}
53+
54+
adkAgent, subagentSessionIDs, err := agent.CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentNameFromAppName(appName), stsPlugin, extraTools...)
4455
if err != nil {
4556
return runner.Config{}, nil, fmt.Errorf("failed to create agent: %w", err)
4657
}
@@ -61,11 +72,57 @@ func CreateRunnerConfig(
6172
runnerMemory = memoryService
6273
}
6374

75+
var adkPlugins []*adkplugin.Plugin
76+
if stsPlugin != nil {
77+
p, err := stsPlugin.ADKPlugin()
78+
if err != nil {
79+
return runner.Config{}, nil, fmt.Errorf("failed to create STS ADK plugin: %w", err)
80+
}
81+
if p != nil {
82+
adkPlugins = append(adkPlugins, p)
83+
}
84+
}
85+
6486
cfg := runner.Config{
6587
AppName: appName,
6688
Agent: adkAgent,
6789
SessionService: adkSessionService,
6890
MemoryService: runnerMemory,
91+
PluginConfig: runner.PluginConfig{
92+
Plugins: adkPlugins,
93+
},
6994
}
95+
7096
return cfg, subagentSessionIDs, nil
7197
}
98+
99+
func buildTokenPropagationPlugin(ctx context.Context, log logr.Logger) (*sts.TokenPropagationPlugin, error) {
100+
propagateToken := strings.EqualFold(strings.TrimSpace(os.Getenv("KAGENT_PROPAGATE_TOKEN")), "true")
101+
stsWellKnownURI := strings.TrimSpace(os.Getenv("STS_WELL_KNOWN_URI"))
102+
if !propagateToken && stsWellKnownURI == "" {
103+
return nil, nil
104+
}
105+
106+
// Propagate-only mode: keep parity with Python by enabling plugin without STS exchange.
107+
if stsWellKnownURI == "" {
108+
log.Info("Enabling token propagation plugin without STS exchange")
109+
return sts.NewTokenPropagationPlugin(nil, log), nil
110+
}
111+
defaultSTSConfig := sts.DefaultSTSConfig(stsWellKnownURI)
112+
113+
integration, err := sts.NewSTSIntegration(
114+
stsWellKnownURI,
115+
"",
116+
nil, // fetchActorToken
117+
nil, // getSubjectToken
118+
defaultSTSConfig.Timeout,
119+
*defaultSTSConfig.VerifySSL,
120+
defaultSTSConfig.UseIssuerHost,
121+
)
122+
if err != nil {
123+
return nil, fmt.Errorf("failed to initialize STS integration: %w", err)
124+
}
125+
126+
log.Info("Enabling STS token propagation plugin", "wellKnownURI", stsWellKnownURI)
127+
return sts.NewTokenPropagationPlugin(integration, log), nil
128+
}

0 commit comments

Comments
 (0)