Skip to content

Commit 1301dbc

Browse files
committed
Bring composite tools into session abstraction
Composite tool workflow engines were previously relying on the discovery middleware to inject DiscoveredCapabilities into the request context so that the shared stateless router could route backend tool calls within workflows. This created an implicit coupling between the middleware and composite tool execution that made unit-testing harder and was a source of integration bugs. Affected components: pkg/vmcp/router, pkg/vmcp/composer, pkg/vmcp/server, pkg/vmcp/discovery Related-to: #3872
1 parent ca7f127 commit 1301dbc

10 files changed

Lines changed: 912 additions & 82 deletions

pkg/vmcp/composer/workflow_engine.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"github.com/stacklok/toolhive/pkg/audit"
1818
"github.com/stacklok/toolhive/pkg/vmcp"
1919
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
20-
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
2120
"github.com/stacklok/toolhive/pkg/vmcp/router"
2221
"github.com/stacklok/toolhive/pkg/vmcp/schema"
2322
)
@@ -46,6 +45,10 @@ type workflowEngine struct {
4645
// backendClient makes calls to backend MCP servers.
4746
backendClient vmcp.BackendClient
4847

48+
// tools is the resolved tool list for the session, used by getToolInputSchema
49+
// for argument type coercion. Set via NewSessionWorkflowEngine.
50+
tools []vmcp.Tool
51+
4952
// templateExpander handles template expansion.
5053
templateExpander TemplateExpander
5154

@@ -93,6 +96,30 @@ func NewWorkflowEngine(
9396
}
9497
}
9598

