Skip to content

Commit 332da63

Browse files
committed
refactor(vmcp): unify composite tools and optimizer as session decorators
Both composite tools and the optimizer now implement the MultiSession decorator pattern (same as hijackPreventionDecorator) rather than having bespoke SDK wiring in handleSessionRegistrationImpl. New session decorators: - session/compositetools: appends composite tools to Tools(), routes their CallTool invocations to per-session workflow executors - session/optimizerdec: replaces Tools() with [find_tool, call_tool]; find_tool routes through the optimizer, call_tool delegates to the underlying session for normal backend routing sessionmanager.Manager gains DecorateSession() to swap in a wrapped session after creation. handleSessionRegistrationImpl becomes a flat decoration sequence (apply compositetools → apply optimizer → register tools) with no branching on optimizer vs non-optimizer paths. adapter.WorkflowExecutor/WorkflowResult become type aliases for the compositetools package types so the two layers share a single definition. adapter.CreateOptimizerTools is deleted; its logic lives in optimizerdec.
1 parent 1301dbc commit 332da63

20 files changed

Lines changed: 1056 additions & 311 deletions

pkg/vmcp/composer/testhelpers_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ func newTestEngine(t *testing.T) *testEngine {
3030
t.Cleanup(ctrl.Finish)
3131

3232
mockRouter := routermocks.NewMockRouter(ctrl)
33+
// ResolveToolName is called by getToolInputSchema on every tool step.
34+
// For tests that use NewWorkflowEngine (no tools list), the result is
35+
// always nil, so a pass-through AnyTimes expectation is sufficient.
36+
mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()).
37+
DoAndReturn(func(_ context.Context, name string) (string, error) { return name, nil }).
38+
AnyTimes()
3339
mockBackend := mocks.NewMockBackendClient(ctrl)
3440
engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil) // nil elicitationHandler, stateStore, and auditor for simple tests
3541