99+
// NewSessionWorkflowEngine creates a per-session workflow engine bound to a resolved tool list.
100+
// tools is required: it enables argument type coercion against the session's tool schemas.
101+
// Use this when creating per-session engines via router.NewSessionRouter.
102+
func NewSessionWorkflowEngine(
103+
rtr router.Router,
104+
backendClient vmcp.BackendClient,
105+
elicitationHandler ElicitationProtocolHandler,
106+
stateStore WorkflowStateStore,
107+
auditor *audit.WorkflowAuditor,
108+
tools []vmcp.Tool,
109+
) Composer {
110+
return &workflowEngine{
111+
router: rtr,
112+
backendClient: backendClient,
113+
templateExpander: NewTemplateExpander(),
114+
contextManager: newWorkflowContextManager(),
115+
elicitationHandler: elicitationHandler,
116+
dagExecutor: newDAGExecutor(defaultMaxParallelSteps),
117+
stateStore: stateStore,
118+
auditor: auditor,
119+
tools: tools,
120+
}
121+
}
122+
96123
// ExecuteWorkflow executes a composite tool workflow.
97124
//
98125
// TODO(rate-limiting): Add rate limiting per user/session to prevent workflow execution DoS.
@@ -407,7 +434,7 @@ func (e *workflowEngine) executeToolStep(
407434
// Coerce expanded arguments to expected types based on backend tool schema.
408435
// Template expansion returns strings, but backend tools expect typed values
409436
// (integer, boolean, number) as defined in their InputSchema.
410-
rawSchema := e.getToolInputSchema(ctx, step.Tool)
437+
rawSchema := e.getToolInputSchema(step.Tool)
411438
s := schema.MakeSchema(rawSchema)
412439
if coerced, ok := s.TryCoerce(expandedArgs).(map[string]any); ok {
413440
expandedArgs = coerced
@@ -1223,20 +1250,13 @@ func (e *workflowEngine) auditStepSkipped(
12231250
}
12241251
}
12251252

1226-
// getToolInputSchema looks up a tool's InputSchema from discovered capabilities.
1227-
// Returns nil if the tool is not found or capabilities are not in context.
1228-
func (*workflowEngine) getToolInputSchema(ctx context.Context, toolName string) map[string]any {
1229-
caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx)
1230-
if !ok || caps == nil {
1231-
return nil
1232-
}
1233-
1234-
// Search in backend tools
1235-
for i := range caps.Tools {
1236-
if caps.Tools[i].Name == toolName {
1237-
return caps.Tools[i].InputSchema
1253+
// getToolInputSchema looks up a tool's InputSchema from the session-bound tools list.
1254+
// Returns nil if the engine has no tools list or the tool is not found.
1255+
func (e *workflowEngine) getToolInputSchema(toolName string) map[string]any {
1256+
for i := range e.tools {
1257+
if e.tools[i].Name == toolName {
1258+
return e.tools[i].InputSchema
12381259
}
12391260
}
1240-
12411261
return nil
12421262
}

pkg/vmcp/composer/workflow_engine_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,3 +741,98 @@ func TestWorkflowEngine_WorkflowMetadataAvailableInTemplates(t *testing.T) {
741741
assert.Equal(t, WorkflowStatusCompleted, result.Status)
742742
assert.Len(t, result.Steps, 2)
743743
}
744+
745+
func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing.T) {
746+
t.Parallel()
747+
748+
// Template expansion always produces strings. When the engine is created
749+
// with NewSessionWorkflowEngine, getToolInputSchema resolves the target tool's InputSchema
750+
// and the schema coercion layer converts "42" → 42 before calling the backend.
751+
ctrl := gomock.NewController(t)
752+
t.Cleanup(ctrl.Finish)
753+
754+
mockRouter := routermocks.NewMockRouter(ctrl)
755+
mockBackend := mocks.NewMockBackendClient(ctrl)
756+
757+
tools := []vmcp.Tool{
758+
{
759+
Name: "count_items",
760+
InputSchema: map[string]any{
761+
"type": "object",
762+
"properties": map[string]any{
763+
"limit": map[string]any{"type": "integer"},
764+
},
765+
},
766+
},
767+
}
768+
769+
engine := NewSessionWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools)
770+
771+
target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"}
772+
mockRouter.EXPECT().RouteTool(gomock.Any(), "count_items").Return(target, nil)
773+
774+
// Expect the backend to receive the coerced integer, not the string "42".
775+
coercedArgs := map[string]any{"limit": int64(42)}
776+
mockBackend.EXPECT().
777+
CallTool(gomock.Any(), target, "count_items", coercedArgs, gomock.Any()).
778+
Return(&vmcp.ToolCallResult{StructuredContent: map[string]any{"items": []any{}}, Content: []vmcp.Content{}}, nil)
779+
780+
workflow := &WorkflowDefinition{
781+
Name: "coerce_test",
782+
Parameters: map[string]any{
783+
"type": "object",
784+
"properties": map[string]any{
785+
"n": map[string]any{"type": "string"},
786+
},
787+
},
788+
Steps: []WorkflowStep{
789+
{
790+
ID: "step1",
791+
Type: StepTypeTool,
792+
Tool: "count_items",
793+
// Template expansion produces a string; coercion must convert it to int.
794+
Arguments: map[string]any{"limit": "{{.params.n}}"},
795+
},
796+
},
797+
}
798+
799+
result, err := engine.ExecuteWorkflow(context.Background(), workflow, map[string]any{"n": "42"})
800+
require.NoError(t, err)
801+
assert.Equal(t, WorkflowStatusCompleted, result.Status)
802+
}
803+
804+
func TestWorkflowEngine_SessionEngine_ToolNotInList_ReturnsNilSchema(t *testing.T) {
805+
t.Parallel()
806+
807+
// When NewSessionWorkflowEngine is used but the requested tool is not in the list,
808+
// getToolInputSchema returns nil and coercion is a no-op.
809+
ctrl := gomock.NewController(t)
810+
t.Cleanup(ctrl.Finish)
811+
812+
mockRouter := routermocks.NewMockRouter(ctrl)
813+
mockBackend := mocks.NewMockBackendClient(ctrl)
814+
815+
// Tools list does not include "other_tool".
816+
tools := []vmcp.Tool{{Name: "known_tool", InputSchema: map[string]any{"type": "object"}}}
817+
engine := NewSessionWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools)
818+
819+
target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"}
820+
mockRouter.EXPECT().RouteTool(gomock.Any(), "other_tool").Return(target, nil)
821+
822+
// Args pass through unmodified (string stays a string).
823+
rawArgs := map[string]any{"value": "hello"}
824+
mockBackend.EXPECT().
825+
CallTool(gomock.Any(), target, "other_tool", rawArgs, gomock.Any()).
826+
Return(&vmcp.ToolCallResult{StructuredContent: map[string]any{"ok": true}, Content: []vmcp.Content{}}, nil)
827+
828+
workflow := &WorkflowDefinition{
829+
Name: "no_schema_test",
830+
Steps: []WorkflowStep{
831+
{ID: "s1", Type: StepTypeTool, Tool: "other_tool", Arguments: rawArgs},
832+
},
833+
}
834+
835+
result, err := engine.ExecuteWorkflow(context.Background(), workflow, nil)
836+
require.NoError(t, err)
837+
assert.Equal(t, WorkflowStatusCompleted, result.Status)
838+
}