pkg/vmcp/composer/workflow_engine.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ func (e *workflowEngine) executeToolStep(
434434
// Coerce expanded arguments to expected types based on backend tool schema.
435435
// Template expansion returns strings, but backend tools expect typed values
436436
// (integer, boolean, number) as defined in their InputSchema.
437-
rawSchema := e.getToolInputSchema(step.Tool)
437+
rawSchema := e.getToolInputSchema(ctx, step.Tool)
438438
s := schema.MakeSchema(rawSchema)
439439
if coerced, ok := s.TryCoerce(expandedArgs).(map[string]any); ok {
440440
expandedArgs = coerced
@@ -1250,11 +1250,19 @@ func (e *workflowEngine) auditStepSkipped(
12501250
}
12511251
}
12521252

1253-
// getToolInputSchema looks up a tool's InputSchema from the session-bound tools list.
1254-
// Returns nil if the engine has no tools list or the tool is not found.
1255-
func (e *workflowEngine) getToolInputSchema(toolName string) map[string]any {
1253+
// getToolInputSchema looks up a tool's InputSchema from the session-bound tools
1254+
// list. If toolName uses the dot convention "{workloadID}.{originalCapabilityName}",
1255+
// ResolveToolName is called to translate it to the conflict-resolved key before
1256+
// lookup. Returns nil if the engine has no tools list or the tool is not found.
1257+
func (e *workflowEngine) getToolInputSchema(ctx context.Context, toolName string) map[string]any {
1258+
resolved := toolName
1259+
if e.router != nil {
1260+
if r, err := e.router.ResolveToolName(ctx, toolName); err == nil {
1261+
resolved = r
1262+
}
1263+
}
12561264
for i := range e.tools {
1257-
if e.tools[i].Name == toolName {
1265+
if e.tools[i].Name == resolved {
12581266
return e.tools[i].InputSchema
12591267
}
12601268
}

pkg/vmcp/composer/workflow_engine_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@ func TestWorkflowEngine_ParallelExecution(t *testing.T) {
390390
defer ctrl.Finish()
391391

392392
mockRouter := routermocks.NewMockRouter(ctrl)
393+
mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()).
394+
DoAndReturn(func(_ context.Context, name string) (string, error) { return name, nil }).
395+
AnyTimes()
393396
mockBackend := mocks.NewMockBackendClient(ctrl)
394397
stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour)
395398
engine := NewWorkflowEngine(mockRouter, mockBackend, nil, stateStore, nil)
@@ -752,6 +755,9 @@ func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing
752755
t.Cleanup(ctrl.Finish)
753756

754757
mockRouter := routermocks.NewMockRouter(ctrl)
758+
mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()).
759+
DoAndReturn(func(_ context.Context, name string) (string, error) { return name, nil }).
760+
AnyTimes()
755761
mockBackend := mocks.NewMockBackendClient(ctrl)
756762

757763
tools := []vmcp.Tool{
@@ -810,6 +816,9 @@ func TestWorkflowEngine_SessionEngine_ToolNotInList_ReturnsNilSchema(t *testing.
810816
t.Cleanup(ctrl.Finish)
811817

812818
mockRouter := routermocks.NewMockRouter(ctrl)
819+
mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()).
820+
DoAndReturn(func(_ context.Context, name string) (string, error) { return name, nil }).
821+
AnyTimes()
813822
mockBackend := mocks.NewMockBackendClient(ctrl)
814823

815824
// Tools list does not include "other_tool".

pkg/vmcp/router/default_router.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ func (*defaultRouter) RouteTool(ctx context.Context, toolName string) (*vmcp.Bac
8989
)
9090
}
9191

92+
// ResolveToolName returns toolName unchanged. The defaultRouter has no static
93+
// routing table, so dot-convention resolution is not available; the caller
94+
// should already be using resolved names when working with this router.
95+
func (*defaultRouter) ResolveToolName(_ context.Context, toolName string) (string, error) {
96+
return toolName, nil
97+
}
98+
9299
// RouteResource resolves a resource URI to its backend target.
93100
// With lazy discovery, this method gets capabilities from the request context
94101
// instead of using a cached routing table.

pkg/vmcp/router/mocks/mock_router.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/vmcp/router/router.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ type Router interface {
2828
// Returns ErrToolNotFound if the tool doesn't exist in any backend.
2929
RouteTool(ctx context.Context, toolName string) (*vmcp.BackendTarget, error)
3030

31+
// ResolveToolName translates a tool name (which may use the dot-convention
32+
// "{workloadID}.{originalCapabilityName}") to the conflict-resolved routing
33+
// table key used in the session tools list. Returns the input unchanged if
34+
// it already matches exactly or if the router has no routing table.
35+
// Returns ErrToolNotFound if the name cannot be resolved.
36+
ResolveToolName(ctx context.Context, toolName string) (string, error)
37+
3138
// RouteResource resolves a resource URI to its backend target.
3239
// Returns ErrResourceNotFound if the resource doesn't exist in any backend.
3340
RouteResource(ctx context.Context, uri string) (*vmcp.BackendTarget, error)

pkg/vmcp/router/session_router.go

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package router
66
import (
77
"context"
88
"fmt"
9+
"strings"
910

1011
"github.com/stacklok/toolhive/pkg/vmcp"
1112
)
@@ -28,15 +29,74 @@ func NewSessionRouter(rt *vmcp.RoutingTable) Router {
2829

2930
// RouteTool resolves a tool name to its backend target using the session's
3031
// routing table directly.
32+
//
33+
// Two naming conventions are supported:
34+
//
35+
// 1. Exact key: the resolved/conflict-resolved name stored in the routing
36+
// table (e.g. "my-backend_echo" after prefix conflict resolution).
37+
//
38+
// 2. Dot convention "{workloadID}.{toolName}": the tool name is the original
39+
// backend capability name and the workload ID is the prefix. This mirrors
40+
// the isToolStepAccessible logic used when registering composite tools and
41+
// lets workflow step definitions remain stable regardless of the conflict
42+
// resolution strategy in use.
43+
//
44+
// The dot convention is necessary because composite workflow steps reference
45+
// tools by their pre-conflict-resolution name (e.g. "my-backend.echo"), while
46+
// the routing table may store them under a prefixed key ("my-backend_echo").
3147
func (r *sessionRouter) RouteTool(_ context.Context, toolName string) (*vmcp.BackendTarget, error) {
3248
if r.routingTable == nil || r.routingTable.Tools == nil {
3349
return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
3450
}
35-
target, exists := r.routingTable.Tools[toolName]
36-
if !exists {
37-
return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
51+
52+
// Fast path: exact key match.
53+
if target, exists := r.routingTable.Tools[toolName]; exists {
54+
return target, nil
3855
}
39-
return target, nil
56+
57+
// Fallback: dot convention "{workloadID}.{toolName}".
58+
// Workload IDs are Kubernetes resource names and cannot contain dots,
59+
// so the first dot unambiguously separates the workload ID from the
60+
// original backend capability name.
61+
if dotIdx := strings.Index(toolName, "."); dotIdx > 0 {
62+
workloadID := toolName[:dotIdx]
63+
capName := toolName[dotIdx+1:]
64+
for resolvedName, target := range r.routingTable.Tools {
65+
if target.WorkloadID == workloadID && target.GetBackendCapabilityName(resolvedName) == capName {
66+
return target, nil
67+
}
68+
}
69+
}
70+
71+
return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
72+
}
73+
74+
// ResolveToolName returns the routing table key (conflict-resolved name) for
75+
// toolName. If toolName is an exact key it is returned unchanged. If it uses
76+
// the dot convention "{workloadID}.{originalCapabilityName}", the matching
77+
// routing table key is returned. Returns ErrToolNotFound if unresolvable.
78+
func (r *sessionRouter) ResolveToolName(_ context.Context, toolName string) (string, error) {
79+
if r.routingTable == nil || r.routingTable.Tools == nil {
80+
return "", fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
81+
}
82+
83+
// Fast path: exact key match.
84+
if _, exists := r.routingTable.Tools[toolName]; exists {
85+
return toolName, nil
86+
}
87+
88+
// Fallback: dot convention "{workloadID}.{toolName}".
89+
if dotIdx := strings.Index(toolName, "."); dotIdx > 0 {
90+
workloadID := toolName[:dotIdx]
91+
capName := toolName[dotIdx+1:]
92+
for resolvedName, target := range r.routingTable.Tools {
93+
if target.WorkloadID == workloadID && target.GetBackendCapabilityName(resolvedName) == capName {
94+
return resolvedName, nil
95+
}
96+
}
97+
}
98+
99+
return "", fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
40100
}
41101

42102
// RouteResource resolves a resource URI to its backend target using the

pkg/vmcp/router/session_router_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,54 @@ func TestSessionRouter_RouteTool(t *testing.T) {
6565
expectError: true,
6666
errorContains: "tool not found",
6767
},
68+
{
69+
// Composite workflow steps use "{workloadID}.{toolName}" where toolName
70+
// is the original backend capability name. With prefix conflict resolution
71+
// the routing table key is "{workloadID}_toolName", so an exact match
72+
// fails. The dot-convention fallback must resolve it correctly.
73+
name: "dot convention resolved via workload ID and original capability name",
74+
routingTable: &vmcp.RoutingTable{
75+
Tools: map[string]*vmcp.BackendTarget{
76+
"my-backend_echo": {
77+
WorkloadID: "my-backend",
78+
WorkloadName: "My Backend",
79+
BaseURL: "http://my-backend:8080",
80+
OriginalCapabilityName: "echo",
81+
},
82+
},
83+
},
84+
toolName: "my-backend.echo",
85+
expectedID: "my-backend",
86+
expectError: false,
87+
},
88+
{
89+
name: "dot convention: workload not in session",
90+
routingTable: &vmcp.RoutingTable{
91+
Tools: map[string]*vmcp.BackendTarget{
92+
"other-backend_echo": {
93+
WorkloadID: "other-backend",
94+
OriginalCapabilityName: "echo",
95+
},
96+
},
97+
},
98+
toolName: "my-backend.echo",
99+
expectError: true,
100+
errorContains: "tool not found",
101+
},
102+
{
103+
name: "dot convention: capability name mismatch",
104+
routingTable: &vmcp.RoutingTable{
105+
Tools: map[string]*vmcp.BackendTarget{
106+
"my-backend_echo": {
107+
WorkloadID: "my-backend",
108+
OriginalCapabilityName: "echo",
109+
},
110+
},
111+
},
112+
toolName: "my-backend.fetch",
113+
expectError: true,
114+
errorContains: "tool not found",
115+
},
68116
}
69117

70118
for _, tt := range tests {
@@ -87,6 +135,74 @@ func TestSessionRouter_RouteTool(t *testing.T) {
87135
}
88136
}
89137

138+
func TestSessionRouter_ResolveToolName(t *testing.T) {
139+
t.Parallel()
140+
141+
tests := []struct {
142+
name string
143+
routingTable *vmcp.RoutingTable
144+
toolName string
145+
expectedName string
146+
expectError bool
147+
errorContains string
148+
}{
149+
{
150+
name: "exact key returned unchanged",
151+
routingTable: &vmcp.RoutingTable{
152+
Tools: map[string]*vmcp.BackendTarget{
153+
"my-backend_echo": {WorkloadID: "my-backend", OriginalCapabilityName: "echo"},
154+
},
155+
},
156+
toolName: "my-backend_echo",
157+
expectedName: "my-backend_echo",
158+
},
159+
{
160+
name: "dot convention resolves to routing table key",
161+
routingTable: &vmcp.RoutingTable{
162+
Tools: map[string]*vmcp.BackendTarget{
163+
"my-backend_echo": {WorkloadID: "my-backend", OriginalCapabilityName: "echo"},
164+
},
165+
},
166+
toolName: "my-backend.echo",
167+
expectedName: "my-backend_echo",
168+
},
169+
{
170+
name: "not found returns error",
171+
routingTable: &vmcp.RoutingTable{
172+
Tools: make(map[string]*vmcp.BackendTarget),
173+
},
174+
toolName: "missing_tool",
175+
expectError: true,
176+
errorContains: "tool not found",
177+
},
178+
{
179+
name: "nil routing table returns error",
180+
routingTable: nil,
181+
toolName: "any_tool",
182+
expectError: true,
183+
errorContains: "tool not found",
184+
},
185+
}
186+
187+
for _, tt := range tests {
188+
t.Run(tt.name, func(t *testing.T) {
189+
t.Parallel()
190+
191+
r := router.NewSessionRouter(tt.routingTable)
192+
resolved, err := r.ResolveToolName(context.Background(), tt.toolName)
193+
194+
if tt.expectError {
195+
require.Error(t, err)
196+
assert.Contains(t, err.Error(), tt.errorContains)
197+
assert.Empty(t, resolved)
198+
} else {
199+
require.NoError(t, err)
200+
assert.Equal(t, tt.expectedName, resolved)
201+
}
202+
})
203+
}
204+
}
205+
90206
func TestSessionRouter_RouteResource(t *testing.T) {
91207
t.Parallel()
92208

pkg/vmcp/server/adapter/handler_factory.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
2020
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
2121
"github.com/stacklok/toolhive/pkg/vmcp/router"
22+
"github.com/stacklok/toolhive/pkg/vmcp/session/compositetools"
2223
)
2324

2425
//go:generate mockgen -destination=mocks/mock_handler_factory.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/server/adapter HandlerFactory
@@ -43,20 +44,13 @@ type HandlerFactory interface {
4344
}
4445

4546
// WorkflowExecutor executes composite tool workflows.
46-
// This interface abstracts the composer to enable testing without full composer setup.
47-
type WorkflowExecutor interface {
48-
// ExecuteWorkflow executes the workflow with the given parameters.
49-
ExecuteWorkflow(ctx context.Context, params map[string]any) (*WorkflowResult, error)
50-
}
47+
// Type alias for compositetools.WorkflowExecutor so that adapter consumers and
48+
// the session decorator share a single interface definition.
49+
type WorkflowExecutor = compositetools.WorkflowExecutor
5150

5251
// WorkflowResult represents the result of a workflow execution.
53-
type WorkflowResult struct {
54-
// Output contains the workflow output data (typically from the last step).
55-
Output map[string]any
56-
57-
// Error contains error information if the workflow failed.
58-
Error error
59-
}
52+
// Type alias for compositetools.WorkflowResult.
53+
type WorkflowResult = compositetools.WorkflowResult
6054

6155
// DefaultHandlerFactory creates MCP request handlers that route to backend workloads.
6256
type DefaultHandlerFactory struct {

0 commit comments

Comments
 (0)