pkg/vmcp/discovery/middleware.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,12 @@ func handleSubsequentRequest(
281281
return ctx, fmt.Errorf("session not found: %s", sessionID)
282282
}
283283

284-
// Backend tool calls are routed by session-scoped handlers registered with the SDK.
285-
// However, composite tool workflow steps go through the shared router which requires
286-
// DiscoveredCapabilities in the context. Inject capabilities built from the session's
287-
// routing table so composite workflows can route backend tool calls correctly.
284+
// Backend tool handlers (created by DefaultHandlerFactory) resolve their backend
285+
// target by calling router.RouteTool(ctx, name), which reads DiscoveredCapabilities
286+
// from the request context. Inject capabilities built from the session's routing
287+
// table so these handlers can route correctly on subsequent requests.
288+
// Note: composite tool workflow engines are created per-session and route via
289+
// SessionRouter directly, so they no longer depend on this context value.
288290
multiSess, isMulti := rawSess.(vmcpsession.MultiSession)
289291
if !isMulti {
290292
// The session is still a StreamableSession placeholder — Phase 2

pkg/vmcp/router/session_router.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package router
5+
6+
import (
7+
"context"
8+
"fmt"
9+
10+
"github.com/stacklok/toolhive/pkg/vmcp"
11+
)
12+
13+
// sessionRouter is a Router implementation backed directly by a RoutingTable,
14+
// requiring no request context to resolve capabilities. It is used by
15+
// per-session workflow engines so that composite tool execution does not depend
16+
// on the discovery middleware injecting DiscoveredCapabilities into the context.
17+
type sessionRouter struct {
18+
routingTable *vmcp.RoutingTable
19+
}
20+
21+
// NewSessionRouter creates a Router that routes from the provided RoutingTable
22+
// without reading the request context. This is the preferred router for
23+
// composite tool workflow engines because it couples routing to the session
24+
// rather than to middleware-managed context values.
25+
func NewSessionRouter(rt *vmcp.RoutingTable) Router {
26+
return &sessionRouter{routingTable: rt}
27+
}
28+
29+
// RouteTool resolves a tool name to its backend target using the session's
30+
// routing table directly.
31+
func (r *sessionRouter) RouteTool(_ context.Context, toolName string) (*vmcp.BackendTarget, error) {
32+
if r.routingTable == nil || r.routingTable.Tools == nil {
33+
return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
34+
}
35+
target, exists := r.routingTable.Tools[toolName]
36+
if !exists {
37+
return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName)
38+
}
39+
return target, nil
40+
}
41+
42+
// RouteResource resolves a resource URI to its backend target using the
43+
// session's routing table directly.
44+
func (r *sessionRouter) RouteResource(_ context.Context, uri string) (*vmcp.BackendTarget, error) {
45+
if r.routingTable == nil || r.routingTable.Resources == nil {
46+
return nil, fmt.Errorf("%w: %s", ErrResourceNotFound, uri)
47+
}
48+
target, exists := r.routingTable.Resources[uri]
49+
if !exists {
50+
return nil, fmt.Errorf("%w: %s", ErrResourceNotFound, uri)
51+
}
52+
return target, nil
53+
}
54+
55+
// RoutePrompt resolves a prompt name to its backend target using the session's
56+
// routing table directly.
57+
func (r *sessionRouter) RoutePrompt(_ context.Context, name string) (*vmcp.BackendTarget, error) {
58+
if r.routingTable == nil || r.routingTable.Prompts == nil {
59+
return nil, fmt.Errorf("%w: %s", ErrPromptNotFound, name)
60+
}
61+
target, exists := r.routingTable.Prompts[name]
62+
if !exists {
63+
return nil, fmt.Errorf("%w: %s", ErrPromptNotFound, name)
64+
}
65+
return target, nil
66+
}

0 commit comments

Comments
 (0)