From 0be8618a0ded4922f5ae9ee679492a6e2007c633 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Wed, 20 May 2026 12:03:30 -0400 Subject: [PATCH 1/4] feat: add middleware system for wrapping agent stages --- strands-ts/src/agent/agent.ts | 541 ++++--- strands-ts/src/index.ts | 16 + .../__tests__/agent-middleware.test.ts | 1256 +++++++++++++++++ .../__tests__/custom-stages.test.ts | 151 ++ .../__tests__/middleware-interrupts.test.ts | 216 +++ .../src/middleware/__tests__/registry.test.ts | 628 +++++++++ strands-ts/src/middleware/index.ts | 16 + strands-ts/src/middleware/registry.ts | 73 + strands-ts/src/middleware/stages.ts | 115 ++ strands-ts/src/middleware/types.ts | 57 + strands-ts/src/types/agent.ts | 14 + 11 files changed, 2922 insertions(+), 161 deletions(-) create mode 100644 strands-ts/src/middleware/__tests__/agent-middleware.test.ts create mode 100644 strands-ts/src/middleware/__tests__/custom-stages.test.ts create mode 100644 strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts create mode 100644 strands-ts/src/middleware/__tests__/registry.test.ts create mode 100644 strands-ts/src/middleware/index.ts create mode 100644 strands-ts/src/middleware/registry.ts create mode 100644 strands-ts/src/middleware/stages.ts create mode 100644 strands-ts/src/middleware/types.ts diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts index 7fb3d6484..8967ea381 100644 --- a/strands-ts/src/agent/agent.ts +++ b/strands-ts/src/agent/agent.ts @@ -45,6 +45,16 @@ import { SlidingWindowConversationManager } from '../conversation-manager/slidin import { NullConversationManager } from '../conversation-manager/null-conversation-manager.js' import { ConversationManager } from '../conversation-manager/conversation-manager.js' import { HookRegistryImplementation } from '../hooks/registry.js' +import { MiddlewareRegistry, InvokeModelStage, ExecuteToolStage, AgentStreamStage } from '../middleware/index.js' +import type { Stage, MiddlewareHandler, MiddlewareNext } from '../middleware/index.js' +import type { + InvokeModelContext, + InvokeModelResult, + ExecuteToolContext, + ExecuteToolResult, + AgentStreamContext, + AgentStreamResult, +} from '../middleware/index.js' import type { HookableEventConstructor, HookCallback, HookCallbackOptions, HookCleanup } from '../hooks/types.js' import { InitializedEvent, @@ -66,6 +76,7 @@ import { ToolStreamUpdateEvent, InterruptEvent, type ModelStopData, + type ToolUseData, } from '../hooks/events.js' import { StructuredOutputTool, STRUCTURED_OUTPUT_TOOL_NAME } from '../tools/structured-output-tool.js' import { AgentAsTool } from './agent-as-tool.js' @@ -83,7 +94,7 @@ import { CancelledError } from '../errors.js' import { DefaultModelRetryStrategy } from '../retry/default-model-retry-strategy.js' import type { RetryStrategy } from '../retry/retry-strategy.js' import { warnOnDuplicateRetryStrategyTypes } from '../retry/retry-strategy.js' -import { InterruptError, InterruptState, interruptFromAgent } from '../interrupt.js' +import { Interrupt, InterruptError, InterruptState, interruptFromAgent } from '../interrupt.js' import type { InterruptParams } from '../types/interrupt.js' import { isInterruptResponseContent, type InterruptResponseContent } from '../types/interrupt.js' import { takeSnapshot as takeSnapshotInternal, loadSnapshot as loadSnapshotInternal } from './snapshot.js' @@ -290,6 +301,7 @@ export class Agent implements LocalAgent, InvokableAgent { public readonly sessionManager?: SessionManager | undefined private readonly _hooksRegistry: HookRegistryImplementation + private readonly _middlewareRegistry: MiddlewareRegistry private readonly _pluginRegistry: PluginRegistry private readonly _interventionRegistry: InterventionRegistry private _toolRegistry: ToolRegistry @@ -353,6 +365,9 @@ export class Agent implements LocalAgent, InvokableAgent { this._interventionRegistry = new InterventionRegistry(config?.interventions ?? [], this._hooksRegistry) + // Initialize middleware registry + this._middlewareRegistry = new MiddlewareRegistry() + // `undefined` (omitted) → install the default; `null`/`[]` → explicit opt-out. const retryStrategies: RetryStrategy[] = config?.retryStrategy === null @@ -440,6 +455,20 @@ export class Agent implements LocalAgent, InvokableAgent { return this._hooksRegistry.addCallback(eventType, callback, options) } + /** + * Register a middleware handler for a given stage. + * Middleware wraps stage execution and can intercept, transform, or short-circuit operations. + * + * @param stage - The stage token identifying the interception point + * @param handler - The middleware handler function (async generator) + */ + addMiddleware( + stage: Stage, + handler: MiddlewareHandler, + ): void { + this._middlewareRegistry.add(stage, handler) + } + public async initialize(): Promise { if (this._initialized) { return @@ -711,92 +740,164 @@ export class Agent implements LocalAgent, InvokableAgent { try { await this.initialize() - let currentArgs: InvokeArgs = args - - // Outer loop: re-enters _stream when a hook sets AfterInvocationEvent.resume. - // One invocation lock spans the whole resume chain. - while (true) { - // Fresh AbortController per invocation iteration, composed with any external signal. - this._abortController = new AbortController() - this._abortSignal = options?.cancelSignal - ? AbortSignal.any([this._abortController.signal, options.cancelSignal]) - : this._abortController.signal - - const streamGenerator = this._stream(currentArgs, options) - let caughtError: Error | undefined - let lastAfterInvocation: AfterInvocationEvent | undefined - let iterationResult: IteratorResult - try { - iterationResult = await streamGenerator.next() + // Process interrupt responses before middleware runs so context.interrupt() can find them + const interruptResponses = this._extractInterruptResponses(args) + if (interruptResponses.length > 0) { + this._interruptState.resume(interruptResponses) + } - while (!iterationResult.done) { - try { - const processed = await this._invokeCallbacks(iterationResult.value) - if (processed instanceof AfterInvocationEvent) { - lastAfterInvocation = processed - } - yield processed - iterationResult = await streamGenerator.next() - } catch (error) { - // Throw interrupt errors back into _stream so executeTools can store the - // assistant message as pending execution state for resume. - if (error instanceof InterruptError) { - iterationResult = await streamGenerator.throw(error) - } else { - throw error - } - } + const context: AgentStreamContext = { + agent: this, + args, + ...(options !== undefined && { options }), + interrupt: (params: InterruptParams): T => { + const interruptId = `middleware:agentStream:${params.name}` + // Read-only: check if a response already exists (resume case or preemptive) + const existing = this._interruptState.interrupts[interruptId] + if (existing?.response !== undefined) { + return existing.response as T + } + if (params.response !== undefined) { + return params.response as T } + // No response available: create interrupt locally and throw (no state mutation) + const interrupt = new Interrupt({ + id: interruptId, + name: params.name, + ...(params.reason !== undefined && { reason: params.reason }), + }) + throw new InterruptError(interrupt) + }, + } - // Suppress AgentResultEvent for resumed iterations — only the final - // invocation in a resume chain reports an agent result. - if (lastAfterInvocation?.resume === undefined) { - yield await this._invokeCallbacks( - new AgentResultEvent({ - agent: this, - result: iterationResult.value, - invocationState: iterationResult.value.invocationState, - }) - ) + // eslint-disable-next-line @typescript-eslint/no-this-alias + const self = this + try { + const { result } = yield* this._middlewareRegistry.invoke( + AgentStreamStage, + context, + async function* (ctx: AgentStreamContext): AsyncGenerator { + const result = yield* self._streamWithResumeLoop(ctx.args, ctx.options) + return { result } + }, + ) + return result + } catch (error) { + if (error instanceof InterruptError) { + // Transfer interrupts from the error to agent state (deferred write). + // Middleware-thrown interrupts bypass _stream's own catch, so we mirror + // upstream's behavior here. + for (const interrupt of error.interrupts) { + this._interruptState.getOrCreateInterrupt(interrupt.id, interrupt.name, interrupt.reason) } - } catch (error) { - caughtError = error as Error - throw error - } finally { - // Drain _stream() so cleanup hooks and printer still fire. - // Yield only on error (consumer may still be iterating); on a consumer - // break, yielding would suspend the generator and leak the lock. - let drainResult = await streamGenerator.return(undefined as never) - while (!drainResult.done) { - try { - if (caughtError) { - yield await this._invokeCallbacks(drainResult.value) - } else { - await this._invokeCallbacks(drainResult.value) - } - } catch (error) { - logger.warn( - `event_type=<${drainResult.value.type}>, error=<${error}> | error invoking callbacks during cleanup` - ) + this._interruptState.activate() + return new AgentResult({ + stopReason: 'interrupt', + lastMessage: new Message({ role: 'assistant', content: [] }), + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + interrupts: this._interruptState.getUnansweredInterrupts(), + invocationState: options?.invocationState ?? {}, + }) + } + throw error + } + } finally { + this._isInvoking = false + } + } + + /** + * Internal stream logic with the outer resume loop. + * Extracted to allow AgentStreamStage middleware to wrap it as a terminal function. + */ + private async *_streamWithResumeLoop( + args: InvokeArgs, + options?: InvokeOptions, + ): AsyncGenerator { + let currentArgs: InvokeArgs = args + + // Outer loop: re-enters _stream when a hook sets AfterInvocationEvent.resume. + // One invocation lock spans the whole resume chain. + while (true) { + // Fresh AbortController per invocation iteration, composed with any external signal. + this._abortController = new AbortController() + this._abortSignal = options?.cancelSignal + ? AbortSignal.any([this._abortController.signal, options.cancelSignal]) + : this._abortController.signal + + const streamGenerator = this._stream(currentArgs, options) + let caughtError: Error | undefined + let lastAfterInvocation: AfterInvocationEvent | undefined + let iterationResult: IteratorResult + try { + iterationResult = await streamGenerator.next() + + while (!iterationResult.done) { + try { + const processed = await this._invokeCallbacks(iterationResult.value) + if (processed instanceof AfterInvocationEvent) { + lastAfterInvocation = processed + } + yield processed + iterationResult = await streamGenerator.next() + } catch (error) { + // Throw interrupt errors back into _stream so executeTools can store the + // assistant message as pending execution state for resume. + if (error instanceof InterruptError) { + iterationResult = await streamGenerator.throw(error) + } else { + throw error } - drainResult = await streamGenerator.next() } - - // Reset controller and signal for next iteration / invocation - this._abortController = new AbortController() - this._abortSignal = this._abortController.signal } - // Resume only on a clean invocation — errors propagate above. - if (lastAfterInvocation?.resume !== undefined) { - currentArgs = lastAfterInvocation.resume - continue + // Suppress AgentResultEvent for resumed iterations — only the final + // invocation in a resume chain reports an agent result. + if (lastAfterInvocation?.resume === undefined) { + yield await this._invokeCallbacks( + new AgentResultEvent({ + agent: this, + result: iterationResult.value, + invocationState: iterationResult.value.invocationState, + }) + ) + } + } catch (error) { + caughtError = error as Error + throw error + } finally { + // Drain _stream() so cleanup hooks and printer still fire. + // Yield only on error (consumer may still be iterating); on a consumer + // break, yielding would suspend the generator and leak the lock. + let drainResult = await streamGenerator.return(undefined as never) + while (!drainResult.done) { + try { + if (caughtError) { + yield await this._invokeCallbacks(drainResult.value) + } else { + await this._invokeCallbacks(drainResult.value) + } + } catch (error) { + logger.warn( + `event_type=<${drainResult.value.type}>, error=<${error}> | error invoking callbacks during cleanup` + ) + } + drainResult = await streamGenerator.next() } - return iterationResult.value + // Reset controller and signal for next iteration / invocation + this._abortController = new AbortController() + this._abortSignal = this._abortController.signal } - } finally { - this._isInvoking = false + + // Resume only on a clean invocation — errors propagate above. + if (lastAfterInvocation?.resume !== undefined) { + currentArgs = lastAfterInvocation.resume + continue + } + + return iterationResult.value } } @@ -1199,6 +1300,14 @@ export class Agent implements LocalAgent, InvokableAgent { return result } if (error instanceof InterruptError) { + // Transfer interrupts from the error to agent state. Tool/hook-thrown + // interrupts already populate state; middleware-thrown interrupts + // (from the public stream() interrupt fn) defer the write, so this + // call is the seam that captures both. getOrCreateInterrupt is + // idempotent for already-known IDs. + for (const interrupt of error.interrupts) { + this._interruptState.getOrCreateInterrupt(interrupt.id, interrupt.name, interrupt.reason) + } // Fan out one event per interrupt. Each event exposes `interrupt.source` so // consumers can filter by origin (tool callback vs hook callback) without // subscribing to separate event types. @@ -1444,7 +1553,7 @@ export class Agent implements LocalAgent, InvokableAgent { }) try { - const result = yield* this._streamFromModel(this.messages, streamOptions, invocationState) + const result = yield* this._invokeModelWithMiddleware(invocationState, toolChoice) // Accumulate token usage and model latency metrics this._meter.updateCycle(result.metadata) @@ -1528,6 +1637,54 @@ export class Agent implements LocalAgent, InvokableAgent { } } + /** + * Invokes the model through the InvokeModelStage middleware chain. + * Builds an InvokeModelContext from current agent state and composes the + * middleware chain with a terminal function that calls _streamFromModel + * using context fields directly (not re-derived from the agent). + * + * @param invocationState - Per-invocation state shared across hooks and tools + * @param toolChoice - Optional tool choice to force specific tool usage + * @returns StreamAggregatedResult from the model (or middleware short-circuit) + */ + private async *_invokeModelWithMiddleware( + invocationState: InvocationState, + toolChoice?: ToolChoice + ): AsyncGenerator { + const context: InvokeModelContext = { + agent: this, + messages: this.messages, + ...(this.systemPrompt !== undefined && { systemPrompt: this.systemPrompt }), + toolSpecs: this._toolRegistry.list().map((tool) => tool.toolSpec), + ...(toolChoice !== undefined && { toolChoice }), + modelState: this.modelState, + invocationState, + } + + // eslint-disable-next-line @typescript-eslint/no-this-alias + const self = this + const middlewareResult = yield* this._middlewareRegistry.invoke( + InvokeModelStage, + context, + async function* (ctx: InvokeModelContext): AsyncGenerator { + const streamOptions: StreamOptions = { + toolSpecs: ctx.toolSpecs, + modelState: ctx.modelState, + ...(ctx.systemPrompt !== undefined && { systemPrompt: ctx.systemPrompt }), + ...(ctx.toolChoice && { toolChoice: ctx.toolChoice }), + } + const gen = self._streamFromModel(ctx.messages, streamOptions, ctx.invocationState) + let iterResult = await gen.next() + while (!iterResult.done) { + yield iterResult.value + iterResult = await gen.next() + } + return { result: iterResult.value } + }, + ) + return middlewareResult.result + } + /** * Streams events from the model and dispatches appropriate events for each. * @@ -1976,90 +2133,14 @@ export class Agent implements LocalAgent, InvokableAgent { return afterToolCallEvent.result } - // Start tool span within loop span context - const toolSpan = this._tracer.startToolCallSpan({ - tool: toolUse, - }) - - // Track tool execution time for metrics - const toolStartTime = Date.now() - - let toolResult: ToolResultBlock - let error: Error | undefined - - if (!effectiveTool) { - // Tool not found - toolResult = new ToolResultBlock({ - toolUseId: toolUse.toolUseId, - status: 'error', - content: [new TextBlock(`Tool '${toolUse.name}' not found in registry`)], - }) - } else { - // Execute tool within the tool span context - const toolContext: ToolContext = { - toolUse: { - name: toolUse.name, - toolUseId: toolUse.toolUseId, - input: toolUse.input, - }, - agent: this, - invocationState, - interrupt: (params: InterruptParams): T => { - return interruptFromAgent(this, `tool:${toolUseBlock.toolUseId}:${params.name}`, params, 'tool') - }, - } - - try { - // Manually iterate tool stream to wrap each ToolStreamEvent in ToolStreamUpdateEvent. - // This keeps the tool authoring interface unchanged — tools construct ToolStreamEvent - // without knowledge of agents or hooks, and we wrap at the boundary. - // Tool execution is ran within the tool span's context so that - // downstream calls (e.g., MCP clients) can propagate trace context - const toolGenerator = this._tracer.withSpanContext(toolSpan, () => effectiveTool.stream(toolContext)) - let toolNext = await this._tracer.withSpanContext(toolSpan, () => toolGenerator.next()) - while (!toolNext.done) { - yield new ToolStreamUpdateEvent({ agent: this, event: toolNext.value, invocationState }) - toolNext = await this._tracer.withSpanContext(toolSpan, () => toolGenerator.next()) - } - const result = toolNext.value - - if (!result) { - // Tool didn't return a result - toolResult = new ToolResultBlock({ - toolUseId: toolUse.toolUseId, - status: 'error', - content: [new TextBlock(`Tool '${toolUse.name}' did not return a result`)], - }) - } else { - toolResult = result - error = result.error - } - } catch (e) { - // Re-throw InterruptError to allow interrupt handling - if (e instanceof InterruptError) { - throw e - } - // Tool execution failed with error - error = normalizeError(e) - toolResult = new ToolResultBlock({ - toolUseId: toolUse.toolUseId, - status: 'error', - content: [new TextBlock(error.message)], - error, - }) - } - } - - // End tool span with the raw tool result — telemetry reflects what the - // tool actually returned, independent of AfterToolCallEvent mutations. - this._tracer.endToolCallSpan(toolSpan, { toolResult, ...(error && { error }) }) - - // End tool metrics tracking - this._meter.endToolCall({ - tool: toolUse, - duration: Date.now() - toolStartTime, - success: toolResult.status === 'success', - }) + // Execute tool core logic through middleware chain + const middlewareResult = yield* this._executeToolWithMiddleware( + effectiveTool, + toolUse, + invocationState, + ) + const toolResult = middlewareResult.result + const error = toolResult.error // Single point for AfterToolCallEvent const afterToolCallEvent = new AfterToolCallEvent({ @@ -2082,6 +2163,144 @@ export class Agent implements LocalAgent, InvokableAgent { } } + private async *_executeToolWithMiddleware( + tool: Tool | undefined, + toolUse: ToolUseData, + invocationState: InvocationState, + ): AsyncGenerator { + const context: ExecuteToolContext = { + agent: this, + tool, + toolUse: { + name: toolUse.name, + toolUseId: toolUse.toolUseId, + input: toolUse.input, + }, + invocationState, + interrupt: (params: InterruptParams): T => { + const interruptId = `middleware:executeTool:${toolUse.toolUseId}:${params.name}` + // Read-only: check if a response already exists (resume case or preemptive) + const existing = this._interruptState.interrupts[interruptId] + if (existing?.response !== undefined) { + return existing.response as T + } + if (params.response !== undefined) { + return params.response as T + } + // No response available: create interrupt locally and throw (no state mutation) + const interrupt = new Interrupt({ + id: interruptId, + name: params.name, + ...(params.reason !== undefined && { reason: params.reason }), + }) + throw new InterruptError(interrupt) + }, + } + + // eslint-disable-next-line @typescript-eslint/no-this-alias + const self = this + return yield* this._middlewareRegistry.invoke( + ExecuteToolStage, + context, + async function* (ctx: ExecuteToolContext): AsyncGenerator { + return yield* self._executeToolCore(ctx.tool, ctx.toolUse, ctx.invocationState) + }, + ) + } + + private async *_executeToolCore( + effectiveTool: Tool | undefined, + toolUse: ToolUseData, + invocationState: InvocationState, + ): AsyncGenerator { + // Start tool span within loop span context + const toolSpan = this._tracer.startToolCallSpan({ + tool: toolUse, + }) + + // Track tool execution time for metrics + const toolStartTime = Date.now() + + let toolResult: ToolResultBlock + let error: Error | undefined + + if (!effectiveTool) { + // Tool not found + toolResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(`Tool '${toolUse.name}' not found in registry`)], + }) + } else { + // Execute tool within the tool span context + const toolContext: ToolContext = { + toolUse: { + name: toolUse.name, + toolUseId: toolUse.toolUseId, + input: toolUse.input, + }, + agent: this, + invocationState, + interrupt: (params: InterruptParams): T => { + return interruptFromAgent(this, `tool:${toolUse.toolUseId}:${params.name}`, params, 'tool') + }, + } + + try { + // Manually iterate tool stream to wrap each ToolStreamEvent in ToolStreamUpdateEvent. + // This keeps the tool authoring interface unchanged — tools construct ToolStreamEvent + // without knowledge of agents or hooks, and we wrap at the boundary. + // Tool execution is ran within the tool span's context so that + // downstream calls (e.g., MCP clients) can propagate trace context + const toolGenerator = this._tracer.withSpanContext(toolSpan, () => effectiveTool.stream(toolContext)) + let toolNext = await this._tracer.withSpanContext(toolSpan, () => toolGenerator.next()) + while (!toolNext.done) { + yield new ToolStreamUpdateEvent({ agent: this, event: toolNext.value, invocationState }) + toolNext = await this._tracer.withSpanContext(toolSpan, () => toolGenerator.next()) + } + const result = toolNext.value + + if (!result) { + // Tool didn't return a result + toolResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(`Tool '${toolUse.name}' did not return a result`)], + }) + } else { + toolResult = result + error = result.error + } + } catch (e) { + // Re-throw InterruptError to allow interrupt handling + if (e instanceof InterruptError) { + throw e + } + // Tool execution failed with error + error = normalizeError(e) + toolResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(error.message)], + error, + }) + } + } + + // End tool span with the raw tool result — telemetry reflects what the + // tool actually returned, independent of AfterToolCallEvent mutations. + this._tracer.endToolCallSpan(toolSpan, { toolResult, ...(error && { error }) }) + + // End tool metrics tracking + this._meter.endToolCall({ + tool: toolUse, + duration: Date.now() - toolStartTime, + success: toolResult.status === 'success', + }) + + return { result: toolResult } + } + /** * Redacts the last message in the conversation history. * Called when guardrails block user input and redaction is enabled. diff --git a/strands-ts/src/index.ts b/strands-ts/src/index.ts index 46b1022c0..762288d67 100644 --- a/strands-ts/src/index.ts +++ b/strands-ts/src/index.ts @@ -307,6 +307,22 @@ export { Sandbox, type ExecuteOptions } from './sandbox/base.js' export { PosixShellSandbox } from './sandbox/posix-shell.js' export type { StreamType, StreamChunk, FileInfo, OutputFile, ExecutionResult } from './sandbox/types.js' +// Middleware system +export { createStage, InvokeModelStage, ExecuteToolStage, AgentStreamStage } from './middleware/index.js' +export type { + Stage, + MiddlewareHandler, + MiddlewareNext, + HandlerOf, + NextOf, + InvokeModelContext, + InvokeModelResult, + ExecuteToolContext, + ExecuteToolResult, + AgentStreamContext, + AgentStreamResult, +} from './middleware/index.js' + // Multi-agent orchestration export { Graph } from './multiagent/index.js' export { Swarm } from './multiagent/index.js' diff --git a/strands-ts/src/middleware/__tests__/agent-middleware.test.ts b/strands-ts/src/middleware/__tests__/agent-middleware.test.ts new file mode 100644 index 000000000..7a37c6494 --- /dev/null +++ b/strands-ts/src/middleware/__tests__/agent-middleware.test.ts @@ -0,0 +1,1256 @@ +import { describe, expect, it, vi } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { AgentStreamStage, ExecuteToolStage, InvokeModelStage } from '../stages.js' +import type { + AgentStreamContext, + AgentStreamResult, + ExecuteToolContext, + ExecuteToolResult, + InvokeModelContext, +} from '../stages.js' +import type { MiddlewareHandler, HandlerOf } from '../types.js' +import type { AgentStreamEvent, LocalAgent } from '../../types/agent.js' +import type { Plugin } from '../../plugins/plugin.js' +import { TextBlock, ToolResultBlock, Message } from '../../types/messages.js' +import { AfterToolCallEvent, BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, ContentBlockEvent } from '../../hooks/events.js' +import type { ToolContext } from '../../tools/tool.js' + +type ExecuteToolMiddleware = MiddlewareHandler +type AgentStreamMiddleware = MiddlewareHandler + +describe('Agent middleware integration — InvokeModelStage', () => { + describe('addMiddleware registers handler and it executes on model call', () => { + it('middleware handler is invoked during model call', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const middlewareCalled = vi.fn() + + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + middlewareCalled() + return yield* next(context) + }) + + await agent.invoke('Test prompt') + + expect(middlewareCalled).toHaveBeenCalledOnce() + }) + + it('middleware receives InvokeModelContext with correct fields', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('ok')], + }), + ) + const agent = new Agent({ model, tools: [tool], printer: false, systemPrompt: 'Be helpful' }) + + let receivedContext: InvokeModelContext | undefined + + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + receivedContext = context + return yield* next(context) + }) + + await agent.invoke('Test prompt') + + expect(receivedContext).toMatchObject({ + agent, + systemPrompt: 'Be helpful', + messages: expect.arrayContaining([expect.any(Message)]), + toolSpecs: expect.arrayContaining([expect.objectContaining({ name: 'testTool' })]), + modelState: expect.anything(), + invocationState: expect.anything(), + }) + }) + + it('middleware result is used as the model call result', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + const result = yield* next(context) + return result + }) + + const result = await agent.invoke('Test prompt') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content).toEqual([new TextBlock('Hello')]) + }) + }) + + describe('middleware can short-circuit model call with synthetic result', () => { + it('returns synthetic result without calling the model', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Real response' }) + const agent = new Agent({ model, printer: false }) + + agent.addMiddleware( + InvokeModelStage, + // eslint-disable-next-line require-yield + async function* () { + return { + result: { + message: new Message({ role: 'assistant', content: [new TextBlock('Cached response')] }), + stopReason: 'endTurn' as const, + }, + } + }, + ) + + const result = await agent.invoke('Test prompt') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content).toEqual([new TextBlock('Cached response')]) + }) + + it('model is not called when middleware short-circuits', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Real response' }) + const agent = new Agent({ model, printer: false }) + + const streamSpy = vi.spyOn(model, 'stream') + + agent.addMiddleware( + InvokeModelStage, + // eslint-disable-next-line require-yield + async function* () { + return { + result: { + message: new Message({ role: 'assistant', content: [new TextBlock('Cached')] }), + stopReason: 'endTurn' as const, + }, + } + }, + ) + + await agent.invoke('Test prompt') + + expect(streamSpy).not.toHaveBeenCalled() + }) + }) + + describe('middleware can transform context (messages, toolSpecs) before model call', () => { + it('modified messages are passed to the model', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addMiddleware( + InvokeModelStage, + async function* (context, next) { + const modifiedContext: InvokeModelContext = { + ...context, + messages: [ + ...context.messages, + new Message({ role: 'user', content: [new TextBlock('Injected message')] }), + ], + } + return yield* next(modifiedContext) + }, + ) + + const streamSpy = vi.spyOn(model, 'stream') + + await agent.invoke('Test prompt') + + expect(streamSpy).toHaveBeenCalled() + const calledMessages = streamSpy.mock.calls[0]![0] + expect(calledMessages.length).toBeGreaterThan(1) + }) + + it('modified toolSpecs are passed to the model', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addMiddleware( + InvokeModelStage, + async function* (context, next) { + const modifiedContext: InvokeModelContext = { + ...context, + toolSpecs: [], + } + return yield* next(modifiedContext) + }, + ) + + const streamSpy = vi.spyOn(model, 'stream') + + await agent.invoke('Test prompt') + + expect(streamSpy).toHaveBeenCalled() + const calledOptions = streamSpy.mock.calls[0]![1] + expect(calledOptions?.toolSpecs).toStrictEqual([]) + }) + }) + + describe('hooks fire around middleware', () => { + it('BeforeModelCallEvent fires before middleware executes', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const order: string[] = [] + + agent.addHook(BeforeModelCallEvent, () => { + order.push('beforeModelCall') + }) + + agent.addMiddleware( + InvokeModelStage, + async function* (context, next) { + order.push('middleware') + return yield* next(context) + }, + ) + + await agent.invoke('Test prompt') + + expect(order.indexOf('beforeModelCall')).toBeLessThan(order.indexOf('middleware')) + }) + + it('AfterModelCallEvent fires after middleware completes', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const order: string[] = [] + + agent.addHook(AfterModelCallEvent, () => { + order.push('afterModelCall') + }) + + agent.addMiddleware( + InvokeModelStage, + async function* (context, next) { + order.push('middleware-start') + const result = yield* next(context) + order.push('middleware-end') + return result + }, + ) + + await agent.invoke('Test prompt') + + expect(order.indexOf('middleware-end')).toBeLessThan(order.indexOf('afterModelCall')) + }) + + it('Before/After hooks fire even when middleware short-circuits', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const beforeCalled = vi.fn() + const afterCalled = vi.fn() + + agent.addHook(BeforeModelCallEvent, beforeCalled) + agent.addHook(AfterModelCallEvent, afterCalled) + + agent.addMiddleware( + InvokeModelStage, + // eslint-disable-next-line require-yield + async function* () { + return { + result: { + message: new Message({ role: 'assistant', content: [new TextBlock('Cached')] }), + stopReason: 'endTurn' as const, + }, + } + }, + ) + + await agent.invoke('Test prompt') + + expect(beforeCalled).toHaveBeenCalled() + expect(afterCalled).toHaveBeenCalled() + }) + }) + + describe('no middleware registered', () => { + it('agent works correctly without any middleware', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const result = await agent.invoke('Test prompt') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content).toEqual([new TextBlock('Hello')]) + }) + + it('agent with tools works correctly without middleware', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Tool executed')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const result = await agent.invoke('Use the tool') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content).toEqual([new TextBlock('Done')]) + }) + + it('stream works correctly without middleware', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const { result } = await collectGenerator(agent.stream('Test prompt')) + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content).toEqual([new TextBlock('Hello')]) + }) + }) +}) + +describe('AgentStreamStage integration', () => { + describe('middleware wraps the entire agent stream', () => { + it('middleware executes around the full agent stream', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const callOrder: string[] = [] + + const middleware: AgentStreamMiddleware = async function* (context, next) { + callOrder.push('middleware-before') + const result = yield* next(context) + callOrder.push('middleware-after') + return result + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { result } = await collectGenerator(agent.stream('Test prompt')) + + expect(callOrder).toStrictEqual(['middleware-before', 'middleware-after']) + expect(result.stopReason).toBe('endTurn') + }) + + it('middleware receives AgentStreamContext with agent, args, and options', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + let receivedContext: AgentStreamContext | undefined + + const middleware: AgentStreamMiddleware = async function* (context, next) { + receivedContext = context + return yield* next(context) + } + + agent.addMiddleware(AgentStreamStage, middleware) + + await collectGenerator(agent.stream('Test prompt')) + + expect(receivedContext).toBeDefined() + expect(receivedContext!.agent).toBe(agent) + expect(receivedContext!.args).toBe('Test prompt') + }) + + it('middleware receives options when provided', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + let receivedContext: AgentStreamContext | undefined + const options = { invocationState: { key: 'value' } } + + const middleware: AgentStreamMiddleware = async function* (context, next) { + receivedContext = context + return yield* next(context) + } + + agent.addMiddleware(AgentStreamStage, middleware) + + await collectGenerator(agent.stream('Test prompt', options)) + + expect(receivedContext!.options).toBe(options) + }) + + it('middleware can short-circuit the entire agent stream', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Should not reach' }) + const agent = new Agent({ model, printer: false }) + + // eslint-disable-next-line require-yield + const middleware: AgentStreamMiddleware = async function* () { + return { + result: { + stopReason: 'endTurn', + lastMessage: { type: 'message', role: 'assistant', content: [] }, + metrics: { cycleCount: 0, accumulatedUsage: {}, accumulatedMetrics: {}, toolMetrics: {} }, + invocationState: {}, + }, + } as unknown as AgentStreamResult + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items, result } = await collectGenerator(agent.stream('Test prompt')) + + expect(items).toStrictEqual([]) + expect(result.stopReason).toBe('endTurn') + }) + + it('multiple middleware execute in registration order', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const callOrder: string[] = [] + + const outer: AgentStreamMiddleware = async function* (context, next) { + callOrder.push('outer-before') + const result = yield* next(context) + callOrder.push('outer-after') + return result + } + + const inner: AgentStreamMiddleware = async function* (context, next) { + callOrder.push('inner-before') + const result = yield* next(context) + callOrder.push('inner-after') + return result + } + + agent.addMiddleware(AgentStreamStage, outer) + agent.addMiddleware(AgentStreamStage, inner) + + await collectGenerator(agent.stream('Test prompt')) + + expect(callOrder).toStrictEqual([ + 'outer-before', + 'inner-before', + 'inner-after', + 'outer-after', + ]) + }) + }) + + describe('middleware can filter events from the stream', () => { + it('filters out specific event types', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const middleware: AgentStreamMiddleware = async function* (context, next) { + const gen = next(context) + let iterResult = await gen.next() + while (!iterResult.done) { + const event = iterResult.value + // Filter out modelStreamUpdate events + if (event.type !== 'modelStreamUpdateEvent') { + yield event + } + iterResult = await gen.next() + } + return iterResult.value + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items } = await collectGenerator(agent.stream('Test prompt')) + + const modelStreamEvents = items.filter((e: AgentStreamEvent) => e.type === 'modelStreamUpdateEvent') + expect(modelStreamEvents).toStrictEqual([]) + // Other events should still be present + expect(items.length).toBeGreaterThan(0) + }) + + it('preserves the result when filtering events', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const middleware: AgentStreamMiddleware = async function* (context, next) { + const gen = next(context) + let iterResult = await gen.next() + while (!iterResult.done) { + const event = iterResult.value + if (event.type !== 'contentBlockEvent') { + yield event + } + iterResult = await gen.next() + } + return iterResult.value + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { result } = await collectGenerator(agent.stream('Test prompt')) + + expect(result.stopReason).toBe('endTurn') + }) + + it('can suppress all events while still returning the result', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + // eslint-disable-next-line require-yield + const middleware: AgentStreamMiddleware = async function* (context, next) { + const gen = next(context) + let iterResult = await gen.next() + while (!iterResult.done) { + // Suppress all events — do not yield + iterResult = await gen.next() + } + return iterResult.value + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items, result } = await collectGenerator(agent.stream('Test prompt')) + + expect(items).toStrictEqual([]) + expect(result.stopReason).toBe('endTurn') + }) + }) + + describe('middleware can inject synthetic events', () => { + it('injects events before the stream', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const syntheticEvent = { type: 'contentBlockEvent' } as unknown as AgentStreamEvent + + const middleware: AgentStreamMiddleware = async function* (context, next) { + yield syntheticEvent + return yield* next(context) + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items } = await collectGenerator(agent.stream('Test prompt')) + + expect(items[0]).toBe(syntheticEvent) + expect(items.length).toBeGreaterThan(1) + }) + + it('injects events after the stream', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const syntheticEvent = { type: 'contentBlockEvent' } as unknown as AgentStreamEvent + + const middleware: AgentStreamMiddleware = async function* (context, next) { + const result = yield* next(context) + yield syntheticEvent + return result + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items } = await collectGenerator(agent.stream('Test prompt')) + + expect(items[items.length - 1]).toBe(syntheticEvent) + }) + + it('injects events alongside real events via manual iteration', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const syntheticEvent = { type: 'contentBlockEvent', synthetic: true } as unknown as AgentStreamEvent + + const middleware: AgentStreamMiddleware = async function* (context, next) { + const gen = next(context) + let iterResult = await gen.next() + let injected = false + while (!iterResult.done) { + yield iterResult.value + if (!injected) { + yield syntheticEvent + injected = true + } + iterResult = await gen.next() + } + return iterResult.value + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items } = await collectGenerator(agent.stream('Test prompt')) + + // The synthetic event should appear after the first real event + expect(items[1]).toBe(syntheticEvent) + // Total events should include exactly one synthetic event + expect(items.filter((e: AgentStreamEvent) => e === syntheticEvent)).toHaveLength(1) + }) + + it('can yield events without calling next (pure synthetic stream)', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Should not reach' }) + const agent = new Agent({ model, printer: false }) + + const syntheticEvent1 = { type: 'contentBlockEvent', id: 1 } as unknown as AgentStreamEvent + const syntheticEvent2 = { type: 'contentBlockEvent', id: 2 } as unknown as AgentStreamEvent + + const middleware: AgentStreamMiddleware = async function* () { + yield syntheticEvent1 + yield syntheticEvent2 + return { + result: { + stopReason: 'endTurn', + lastMessage: { type: 'message', role: 'assistant', content: [] }, + metrics: { cycleCount: 0, accumulatedUsage: {}, accumulatedMetrics: {}, toolMetrics: {} }, + invocationState: {}, + }, + } as unknown as AgentStreamResult + } + + agent.addMiddleware(AgentStreamStage, middleware) + + const { items } = await collectGenerator(agent.stream('Test prompt')) + + expect(items).toStrictEqual([syntheticEvent1, syntheticEvent2]) + }) + }) + + describe('no AgentStreamStage middleware registered', () => { + it('agent streams directly without middleware overhead', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const { items, result } = await collectGenerator(agent.stream('Test prompt')) + + expect(result.stopReason).toBe('endTurn') + expect(items.length).toBeGreaterThan(0) + }) + + it('existing behavior is unchanged when no middleware is registered', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello world' }) + const agent = new Agent({ model, printer: false }) + + const { items, result } = await collectGenerator(agent.stream('Test prompt')) + + const beforeInvocation = items.find((e: AgentStreamEvent) => e.type === 'beforeInvocationEvent') + const afterInvocation = items.find((e: AgentStreamEvent) => e.type === 'afterInvocationEvent') + expect(beforeInvocation).toBeDefined() + expect(afterInvocation).toBeDefined() + expect(result.stopReason).toBe('endTurn') + }) + + it('middleware on other stages does not affect AgentStreamStage', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + // Register middleware on InvokeModelStage only + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + return yield* next(context) + }) + + const { result } = await collectGenerator(agent.stream('Test prompt')) + + expect(result.stopReason).toBe('endTurn') + }) + }) +}) + +describe('ExecuteToolStage integration', () => { + describe('middleware executes around tool calls', () => { + it('middleware handler is invoked during tool execution', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: { x: 1 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('executed')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const executionOrder: string[] = [] + + const middleware: ExecuteToolMiddleware = async function* (context, next) { + executionOrder.push('middleware:before') + const result = yield* next(context) + executionOrder.push('middleware:after') + return result + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + expect(executionOrder).toStrictEqual(['middleware:before', 'middleware:after']) + }) + + it('middleware receives ExecuteToolContext with correct fields', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: { key: 'val' } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('ok')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + let receivedContext: ExecuteToolContext | undefined + + const middleware: ExecuteToolMiddleware = async function* (context, next) { + receivedContext = context + return yield* next(context) + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + expect(receivedContext).toBeDefined() + expect(receivedContext!.agent).toBe(agent) + expect(receivedContext!.tool).toBeDefined() + expect(receivedContext!.tool!.name).toBe('testTool') + expect(receivedContext!.toolUse).toStrictEqual({ + name: 'testTool', + toolUseId: 'tool-1', + input: { key: 'val' }, + }) + expect(receivedContext!.invocationState).toBeDefined() + }) + + it('multiple middleware execute in registration order', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('ok')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const callOrder: string[] = [] + + const outer: ExecuteToolMiddleware = async function* (context, next) { + callOrder.push('outer-before') + const result = yield* next(context) + callOrder.push('outer-after') + return result + } + + const inner: ExecuteToolMiddleware = async function* (context, next) { + callOrder.push('inner-before') + const result = yield* next(context) + callOrder.push('inner-after') + return result + } + + agent.addMiddleware(ExecuteToolStage, outer) + agent.addMiddleware(ExecuteToolStage, inner) + + await agent.invoke('Use the tool') + + expect(callOrder).toStrictEqual([ + 'outer-before', + 'inner-before', + 'inner-after', + 'outer-after', + ]) + }) + }) + + describe('middleware can mock tool responses (short-circuit)', () => { + it('returns mock result without executing the real tool', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const toolFn = vi.fn( + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('real result')], + }), + ) + + const tool = createMockTool('testTool', toolFn) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + const middleware: ExecuteToolMiddleware = async function* (context) { + return { + result: new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('mocked result')], + }), + } + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + const result = await agent.invoke('Use the tool') + + // The real tool function should NOT have been called + expect(toolFn).not.toHaveBeenCalled() + // The agent should still complete successfully + expect(result.stopReason).toBe('endTurn') + }) + + it('short-circuit result is used in the conversation', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Got the mocked data' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('real')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + const middleware: ExecuteToolMiddleware = async function* (context) { + return { + result: new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('mocked data')], + }), + } + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + // The tool result message in conversation should contain the mocked result + const toolResultMessage = agent.messages.find( + (m: Message) => m.role === 'user' && m.content.some((c) => c.type === 'toolResultBlock'), + ) + expect(toolResultMessage).toBeDefined() + const toolResultBlock = toolResultMessage!.content.find((c: { type: string }) => c.type === 'toolResultBlock') + expect(toolResultBlock).toStrictEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('mocked data')], + }), + ) + }) + }) + + describe('middleware can transform tool input via context modification', () => { + it('modified input reaches the tool', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: { value: 'original' } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let receivedInput: unknown + + const tool = createMockTool('testTool', (context: ToolContext) => { + receivedInput = context.toolUse.input + return new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('ok')], + }) + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const middleware: ExecuteToolMiddleware = async function* (context, next) { + const modifiedContext: ExecuteToolContext = { + ...context, + toolUse: { + ...context.toolUse, + input: { value: 'transformed' }, + }, + } + return yield* next(modifiedContext) + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + expect(receivedInput).toStrictEqual({ value: 'transformed' }) + }) + + it('original context is not mutated', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: { value: 'original' } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('ok')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + let originalInput: unknown + + const middleware: ExecuteToolMiddleware = async function* (context, next) { + originalInput = context.toolUse.input + const modifiedContext: ExecuteToolContext = { + ...context, + toolUse: { + ...context.toolUse, + input: { value: 'modified' }, + }, + } + return yield* next(modifiedContext) + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + // The original context input should remain unchanged + expect(originalInput).toStrictEqual({ value: 'original' }) + }) + }) + + describe('hooks fire around middleware for tool execution', () => { + it('BeforeToolCallEvent fires before middleware, AfterToolCallEvent fires after', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('executed')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const executionOrder: string[] = [] + + agent.addHook(BeforeToolCallEvent, () => { + executionOrder.push('hook:beforeToolCall') + }) + + agent.addHook(AfterToolCallEvent, () => { + executionOrder.push('hook:afterToolCall') + }) + + const middleware: ExecuteToolMiddleware = async function* (context, next) { + executionOrder.push('middleware:before') + const result = yield* next(context) + executionOrder.push('middleware:after') + return result + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + // Hooks fire OUTSIDE middleware: Before hook → middleware → After hook + expect(executionOrder).toStrictEqual([ + 'hook:beforeToolCall', + 'middleware:before', + 'middleware:after', + 'hook:afterToolCall', + ]) + }) + + it('hooks fire even when middleware short-circuits', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('real')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const beforeCalled = vi.fn() + const afterCalled = vi.fn() + + agent.addHook(BeforeToolCallEvent, beforeCalled) + agent.addHook(AfterToolCallEvent, afterCalled) + + // eslint-disable-next-line require-yield + const middleware: ExecuteToolMiddleware = async function* (context) { + return { + result: new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('mocked')], + }), + } + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + expect(beforeCalled).toHaveBeenCalled() + expect(afterCalled).toHaveBeenCalled() + }) + + it('AfterToolCallEvent receives the middleware result when short-circuited', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('real')], + }), + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + let afterToolResult: ToolResultBlock | undefined + + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + afterToolResult = event.result + }) + + // eslint-disable-next-line require-yield + const middleware: ExecuteToolMiddleware = async function* (context) { + return { + result: new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('from middleware')], + }), + } + } + + agent.addMiddleware(ExecuteToolStage, middleware) + + await agent.invoke('Use the tool') + + expect(afterToolResult).toBeDefined() + expect(afterToolResult!.content).toStrictEqual([new TextBlock('from middleware')]) + }) + }) +}) + +describe('Middleware use cases', () => { + describe('caching tool results', () => { + class ToolResultCache implements Plugin { + name = 'tool-result-cache' + + private readonly _cache = new Map() + + initAgent(agent: LocalAgent): void { + const cache = this._cache + + // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { + const key = `${context.toolUse.name}:${JSON.stringify(context.toolUse.input)}` + const cached = cache.get(key) + if (cached) { + return { + result: new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: cached.status, + content: cached.content, + }), + } + } + const result = yield* next(context) + cache.set(key, result.result) + return result + }) + } + } + + it('returns cached result on second call, skipping real execution', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'expensiveApi', toolUseId: 'call-1', input: { query: 'weather' } }) + .addTurn({ type: 'textBlock', text: 'First done' }) + .addTurn({ type: 'toolUseBlock', name: 'expensiveApi', toolUseId: 'call-2', input: { query: 'weather' } }) + .addTurn({ type: 'textBlock', text: 'Second done' }) + + const realCallCount = vi.fn() + const tool = createMockTool('expensiveApi', (ctx: ToolContext) => { + realCallCount() + return new ToolResultBlock({ + toolUseId: ctx.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('sunny, 72°F')], + }) + }) + + const agent = new Agent({ model, tools: [tool], plugins: [new ToolResultCache()], printer: false }) + + await agent.invoke('What is the weather?') + expect(realCallCount).toHaveBeenCalledTimes(1) + + await agent.invoke('What is the weather?') + expect(realCallCount).toHaveBeenCalledTimes(1) // cache hit + }) + }) + + describe('auto-retrying model invocations', () => { + class RetryOnThrottle implements Plugin { + name = 'retry-on-throttle' + + private readonly _maxRetries: number + + constructor(maxRetries = 3) { + this._maxRetries = maxRetries + } + + initAgent(agent: LocalAgent): void { + const maxRetries = this._maxRetries + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + for (let attempt = 0; attempt < maxRetries; attempt++) { + try { + return yield* next(context) + } catch (e) { + const isRetryable = (e as Error).message.includes('ThrottlingException') + if (!isRetryable || attempt === maxRetries - 1) throw e + // In production: await sleep(backoff(attempt)) + } + } + throw new Error('exhausted retries') + }) + } + } + + it('retries on transient error and succeeds on second attempt', async () => { + let callCount = 0 + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Success after retry' }) + + const agent = new Agent({ model, plugins: [new RetryOnThrottle(3)], printer: false }) + + const originalStream = model.stream.bind(model) + vi.spyOn(model, 'stream').mockImplementation((...args) => { + callCount++ + if (callCount === 1) throw new Error('ThrottlingException: rate limit exceeded') + return originalStream(...args) + }) + + const result = await agent.invoke('Hello') + + expect(callCount).toBe(2) + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content).toEqual([new TextBlock('Success after retry')]) + }) + }) + + describe('stream final turn only (buffer intermediate turns)', () => { + class StreamFinalTurnOnly implements Plugin { + name = 'stream-final-turn-only' + + initAgent(agent: LocalAgent): void { + agent.addMiddleware(AgentStreamStage, (...args) => this._handler(...args)) + } + + private async *_handler( + ...[context, next]: Parameters> + ): ReturnType> { + let buffer: AgentStreamEvent[] = [] + const gen = next(context) + let iterResult = await gen.next() + + while (!iterResult.done) { + const event = iterResult.value + + if (event.type === 'contentBlockEvent' || event.type === 'modelStreamUpdateEvent') { + buffer.push(event) + } else if (event.type === 'afterModelCallEvent') { + const stopReason = (event as AfterModelCallEvent).stopData?.stopReason + if (stopReason === 'endTurn') { + for (const buffered of buffer) yield buffered + } + buffer = [] + yield event + } else { + yield event + } + + iterResult = await gen.next() + } + + for (const buffered of buffer) yield buffered + return iterResult.value + } + } + + it('suppresses content events from intermediate tool-use turns, emits only final turn', async () => { + const model = new MockMessageModel() + .addTurn([ + { type: 'textBlock', text: 'Let me check that for you' }, + { type: 'toolUseBlock', name: 'lookup', toolUseId: 'tool-1', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'The answer is 42' }) + + const tool = createMockTool('lookup', (ctx: ToolContext) => + new ToolResultBlock({ + toolUseId: ctx.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('42')], + }), + ) + + const agent = new Agent({ model, tools: [tool], plugins: [new StreamFinalTurnOnly()], printer: false }) + + const { items, result } = await collectGenerator(agent.stream('What is the meaning of life?')) + + const contentEvents = items.filter((e: AgentStreamEvent) => e.type === 'contentBlockEvent') + expect(contentEvents).toHaveLength(1) + expect((contentEvents[0] as ContentBlockEvent).contentBlock).toStrictEqual(new TextBlock('The answer is 42')) + expect(result.stopReason).toBe('endTurn') + }) + + it('passes through all events when there is only one turn (no tool use)', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Simple answer' }) + const agent = new Agent({ model, plugins: [new StreamFinalTurnOnly()], printer: false }) + + const { items, result } = await collectGenerator(agent.stream('Hello')) + + const contentEvents = items.filter((e: AgentStreamEvent) => e.type === 'contentBlockEvent') + expect(contentEvents).toHaveLength(1) + expect((contentEvents[0] as ContentBlockEvent).contentBlock).toStrictEqual(new TextBlock('Simple answer')) + expect(result.stopReason).toBe('endTurn') + }) + }) +}) diff --git a/strands-ts/src/middleware/__tests__/custom-stages.test.ts b/strands-ts/src/middleware/__tests__/custom-stages.test.ts new file mode 100644 index 000000000..7191dc67d --- /dev/null +++ b/strands-ts/src/middleware/__tests__/custom-stages.test.ts @@ -0,0 +1,151 @@ +import { describe, expect, it } from 'vitest' +import { createStage, MiddlewareRegistry } from '../index.js' +import type { MiddlewareHandler, MiddlewareNext } from '../types.js' + +// Custom stage types for testing third-party extensibility +interface CustomContext { + readonly label: string + readonly count: number +} + +interface CustomEvent { + readonly kind: string +} + +interface CustomResult { + readonly summary: string +} + +/** + * Helper to collect all yielded events and the return value from an async generator. + */ +async function collect( + gen: AsyncGenerator, +): Promise<{ events: TEvent[]; result: TResult }> { + const events: TEvent[] = [] + let iterResult = await gen.next() + while (!iterResult.done) { + events.push(iterResult.value) + iterResult = await gen.next() + } + return { events, result: iterResult.value } +} + +describe('Third-party custom stages', () => { + describe('createStage', () => { + it('returns a frozen object', () => { + const stage = createStage('myCustomStage') + expect(Object.isFrozen(stage)).toBe(true) + }) + + it('returns object with correct name property', () => { + const stage = createStage('myCustomStage') + expect(stage.name).toBe('myCustomStage') + }) + }) + + describe('custom stage with registry', () => { + it('works with registry.add and compose (handlers execute correctly)', async () => { + const CustomStage = createStage('custom') + const registry = new MiddlewareRegistry() + const callOrder: string[] = [] + + const handler: MiddlewareHandler = async function* (context, next) { + callOrder.push('middleware') + yield { kind: `pre-${context.label}` } + const result = yield* next(context) + callOrder.push('middleware-after') + return result + } + + registry.add(CustomStage, handler) + + const terminal: MiddlewareNext = async function* (ctx) { + callOrder.push('terminal') + yield { kind: `terminal-${ctx.label}` } + return { summary: `done-${ctx.count}` } + } + + const chain = registry.compose(CustomStage, terminal) + const { events, result } = await collect(chain({ label: 'test', count: 42 })) + + expect(callOrder).toStrictEqual(['middleware', 'terminal', 'middleware-after']) + expect(events).toStrictEqual([ + { kind: 'pre-test' }, + { kind: 'terminal-test' }, + ]) + expect(result).toStrictEqual({ summary: 'done-42' }) + }) + + it('two stages with the same name are distinct (reference identity)', async () => { + const StageA = createStage('shared-name') + const StageB = createStage('shared-name') + + const registry = new MiddlewareRegistry() + + const handlerA: MiddlewareHandler = async function* (context, next) { + yield { kind: 'from-A' } + return yield* next(context) + } + + const handlerB: MiddlewareHandler = async function* (context, next) { + yield { kind: 'from-B' } + return yield* next(context) + } + + registry.add(StageA, handlerA) + registry.add(StageB, handlerB) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + return { summary: 'terminal' } + } + + // Composing for StageA should only include handlerA + const chainA = registry.compose(StageA, terminal) + const resultA = await collect(chainA({ label: 'a', count: 1 })) + expect(resultA.events).toStrictEqual([{ kind: 'from-A' }]) + + // Composing for StageB should only include handlerB + const chainB = registry.compose(StageB, terminal) + const resultB = await collect(chainB({ label: 'b', count: 2 })) + expect(resultB.events).toStrictEqual([{ kind: 'from-B' }]) + }) + + it('custom stage middleware can be composed with a terminal and executes correctly', async () => { + const CustomStage = createStage('pipeline') + const registry = new MiddlewareRegistry() + + // Register multiple middleware for the custom stage + const logger: MiddlewareHandler = async function* (context, next) { + yield { kind: 'log-start' } + const result = yield* next(context) + yield { kind: 'log-end' } + return result + } + + const transformer: MiddlewareHandler = async function* (context, next) { + const modified = { ...context, count: context.count * 2 } + return yield* next(modified) + } + + registry.add(CustomStage, logger) + registry.add(CustomStage, transformer) + + const terminal: MiddlewareNext = async function* (ctx) { + yield { kind: `processed-${ctx.count}` } + return { summary: `result-${ctx.label}-${ctx.count}` } + } + + const chain = registry.compose(CustomStage, terminal) + const { events, result } = await collect(chain({ label: 'item', count: 5 })) + + expect(events).toStrictEqual([ + { kind: 'log-start' }, + { kind: 'processed-10' }, + { kind: 'log-end' }, + ]) + expect(result).toStrictEqual({ summary: 'result-item-10' }) + }) + }) +}) diff --git a/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts b/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts new file mode 100644 index 000000000..2dacd0cab --- /dev/null +++ b/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts @@ -0,0 +1,216 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { ExecuteToolStage, AgentStreamStage } from '../stages.js' +import { TextBlock, ToolResultBlock } from '../../types/messages.js' +import { InterruptResponseContent } from '../../types/interrupt.js' +import type { ToolContext } from '../../tools/tool.js' + +describe('Middleware interrupts', () => { + describe('ExecuteToolStage', () => { + it('middleware can raise an interrupt (agent stops with stopReason interrupt)', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'dangerousTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Should not reach' }) + + const tool = createMockTool('dangerousTool', () => 'executed') + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { + context.interrupt({ name: 'approve_tool', reason: 'Confirm execution?' }) + return yield* next(context) + }) + + const result = await agent.invoke('Do the dangerous thing') + + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toEqual([ + expect.objectContaining({ name: 'approve_tool', reason: 'Confirm execution?' }), + ]) + }) + + it('middleware gets response on resume and continues execution', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'dangerousTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('dangerousTool', () => { + toolExecuted = true + return 'executed' + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { + const approval = context.interrupt({ name: 'approve_tool', reason: 'Confirm?' }) + if (approval !== 'yes') { + return { + result: new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'error', + content: [new TextBlock('Denied by user')], + }), + } + } + return yield* next(context) + }) + + // First invocation: interrupt fires + const interruptResult = await agent.invoke('Do it') + expect(interruptResult.stopReason).toBe('interrupt') + expect(toolExecuted).toBe(false) + + // Resume with approval + const finalResult = await agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }), + ]) + + expect(finalResult.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + + it('interrupt ID includes toolUseId for disambiguation', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'unique-tool-id', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('myTool', () => 'ok') + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { + context.interrupt({ name: 'check' }) + return yield* next(context) + }) + + const result = await agent.invoke('Test') + + expect(result.interrupts![0]!.id).toContain('unique-tool-id') + expect(result.interrupts![0]!.id).toContain('check') + }) + + it('preemptive response skips the interrupt (no halt)', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('myTool', () => { + toolExecuted = true + return 'ok' + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { + // Preemptive response: returns immediately without halting + const approval = context.interrupt({ name: 'check', response: 'pre-approved' }) + expect(approval).toBe('pre-approved') + return yield* next(context) + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + + it('context spread preserves interrupt function', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: { x: 1 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('myTool', () => 'ok') + const agent = new Agent({ model, tools: [tool], printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { + // Spread context to modify toolUse, interrupt should still work + const modified = { ...context, toolUse: { ...context.toolUse, input: { x: 2 } } } + modified.interrupt({ name: 'after_spread' }) + return yield* next(modified) + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts![0]!.name).toBe('after_spread') + }) + }) + + describe('AgentStreamStage', () => { + it('middleware can raise an interrupt (agent stops with stopReason interrupt)', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(AgentStreamStage, async function* (context) { + context.interrupt({ name: 'confirm_stream', reason: 'Are you sure?' }) + // unreachable — interrupt() throws + return undefined as never + }) + + const { result } = await collectGenerator(agent.stream('Test')) + + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toEqual([ + expect.objectContaining({ name: 'confirm_stream', reason: 'Are you sure?' }), + ]) + }) + + it('middleware gets response on resume and continues', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addMiddleware(AgentStreamStage, async function* (context, next) { + const approval = context.interrupt({ name: 'gate', reason: 'Proceed?' }) + if (approval !== 'go') { + // eslint-disable-next-line require-yield + return { result: { stopReason: 'endTurn' } } as never + } + return yield* next(context) + }) + + // First: interrupt + const { result: interruptResult } = await collectGenerator(agent.stream('Test')) + expect(interruptResult.stopReason).toBe('interrupt') + + // Resume + const { result: finalResult } = await collectGenerator( + agent.stream([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'go', + }), + ]), + ) + + expect(finalResult.stopReason).toBe('endTurn') + }) + + it('interrupt ID uses agentStream namespace', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + // eslint-disable-next-line require-yield + agent.addMiddleware(AgentStreamStage, async function* (context) { + context.interrupt({ name: 'my_gate' }) + return undefined as never + }) + + const { result } = await collectGenerator(agent.stream('Test')) + + expect(result.interrupts![0]!.id).toContain('agentStream') + expect(result.interrupts![0]!.id).toContain('my_gate') + }) + }) +}) diff --git a/strands-ts/src/middleware/__tests__/registry.test.ts b/strands-ts/src/middleware/__tests__/registry.test.ts new file mode 100644 index 000000000..a88e51fee --- /dev/null +++ b/strands-ts/src/middleware/__tests__/registry.test.ts @@ -0,0 +1,628 @@ +import { describe, expect, it } from 'vitest' +import { MiddlewareRegistry, createStage } from '../index.js' +import type { MiddlewareHandler, MiddlewareNext } from '../types.js' +import { CancelledError } from '../../errors.js' +import { InterruptError, Interrupt } from '../../interrupt.js' + +// Simple test types +interface TestContext { + readonly value: string +} + +interface TestEvent { + readonly data: string +} + +interface TestResult { + readonly output: string +} + +const TestStage = createStage('test') + +/** + * Helper to collect all yielded events and the return value from an async generator. + */ +async function collect( + gen: AsyncGenerator, +): Promise<{ events: TEvent[]; result: TResult }> { + const events: TEvent[] = [] + let iterResult = await gen.next() + while (!iterResult.done) { + events.push(iterResult.value) + iterResult = await gen.next() + } + return { events, result: iterResult.value } +} + +describe('MiddlewareRegistry', () => { + describe('add', () => { + it('stores handlers in registration order', () => { + const registry = new MiddlewareRegistry() + const callOrder: number[] = [] + + const handler1: MiddlewareHandler = async function* (context, next) { + callOrder.push(1) + return yield* next(context) + } + + const handler2: MiddlewareHandler = async function* (context, next) { + callOrder.push(2) + return yield* next(context) + } + + registry.add(TestStage, handler1) + registry.add(TestStage, handler2) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + callOrder.push(3) + return { output: 'done' } + } + + const chain = registry.compose(TestStage, terminal) + // Execute the chain to verify order + const gen = chain({ value: 'test' }) + return collect(gen).then(() => { + expect(callOrder).toStrictEqual([1, 2, 3]) + }) + }) + }) + + describe('compose', () => { + it('with no handlers returns terminal-equivalent function', async () => { + const registry = new MiddlewareRegistry() + + const terminal: MiddlewareNext = async function* (ctx) { + yield { data: `event-${ctx.value}` } + return { output: `result-${ctx.value}` } + } + + const chain = registry.compose(TestStage, terminal) + const { events, result } = await collect(chain({ value: 'hello' })) + + expect(events).toStrictEqual([{ data: 'event-hello' }]) + expect(result).toStrictEqual({ output: 'result-hello' }) + }) + + it('executes handlers in registration order (outermost first)', async () => { + const registry = new MiddlewareRegistry() + const callOrder: string[] = [] + + const outer: MiddlewareHandler = async function* (context, next) { + callOrder.push('outer-before') + const result = yield* next(context) + callOrder.push('outer-after') + return result + } + + const inner: MiddlewareHandler = async function* (context, next) { + callOrder.push('inner-before') + const result = yield* next(context) + callOrder.push('inner-after') + return result + } + + registry.add(TestStage, outer) + registry.add(TestStage, inner) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + callOrder.push('terminal') + return { output: 'done' } + } + + const chain = registry.compose(TestStage, terminal) + await collect(chain({ value: 'test' })) + + expect(callOrder).toStrictEqual([ + 'outer-before', + 'inner-before', + 'terminal', + 'inner-after', + 'outer-after', + ]) + }) + }) + + describe('short-circuit behavior', () => { + it('does not call next handlers or terminal when middleware does not call next', async () => { + const registry = new MiddlewareRegistry() + let terminalCalled = false + let innerCalled = false + + const shortCircuit: MiddlewareHandler = async function* () { + yield { data: 'synthetic' } + return { output: 'short-circuited' } + } + + const inner: MiddlewareHandler = async function* (context, next) { + innerCalled = true + return yield* next(context) + } + + registry.add(TestStage, shortCircuit) + registry.add(TestStage, inner) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + terminalCalled = true + return { output: 'terminal' } + } + + const chain = registry.compose(TestStage, terminal) + const { events, result } = await collect(chain({ value: 'test' })) + + expect(terminalCalled).toBe(false) + expect(innerCalled).toBe(false) + expect(events).toStrictEqual([{ data: 'synthetic' }]) + expect(result).toStrictEqual({ output: 'short-circuited' }) + }) + }) + + describe('event pass-through via yield* next(context)', () => { + it('forwards all events from terminal and returns terminal result', async () => { + const registry = new MiddlewareRegistry() + + const passThrough: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + registry.add(TestStage, passThrough) + + const terminal: MiddlewareNext = async function* () { + yield { data: 'event-1' } + yield { data: 'event-2' } + yield { data: 'event-3' } + return { output: 'terminal-result' } + } + + const chain = registry.compose(TestStage, terminal) + const { events, result } = await collect(chain({ value: 'test' })) + + expect(events).toStrictEqual([ + { data: 'event-1' }, + { data: 'event-2' }, + { data: 'event-3' }, + ]) + expect(result).toStrictEqual({ output: 'terminal-result' }) + }) + }) + + describe('event filtering via manual iteration', () => { + it('only yields events matching a predicate', async () => { + const registry = new MiddlewareRegistry() + + const filter: MiddlewareHandler = async function* (context, next) { + const gen = next(context) + let iterResult = await gen.next() + while (!iterResult.done) { + const event = iterResult.value + // Only forward events containing 'keep' + if (event.data.includes('keep')) { + yield event + } + iterResult = await gen.next() + } + return iterResult.value + } + + registry.add(TestStage, filter) + + const terminal: MiddlewareNext = async function* () { + yield { data: 'keep-1' } + yield { data: 'drop-1' } + yield { data: 'keep-2' } + yield { data: 'drop-2' } + return { output: 'done' } + } + + const chain = registry.compose(TestStage, terminal) + const { events, result } = await collect(chain({ value: 'test' })) + + expect(events).toStrictEqual([ + { data: 'keep-1' }, + { data: 'keep-2' }, + ]) + expect(result).toStrictEqual({ output: 'done' }) + }) + }) + + describe('context modification flows to terminal', () => { + it('terminal receives modified context', async () => { + const registry = new MiddlewareRegistry() + let receivedContext: TestContext | undefined + + const modifier: MiddlewareHandler = async function* (context, next) { + const modified = { ...context, value: 'modified' } + return yield* next(modified) + } + + registry.add(TestStage, modifier) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* (ctx) { + receivedContext = ctx + return { output: ctx.value } + } + + const chain = registry.compose(TestStage, terminal) + const { result } = await collect(chain({ value: 'original' })) + + expect(receivedContext).toStrictEqual({ value: 'modified' }) + expect(result).toStrictEqual({ output: 'modified' }) + }) + + it('each middleware can further modify context', async () => { + const registry = new MiddlewareRegistry() + let receivedContext: TestContext | undefined + + const first: MiddlewareHandler = async function* (context, next) { + return yield* next({ ...context, value: context.value + '-first' }) + } + + const second: MiddlewareHandler = async function* (context, next) { + return yield* next({ ...context, value: context.value + '-second' }) + } + + registry.add(TestStage, first) + registry.add(TestStage, second) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* (ctx) { + receivedContext = ctx + return { output: ctx.value } + } + + const chain = registry.compose(TestStage, terminal) + const { result } = await collect(chain({ value: 'start' })) + + expect(receivedContext).toStrictEqual({ value: 'start-first-second' }) + expect(result).toStrictEqual({ output: 'start-first-second' }) + }) + }) + + describe('error propagation through chain', () => { + it('errors from terminal reach middleware', async () => { + const registry = new MiddlewareRegistry() + let caughtError: Error | undefined + + const catcher: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } catch (error) { + caughtError = error as Error + return { output: 'recovered' } + } + } + + registry.add(TestStage, catcher) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new Error('terminal error') + } + + const chain = registry.compose(TestStage, terminal) + const { result } = await collect(chain({ value: 'test' })) + + expect(caughtError).toBeInstanceOf(Error) + expect(caughtError!.message).toBe('terminal error') + expect(result).toStrictEqual({ output: 'recovered' }) + }) + + it('errors from middleware reach caller', async () => { + const registry = new MiddlewareRegistry() + + // eslint-disable-next-line require-yield + const thrower: MiddlewareHandler = async function* () { + throw new Error('middleware error') + } + + registry.add(TestStage, thrower) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + return { output: 'done' } + } + + const chain = registry.compose(TestStage, terminal) + + await expect(collect(chain({ value: 'test' }))).rejects.toThrow('middleware error') + }) + + it('middleware can transform errors from next', async () => { + const registry = new MiddlewareRegistry() + + const transformer: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } catch { + throw new Error('transformed error') + } + } + + registry.add(TestStage, transformer) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new Error('original error') + } + + const chain = registry.compose(TestStage, terminal) + + await expect(collect(chain({ value: 'test' }))).rejects.toThrow('transformed error') + }) + }) + + describe('CancelledError and InterruptError propagation', () => { + it('CancelledError propagates without being swallowed', async () => { + const registry = new MiddlewareRegistry() + + const passThrough: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + registry.add(TestStage, passThrough) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new CancelledError() + } + + const chain = registry.compose(TestStage, terminal) + + await expect(collect(chain({ value: 'test' }))).rejects.toThrow(CancelledError) + }) + + it('InterruptError propagates without being swallowed', async () => { + const registry = new MiddlewareRegistry() + + const passThrough: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + registry.add(TestStage, passThrough) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new InterruptError(new Interrupt({ id: 'int-1', name: 'test_interrupt' })) + } + + const chain = registry.compose(TestStage, terminal) + + await expect(collect(chain({ value: 'test' }))).rejects.toThrow(InterruptError) + }) + + it('CancelledError propagates through multiple middleware layers', async () => { + const registry = new MiddlewareRegistry() + + const outer: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + const inner: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + registry.add(TestStage, outer) + registry.add(TestStage, inner) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new CancelledError() + } + + const chain = registry.compose(TestStage, terminal) + + await expect(collect(chain({ value: 'test' }))).rejects.toThrow(CancelledError) + }) + + it('InterruptError propagates through multiple middleware layers', async () => { + const registry = new MiddlewareRegistry() + + const outer: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + const inner: MiddlewareHandler = async function* (context, next) { + return yield* next(context) + } + + registry.add(TestStage, outer) + registry.add(TestStage, inner) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new InterruptError(new Interrupt({ id: 'int-1', name: 'test_interrupt' })) + } + + const chain = registry.compose(TestStage, terminal) + + await expect(collect(chain({ value: 'test' }))).rejects.toThrow(InterruptError) + }) + }) + + describe('try-finally guarantees', () => { + it('outer finally runs when inner middleware throws', async () => { + const registry = new MiddlewareRegistry() + const order: string[] = [] + + const outer: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } finally { + order.push('outer-finally') + } + } + + // eslint-disable-next-line require-yield + const inner: MiddlewareHandler = async function* () { + order.push('inner-throw') + throw new Error('inner exploded') + } + + registry.add(TestStage, outer) + registry.add(TestStage, inner) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + return { output: 'unreachable' } + } + + const chain = registry.compose(TestStage, terminal) + await expect(collect(chain({ value: 'test' }))).rejects.toThrow('inner exploded') + + expect(order).toStrictEqual(['inner-throw', 'outer-finally']) + }) + + it('inner finally runs when terminal throws', async () => { + const registry = new MiddlewareRegistry() + const order: string[] = [] + + const outer: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } finally { + order.push('outer-finally') + } + } + + const inner: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } finally { + order.push('inner-finally') + } + } + + registry.add(TestStage, outer) + registry.add(TestStage, inner) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + order.push('terminal-throw') + throw new Error('terminal exploded') + } + + const chain = registry.compose(TestStage, terminal) + await expect(collect(chain({ value: 'test' }))).rejects.toThrow('terminal exploded') + + expect(order).toStrictEqual(['terminal-throw', 'inner-finally', 'outer-finally']) + }) + + it('all finally blocks run in reverse order when terminal throws', async () => { + const registry = new MiddlewareRegistry() + const order: string[] = [] + + const a: MiddlewareHandler = async function* (context, next) { + try { + order.push('a-enter') + return yield* next(context) + } finally { + order.push('a-finally') + } + } + + const b: MiddlewareHandler = async function* (context, next) { + try { + order.push('b-enter') + return yield* next(context) + } finally { + order.push('b-finally') + } + } + + const c: MiddlewareHandler = async function* (context, next) { + try { + order.push('c-enter') + return yield* next(context) + } finally { + order.push('c-finally') + } + } + + registry.add(TestStage, a) + registry.add(TestStage, b) + registry.add(TestStage, c) + + // eslint-disable-next-line require-yield + const terminal: MiddlewareNext = async function* () { + throw new Error('boom') + } + + const chain = registry.compose(TestStage, terminal) + await expect(collect(chain({ value: 'test' }))).rejects.toThrow('boom') + + expect(order).toStrictEqual(['a-enter', 'b-enter', 'c-enter', 'c-finally', 'b-finally', 'a-finally']) + }) + + it('finally runs even when caller abandons the generator mid-stream', async () => { + const registry = new MiddlewareRegistry() + const order: string[] = [] + + const middleware: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } finally { + order.push('middleware-finally') + } + } + + registry.add(TestStage, middleware) + + const terminal: MiddlewareNext = async function* () { + yield { data: 'event-1' } + yield { data: 'event-2' } + yield { data: 'event-3' } + return { output: 'done' } + } + + const chain = registry.compose(TestStage, terminal) + const gen = chain({ value: 'test' }) + + // Only consume one event then abandon (call return to close the generator) + await gen.next() + await gen.return({ output: 'abandoned' }) + + expect(order).toStrictEqual(['middleware-finally']) + }) + + it('finally runs in both middleware when caller abandons mid-stream', async () => { + const registry = new MiddlewareRegistry() + const order: string[] = [] + + const outer: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } finally { + order.push('outer-finally') + } + } + + const inner: MiddlewareHandler = async function* (context, next) { + try { + return yield* next(context) + } finally { + order.push('inner-finally') + } + } + + registry.add(TestStage, outer) + registry.add(TestStage, inner) + + const terminal: MiddlewareNext = async function* () { + yield { data: 'event-1' } + yield { data: 'event-2' } + return { output: 'done' } + } + + const chain = registry.compose(TestStage, terminal) + const gen = chain({ value: 'test' }) + + await gen.next() + await gen.return({ output: 'abandoned' }) + + expect(order).toStrictEqual(['inner-finally', 'outer-finally']) + }) + }) +}) diff --git a/strands-ts/src/middleware/index.ts b/strands-ts/src/middleware/index.ts new file mode 100644 index 000000000..9290de85c --- /dev/null +++ b/strands-ts/src/middleware/index.ts @@ -0,0 +1,16 @@ +export type { Stage, MiddlewareNext, MiddlewareHandler, HandlerOf, NextOf } from './types.js' +export { + createStage, + InvokeModelStage, + ExecuteToolStage, + AgentStreamStage, +} from './stages.js' +export type { + InvokeModelContext, + InvokeModelResult, + ExecuteToolContext, + ExecuteToolResult, + AgentStreamContext, + AgentStreamResult, +} from './stages.js' +export { MiddlewareRegistry } from './registry.js' diff --git a/strands-ts/src/middleware/registry.ts b/strands-ts/src/middleware/registry.ts new file mode 100644 index 000000000..ca441a147 --- /dev/null +++ b/strands-ts/src/middleware/registry.ts @@ -0,0 +1,73 @@ +import type { Stage, MiddlewareHandler, MiddlewareNext } from './types.js' + +/** + * Registry that stores middleware handlers keyed by stage tokens + * and composes them into execution chains. + */ +export class MiddlewareRegistry { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + private readonly _handlers: Map, MiddlewareHandler[]> + + constructor() { + this._handlers = new Map() + } + + /** + * Register a middleware handler for a given stage. + * Handlers are stored in registration order (first registered = outermost). + * + * @param stage - The stage token to register the handler for + * @param handler - The middleware handler function + */ + add( + stage: Stage, + handler: MiddlewareHandler, + ): void { + const handlers = this._handlers.get(stage) ?? [] + handlers.push(handler) + this._handlers.set(stage, handlers) + } + + /** + * Compose all registered handlers for a stage into a single middleware chain. + * The chain executes handlers in registration order (first registered = outermost) + * with the terminal function as the innermost layer. + * + * @param stage - The stage token to compose handlers for + * @param terminal - The innermost function that performs actual stage execution + * @returns A single function representing the full middleware chain + */ + compose( + stage: Stage, + terminal: MiddlewareNext, + ): MiddlewareNext { + const handlers = (this._handlers.get(stage) ?? []) as MiddlewareHandler[] + + let current: MiddlewareNext = terminal + for (let i = handlers.length - 1; i >= 0; i--) { + const handler = handlers[i]! + const next = current + current = (context: TContext): AsyncGenerator => handler(context, next) + } + + return current + } + + /** + * Compose and invoke the middleware chain for a stage in one call. + * Equivalent to `compose(stage, terminal)(context)` but reads more clearly at call sites. + * + * @param stage - The stage token to invoke + * @param context - The context to pass into the chain + * @param terminal - The innermost function that performs actual stage execution + * @returns An async generator yielding events and returning the stage result + */ + invoke( + stage: Stage, + context: TContext, + terminal: MiddlewareNext, + ): AsyncGenerator { + const chain = this.compose(stage, terminal) + return chain(context) + } +} diff --git a/strands-ts/src/middleware/stages.ts b/strands-ts/src/middleware/stages.ts new file mode 100644 index 000000000..a6bfa996f --- /dev/null +++ b/strands-ts/src/middleware/stages.ts @@ -0,0 +1,115 @@ +import type { Stage } from './types.js' +import type { LocalAgent, AgentStreamEvent, InvocationState, InvokeArgs, InvokeOptions, AgentResult } from '../types/agent.js' +import type { Message, SystemPrompt, ToolResultBlock } from '../types/messages.js' +import type { ToolSpec, ToolChoice } from '../tools/types.js' +import type { StateStore } from '../state-store.js' +import type { StreamAggregatedResult } from '../models/model.js' +import type { ToolUseData } from '../hooks/events.js' +import type { Tool } from '../tools/tool.js' +import type { Interruptible } from '../interrupt.js' + +/** + * Creates a new middleware stage token. + * The returned object is frozen and used as a Map key by the registry. + * + * @param name - Human-readable name for debugging/logging + * @returns A frozen Stage object carrying the Context/Event/Result type parameters + */ +export function createStage(name: string): Stage { + return Object.freeze({ name }) as Stage +} + +/** + * Context passed to model-stage middleware. + * All inputs to the model call are explicit — middleware can inspect and transform + * any of them by passing a modified context to next(). + */ +export interface InvokeModelContext { + /** The agent instance (escape hatch for advanced use cases). */ + readonly agent: LocalAgent + /** The messages to send to the model. */ + readonly messages: Message[] + /** System prompt to guide the model's behavior. */ + readonly systemPrompt?: SystemPrompt + /** Tool specifications available to the model. */ + readonly toolSpecs: ToolSpec[] + /** Controls how the model selects tools. */ + readonly toolChoice?: ToolChoice + /** Runtime state for stateful model providers. */ + readonly modelState: StateStore + /** Per-invocation state shared across hooks and tools. */ + readonly invocationState: InvocationState +} + +/** + * Result from model-stage middleware. + * The return value of the async generator. + */ +export interface InvokeModelResult { + /** The aggregated result from the model stream. */ + readonly result: StreamAggregatedResult +} + +/** + * Context passed to tool-stage middleware. + * Contains everything needed to understand and potentially modify the tool call. + */ +export interface ExecuteToolContext extends Interruptible { + /** The agent instance (escape hatch for advanced use cases). */ + readonly agent: LocalAgent + /** The resolved tool implementation, or undefined if not found. */ + readonly tool: Tool | undefined + /** The tool use request (name, toolUseId, input). */ + readonly toolUse: ToolUseData + /** Per-invocation state shared across hooks and tools. */ + readonly invocationState: InvocationState +} + +/** + * Result from tool-stage middleware. + * The return value of the async generator. + */ +export interface ExecuteToolResult { + /** The tool result block from execution. */ + readonly result: ToolResultBlock +} + +/** + * Context passed to agent-stream-stage middleware. + * Wraps the entire agent output stream at the outermost interception point. + */ +export interface AgentStreamContext extends Interruptible { + /** The agent instance (escape hatch for advanced use cases). */ + readonly agent: LocalAgent + /** The invocation arguments passed to agent.stream(). */ + readonly args: InvokeArgs + /** Per-invocation options (cancel signal, structured output, etc.). */ + readonly options?: InvokeOptions +} + +/** + * Result from agent-stream-stage middleware. + * The return value of the async generator. + */ +export interface AgentStreamResult { + /** The final agent result from the stream. */ + readonly result: AgentResult +} + +/** + * Built-in stage wrapping core model invocation. + * Middleware registered for this stage can rate-limit, cache, or transform model inputs. + */ +export const InvokeModelStage = createStage('invokeModel') + +/** + * Built-in stage wrapping individual tool execution. + * Middleware registered for this stage can add telemetry, validate inputs, or mock responses. + */ +export const ExecuteToolStage = createStage('executeTool') + +/** + * Built-in stage wrapping the entire agent output stream. + * Middleware registered for this stage can filter, transform, or inject events. + */ +export const AgentStreamStage = createStage('agentStream') diff --git a/strands-ts/src/middleware/types.ts b/strands-ts/src/middleware/types.ts new file mode 100644 index 000000000..1b54590d7 --- /dev/null +++ b/strands-ts/src/middleware/types.ts @@ -0,0 +1,57 @@ +/** + * A stage token that identifies a middleware interception point. + * Stages are created via `createStage()` and carry their Context/Event/Result types + * as generics, enabling full type inference at registration sites. + * + * Third parties can create custom stages — the SDK does not maintain a closed set. + */ +export interface Stage { + /** Human-readable name for debugging and logging. */ + readonly name: string + /** @internal Phantom field for type inference. Never accessed at runtime. */ + readonly _types?: { context: TContext; event: TEvent; result: TResult } +} + +/** + * The `next` function passed to middleware. + * Returns an async generator that yields events of type TEvent and returns the stage result. + * Middleware can choose not to call `next` to short-circuit execution. + */ +export type MiddlewareNext = ( + context: TContext +) => AsyncGenerator + +/** + * A middleware handler function. + * Receives the context and a `next` function to call the next layer. + * Must be an async generator that yields TEvent and returns TResult. + * Middleware can yield its own events, forward events from next, or suppress them. + */ +export type MiddlewareHandler = ( + context: TContext, + next: MiddlewareNext +) => AsyncGenerator + +/** + * Extracts the `MiddlewareHandler` type from a stage token. + * Use this to type middleware methods or properties without repeating the generic parameters. + * + * @example + * ```typescript + * class MyPlugin implements Plugin { + * private _handler: HandlerOf = async function* (context, next) { ... } + * } + * ``` + */ +export type HandlerOf = S extends Stage ? MiddlewareHandler : never + +/** + * Extracts the `MiddlewareNext` type from a stage token. + * Use this to type the `next` parameter in standalone middleware methods. + * + * @example + * ```typescript + * private async *_handler(context: ..., next: NextOf) { ... } + * ``` + */ +export type NextOf = S extends Stage ? MiddlewareNext : never diff --git a/strands-ts/src/types/agent.ts b/strands-ts/src/types/agent.ts index a352e4c8a..914ae68b4 100644 --- a/strands-ts/src/types/agent.ts +++ b/strands-ts/src/types/agent.ts @@ -26,6 +26,7 @@ import type { StreamEvent, } from '../hooks/events.js' import type { HookCallback, HookableEventConstructor, HookCallbackOptions, HookCleanup } from '../hooks/types.js' +import type { Stage, MiddlewareHandler } from '../middleware/types.js' import type { ToolRegistry } from '../registry/tool-registry.js' import type { Model } from '../models/model.js' import type { z } from 'zod' @@ -325,6 +326,7 @@ export interface LocalAgent { options?: HookCallbackOptions ): HookCleanup + /** /** * Captures a point-in-time snapshot of the agent's current state. * @@ -341,6 +343,18 @@ export interface LocalAgent { * @param snapshot - The snapshot to restore from */ loadSnapshot(snapshot: Snapshot): void + + /** + * Register a middleware handler for a given stage. + * Middleware wraps stage execution and can intercept, transform, or short-circuit operations. + * + * @param stage - The stage token identifying the interception point + * @param handler - The middleware handler function (async generator) + */ + addMiddleware( + stage: Stage, + handler: MiddlewareHandler, + ): void } /** From c7aa7eac27c7cd58eac6a63379dc846532319876 Mon Sep 17 00:00:00 2001 From: jackypc Date: Fri, 29 May 2026 15:20:13 -0400 Subject: [PATCH 2/4] fix: clean up lint warnings on rebased PR Resolves CI lint failures introduced after rebasing onto upstream/main: - Remove unused MiddlewareNext import in agent.ts - Remove unused ToolContext import in middleware-interrupts.test.ts - Remove unused require-yield eslint-disable directives in middleware tests --- strands-ts/src/agent/agent.ts | 2 +- .../middleware/__tests__/agent-middleware.test.ts | 2 +- .../__tests__/middleware-interrupts.test.ts | 13 ++++++------- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts index 8967ea381..d55d4ebdb 100644 --- a/strands-ts/src/agent/agent.ts +++ b/strands-ts/src/agent/agent.ts @@ -46,7 +46,7 @@ import { NullConversationManager } from '../conversation-manager/null-conversati import { ConversationManager } from '../conversation-manager/conversation-manager.js' import { HookRegistryImplementation } from '../hooks/registry.js' import { MiddlewareRegistry, InvokeModelStage, ExecuteToolStage, AgentStreamStage } from '../middleware/index.js' -import type { Stage, MiddlewareHandler, MiddlewareNext } from '../middleware/index.js' +import type { Stage, MiddlewareHandler } from '../middleware/index.js' import type { InvokeModelContext, InvokeModelResult, diff --git a/strands-ts/src/middleware/__tests__/agent-middleware.test.ts b/strands-ts/src/middleware/__tests__/agent-middleware.test.ts index 7a37c6494..8e66faab0 100644 --- a/strands-ts/src/middleware/__tests__/agent-middleware.test.ts +++ b/strands-ts/src/middleware/__tests__/agent-middleware.test.ts @@ -1080,7 +1080,7 @@ describe('Middleware use cases', () => { initAgent(agent: LocalAgent): void { const cache = this._cache - // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { const key = `${context.toolUse.name}:${JSON.stringify(context.toolUse.input)}` const cached = cache.get(key) diff --git a/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts b/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts index 2dacd0cab..edb061097 100644 --- a/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts +++ b/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts @@ -6,7 +6,6 @@ import { createMockTool } from '../../__fixtures__/tool-helpers.js' import { ExecuteToolStage, AgentStreamStage } from '../stages.js' import { TextBlock, ToolResultBlock } from '../../types/messages.js' import { InterruptResponseContent } from '../../types/interrupt.js' -import type { ToolContext } from '../../tools/tool.js' describe('Middleware interrupts', () => { describe('ExecuteToolStage', () => { @@ -18,7 +17,7 @@ describe('Middleware interrupts', () => { const tool = createMockTool('dangerousTool', () => 'executed') const agent = new Agent({ model, tools: [tool], printer: false }) - // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { context.interrupt({ name: 'approve_tool', reason: 'Confirm execution?' }) return yield* next(context) @@ -45,7 +44,7 @@ describe('Middleware interrupts', () => { const agent = new Agent({ model, tools: [tool], printer: false }) - // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { const approval = context.interrupt({ name: 'approve_tool', reason: 'Confirm?' }) if (approval !== 'yes') { @@ -85,7 +84,7 @@ describe('Middleware interrupts', () => { const tool = createMockTool('myTool', () => 'ok') const agent = new Agent({ model, tools: [tool], printer: false }) - // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { context.interrupt({ name: 'check' }) return yield* next(context) @@ -110,7 +109,7 @@ describe('Middleware interrupts', () => { const agent = new Agent({ model, tools: [tool], printer: false }) - // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { // Preemptive response: returns immediately without halting const approval = context.interrupt({ name: 'check', response: 'pre-approved' }) @@ -132,7 +131,7 @@ describe('Middleware interrupts', () => { const tool = createMockTool('myTool', () => 'ok') const agent = new Agent({ model, tools: [tool], printer: false }) - // eslint-disable-next-line require-yield + agent.addMiddleware(ExecuteToolStage, async function* (context, next) { // Spread context to modify toolUse, interrupt should still work const modified = { ...context, toolUse: { ...context.toolUse, input: { x: 2 } } } @@ -174,7 +173,7 @@ describe('Middleware interrupts', () => { agent.addMiddleware(AgentStreamStage, async function* (context, next) { const approval = context.interrupt({ name: 'gate', reason: 'Proceed?' }) if (approval !== 'go') { - // eslint-disable-next-line require-yield + return { result: { stopReason: 'endTurn' } } as never } return yield* next(context) From ca095877696184133c7a615cc31eb45947f0102e Mon Sep 17 00:00:00 2001 From: jackypc Date: Fri, 29 May 2026 15:40:37 -0400 Subject: [PATCH 3/4] fix: resolve CI failures on rebased PR - Apply prettier formatting to middleware module + agent.ts + types/agent.ts - Stop deep-matching the agent instance in InvokeModelContext test; the new ToolCaller Proxy on Agent breaks toMatchObject's recursive comparison. Use referential equality for agent and shape-match the rest of the context. --- strands-ts/src/agent/agent.ts | 20 ++- .../__tests__/agent-middleware.test.ts | 134 ++++++++---------- .../__tests__/custom-stages.test.ts | 13 +- .../__tests__/middleware-interrupts.test.ts | 12 +- .../src/middleware/__tests__/registry.test.ts | 21 +-- strands-ts/src/middleware/index.ts | 7 +- strands-ts/src/middleware/registry.ts | 6 +- strands-ts/src/middleware/stages.ts | 9 +- strands-ts/src/types/agent.ts | 2 +- 9 files changed, 88 insertions(+), 136 deletions(-) diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts index d55d4ebdb..194e7e95e 100644 --- a/strands-ts/src/agent/agent.ts +++ b/strands-ts/src/agent/agent.ts @@ -464,7 +464,7 @@ export class Agent implements LocalAgent, InvokableAgent { */ addMiddleware( stage: Stage, - handler: MiddlewareHandler, + handler: MiddlewareHandler ): void { this._middlewareRegistry.add(stage, handler) } @@ -779,7 +779,7 @@ export class Agent implements LocalAgent, InvokableAgent { async function* (ctx: AgentStreamContext): AsyncGenerator { const result = yield* self._streamWithResumeLoop(ctx.args, ctx.options) return { result } - }, + } ) return result } catch (error) { @@ -813,7 +813,7 @@ export class Agent implements LocalAgent, InvokableAgent { */ private async *_streamWithResumeLoop( args: InvokeArgs, - options?: InvokeOptions, + options?: InvokeOptions ): AsyncGenerator { let currentArgs: InvokeArgs = args @@ -1680,7 +1680,7 @@ export class Agent implements LocalAgent, InvokableAgent { iterResult = await gen.next() } return { result: iterResult.value } - }, + } ) return middlewareResult.result } @@ -2134,11 +2134,7 @@ export class Agent implements LocalAgent, InvokableAgent { } // Execute tool core logic through middleware chain - const middlewareResult = yield* this._executeToolWithMiddleware( - effectiveTool, - toolUse, - invocationState, - ) + const middlewareResult = yield* this._executeToolWithMiddleware(effectiveTool, toolUse, invocationState) const toolResult = middlewareResult.result const error = toolResult.error @@ -2166,7 +2162,7 @@ export class Agent implements LocalAgent, InvokableAgent { private async *_executeToolWithMiddleware( tool: Tool | undefined, toolUse: ToolUseData, - invocationState: InvocationState, + invocationState: InvocationState ): AsyncGenerator { const context: ExecuteToolContext = { agent: this, @@ -2204,14 +2200,14 @@ export class Agent implements LocalAgent, InvokableAgent { context, async function* (ctx: ExecuteToolContext): AsyncGenerator { return yield* self._executeToolCore(ctx.tool, ctx.toolUse, ctx.invocationState) - }, + } ) } private async *_executeToolCore( effectiveTool: Tool | undefined, toolUse: ToolUseData, - invocationState: InvocationState, + invocationState: InvocationState ): AsyncGenerator { // Start tool span within loop span context const toolSpan = this._tracer.startToolCallSpan({ diff --git a/strands-ts/src/middleware/__tests__/agent-middleware.test.ts b/strands-ts/src/middleware/__tests__/agent-middleware.test.ts index 8e66faab0..ee7ec5dcc 100644 --- a/strands-ts/src/middleware/__tests__/agent-middleware.test.ts +++ b/strands-ts/src/middleware/__tests__/agent-middleware.test.ts @@ -15,7 +15,13 @@ import type { MiddlewareHandler, HandlerOf } from '../types.js' import type { AgentStreamEvent, LocalAgent } from '../../types/agent.js' import type { Plugin } from '../../plugins/plugin.js' import { TextBlock, ToolResultBlock, Message } from '../../types/messages.js' -import { AfterToolCallEvent, BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, ContentBlockEvent } from '../../hooks/events.js' +import { + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + BeforeToolCallEvent, + ContentBlockEvent, +} from '../../hooks/events.js' import type { ToolContext } from '../../tools/tool.js' type ExecuteToolMiddleware = MiddlewareHandler @@ -48,7 +54,7 @@ describe('Agent middleware integration — InvokeModelStage', () => { toolUseId: 'tool-1', status: 'success' as const, content: [new TextBlock('ok')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false, systemPrompt: 'Be helpful' }) @@ -61,8 +67,8 @@ describe('Agent middleware integration — InvokeModelStage', () => { await agent.invoke('Test prompt') + expect(receivedContext?.agent).toBe(agent) expect(receivedContext).toMatchObject({ - agent, systemPrompt: 'Be helpful', messages: expect.arrayContaining([expect.any(Message)]), toolSpecs: expect.arrayContaining([expect.objectContaining({ name: 'testTool' })]), @@ -102,7 +108,7 @@ describe('Agent middleware integration — InvokeModelStage', () => { stopReason: 'endTurn' as const, }, } - }, + } ) const result = await agent.invoke('Test prompt') @@ -127,7 +133,7 @@ describe('Agent middleware integration — InvokeModelStage', () => { stopReason: 'endTurn' as const, }, } - }, + } ) await agent.invoke('Test prompt') @@ -141,19 +147,13 @@ describe('Agent middleware integration — InvokeModelStage', () => { const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) const agent = new Agent({ model, printer: false }) - agent.addMiddleware( - InvokeModelStage, - async function* (context, next) { - const modifiedContext: InvokeModelContext = { - ...context, - messages: [ - ...context.messages, - new Message({ role: 'user', content: [new TextBlock('Injected message')] }), - ], - } - return yield* next(modifiedContext) - }, - ) + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + const modifiedContext: InvokeModelContext = { + ...context, + messages: [...context.messages, new Message({ role: 'user', content: [new TextBlock('Injected message')] })], + } + return yield* next(modifiedContext) + }) const streamSpy = vi.spyOn(model, 'stream') @@ -168,16 +168,13 @@ describe('Agent middleware integration — InvokeModelStage', () => { const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) const agent = new Agent({ model, printer: false }) - agent.addMiddleware( - InvokeModelStage, - async function* (context, next) { - const modifiedContext: InvokeModelContext = { - ...context, - toolSpecs: [], - } - return yield* next(modifiedContext) - }, - ) + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + const modifiedContext: InvokeModelContext = { + ...context, + toolSpecs: [], + } + return yield* next(modifiedContext) + }) const streamSpy = vi.spyOn(model, 'stream') @@ -200,13 +197,10 @@ describe('Agent middleware integration — InvokeModelStage', () => { order.push('beforeModelCall') }) - agent.addMiddleware( - InvokeModelStage, - async function* (context, next) { - order.push('middleware') - return yield* next(context) - }, - ) + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + order.push('middleware') + return yield* next(context) + }) await agent.invoke('Test prompt') @@ -223,15 +217,12 @@ describe('Agent middleware integration — InvokeModelStage', () => { order.push('afterModelCall') }) - agent.addMiddleware( - InvokeModelStage, - async function* (context, next) { - order.push('middleware-start') - const result = yield* next(context) - order.push('middleware-end') - return result - }, - ) + agent.addMiddleware(InvokeModelStage, async function* (context, next) { + order.push('middleware-start') + const result = yield* next(context) + order.push('middleware-end') + return result + }) await agent.invoke('Test prompt') @@ -258,7 +249,7 @@ describe('Agent middleware integration — InvokeModelStage', () => { stopReason: 'endTurn' as const, }, } - }, + } ) await agent.invoke('Test prompt') @@ -291,7 +282,7 @@ describe('Agent middleware integration — InvokeModelStage', () => { toolUseId: 'tool-1', status: 'success' as const, content: [new TextBlock('Tool executed')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -425,12 +416,7 @@ describe('AgentStreamStage integration', () => { await collectGenerator(agent.stream('Test prompt')) - expect(callOrder).toStrictEqual([ - 'outer-before', - 'inner-before', - 'inner-after', - 'outer-after', - ]) + expect(callOrder).toStrictEqual(['outer-before', 'inner-before', 'inner-after', 'outer-after']) }) }) @@ -663,7 +649,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('executed')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -696,7 +682,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('ok')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -736,7 +722,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('ok')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -762,12 +748,7 @@ describe('ExecuteToolStage integration', () => { await agent.invoke('Use the tool') - expect(callOrder).toStrictEqual([ - 'outer-before', - 'inner-before', - 'inner-after', - 'outer-after', - ]) + expect(callOrder).toStrictEqual(['outer-before', 'inner-before', 'inner-after', 'outer-after']) }) }) @@ -783,7 +764,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('real result')], - }), + }) ) const tool = createMockTool('testTool', toolFn) @@ -823,7 +804,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('real')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -845,7 +826,7 @@ describe('ExecuteToolStage integration', () => { // The tool result message in conversation should contain the mocked result const toolResultMessage = agent.messages.find( - (m: Message) => m.role === 'user' && m.content.some((c) => c.type === 'toolResultBlock'), + (m: Message) => m.role === 'user' && m.content.some((c) => c.type === 'toolResultBlock') ) expect(toolResultMessage).toBeDefined() const toolResultBlock = toolResultMessage!.content.find((c: { type: string }) => c.type === 'toolResultBlock') @@ -854,7 +835,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('mocked data')], - }), + }) ) }) }) @@ -908,7 +889,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('ok')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -949,7 +930,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('executed')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -996,7 +977,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('real')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -1038,7 +1019,7 @@ describe('ExecuteToolStage integration', () => { toolUseId: 'tool-1', status: 'success', content: [new TextBlock('real')], - }), + }) ) const agent = new Agent({ model, tools: [tool], printer: false }) @@ -1080,7 +1061,6 @@ describe('Middleware use cases', () => { initAgent(agent: LocalAgent): void { const cache = this._cache - agent.addMiddleware(ExecuteToolStage, async function* (context, next) { const key = `${context.toolUse.name}:${JSON.stringify(context.toolUse.input)}` const cached = cache.get(key) @@ -1223,12 +1203,14 @@ describe('Middleware use cases', () => { ]) .addTurn({ type: 'textBlock', text: 'The answer is 42' }) - const tool = createMockTool('lookup', (ctx: ToolContext) => - new ToolResultBlock({ - toolUseId: ctx.toolUse.toolUseId, - status: 'success', - content: [new TextBlock('42')], - }), + const tool = createMockTool( + 'lookup', + (ctx: ToolContext) => + new ToolResultBlock({ + toolUseId: ctx.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('42')], + }) ) const agent = new Agent({ model, tools: [tool], plugins: [new StreamFinalTurnOnly()], printer: false }) diff --git a/strands-ts/src/middleware/__tests__/custom-stages.test.ts b/strands-ts/src/middleware/__tests__/custom-stages.test.ts index 7191dc67d..a95a31bd1 100644 --- a/strands-ts/src/middleware/__tests__/custom-stages.test.ts +++ b/strands-ts/src/middleware/__tests__/custom-stages.test.ts @@ -20,7 +20,7 @@ interface CustomResult { * Helper to collect all yielded events and the return value from an async generator. */ async function collect( - gen: AsyncGenerator, + gen: AsyncGenerator ): Promise<{ events: TEvent[]; result: TResult }> { const events: TEvent[] = [] let iterResult = await gen.next() @@ -70,10 +70,7 @@ describe('Third-party custom stages', () => { const { events, result } = await collect(chain({ label: 'test', count: 42 })) expect(callOrder).toStrictEqual(['middleware', 'terminal', 'middleware-after']) - expect(events).toStrictEqual([ - { kind: 'pre-test' }, - { kind: 'terminal-test' }, - ]) + expect(events).toStrictEqual([{ kind: 'pre-test' }, { kind: 'terminal-test' }]) expect(result).toStrictEqual({ summary: 'done-42' }) }) @@ -140,11 +137,7 @@ describe('Third-party custom stages', () => { const chain = registry.compose(CustomStage, terminal) const { events, result } = await collect(chain({ label: 'item', count: 5 })) - expect(events).toStrictEqual([ - { kind: 'log-start' }, - { kind: 'processed-10' }, - { kind: 'log-end' }, - ]) + expect(events).toStrictEqual([{ kind: 'log-start' }, { kind: 'processed-10' }, { kind: 'log-end' }]) expect(result).toStrictEqual({ summary: 'result-item-10' }) }) }) diff --git a/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts b/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts index edb061097..7585116a0 100644 --- a/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts +++ b/strands-ts/src/middleware/__tests__/middleware-interrupts.test.ts @@ -17,7 +17,6 @@ describe('Middleware interrupts', () => { const tool = createMockTool('dangerousTool', () => 'executed') const agent = new Agent({ model, tools: [tool], printer: false }) - agent.addMiddleware(ExecuteToolStage, async function* (context, next) { context.interrupt({ name: 'approve_tool', reason: 'Confirm execution?' }) return yield* next(context) @@ -44,7 +43,6 @@ describe('Middleware interrupts', () => { const agent = new Agent({ model, tools: [tool], printer: false }) - agent.addMiddleware(ExecuteToolStage, async function* (context, next) { const approval = context.interrupt({ name: 'approve_tool', reason: 'Confirm?' }) if (approval !== 'yes') { @@ -84,7 +82,6 @@ describe('Middleware interrupts', () => { const tool = createMockTool('myTool', () => 'ok') const agent = new Agent({ model, tools: [tool], printer: false }) - agent.addMiddleware(ExecuteToolStage, async function* (context, next) { context.interrupt({ name: 'check' }) return yield* next(context) @@ -109,7 +106,6 @@ describe('Middleware interrupts', () => { const agent = new Agent({ model, tools: [tool], printer: false }) - agent.addMiddleware(ExecuteToolStage, async function* (context, next) { // Preemptive response: returns immediately without halting const approval = context.interrupt({ name: 'check', response: 'pre-approved' }) @@ -131,7 +127,6 @@ describe('Middleware interrupts', () => { const tool = createMockTool('myTool', () => 'ok') const agent = new Agent({ model, tools: [tool], printer: false }) - agent.addMiddleware(ExecuteToolStage, async function* (context, next) { // Spread context to modify toolUse, interrupt should still work const modified = { ...context, toolUse: { ...context.toolUse, input: { x: 2 } } } @@ -161,9 +156,7 @@ describe('Middleware interrupts', () => { const { result } = await collectGenerator(agent.stream('Test')) expect(result.stopReason).toBe('interrupt') - expect(result.interrupts).toEqual([ - expect.objectContaining({ name: 'confirm_stream', reason: 'Are you sure?' }), - ]) + expect(result.interrupts).toEqual([expect.objectContaining({ name: 'confirm_stream', reason: 'Are you sure?' })]) }) it('middleware gets response on resume and continues', async () => { @@ -173,7 +166,6 @@ describe('Middleware interrupts', () => { agent.addMiddleware(AgentStreamStage, async function* (context, next) { const approval = context.interrupt({ name: 'gate', reason: 'Proceed?' }) if (approval !== 'go') { - return { result: { stopReason: 'endTurn' } } as never } return yield* next(context) @@ -190,7 +182,7 @@ describe('Middleware interrupts', () => { interruptId: interruptResult.interrupts![0]!.id, response: 'go', }), - ]), + ]) ) expect(finalResult.stopReason).toBe('endTurn') diff --git a/strands-ts/src/middleware/__tests__/registry.test.ts b/strands-ts/src/middleware/__tests__/registry.test.ts index a88e51fee..b7df780cd 100644 --- a/strands-ts/src/middleware/__tests__/registry.test.ts +++ b/strands-ts/src/middleware/__tests__/registry.test.ts @@ -23,7 +23,7 @@ const TestStage = createStage('test') * Helper to collect all yielded events and the return value from an async generator. */ async function collect( - gen: AsyncGenerator, + gen: AsyncGenerator ): Promise<{ events: TEvent[]; result: TResult }> { const events: TEvent[] = [] let iterResult = await gen.next() @@ -114,13 +114,7 @@ describe('MiddlewareRegistry', () => { const chain = registry.compose(TestStage, terminal) await collect(chain({ value: 'test' })) - expect(callOrder).toStrictEqual([ - 'outer-before', - 'inner-before', - 'terminal', - 'inner-after', - 'outer-after', - ]) + expect(callOrder).toStrictEqual(['outer-before', 'inner-before', 'terminal', 'inner-after', 'outer-after']) }) }) @@ -179,11 +173,7 @@ describe('MiddlewareRegistry', () => { const chain = registry.compose(TestStage, terminal) const { events, result } = await collect(chain({ value: 'test' })) - expect(events).toStrictEqual([ - { data: 'event-1' }, - { data: 'event-2' }, - { data: 'event-3' }, - ]) + expect(events).toStrictEqual([{ data: 'event-1' }, { data: 'event-2' }, { data: 'event-3' }]) expect(result).toStrictEqual({ output: 'terminal-result' }) }) }) @@ -219,10 +209,7 @@ describe('MiddlewareRegistry', () => { const chain = registry.compose(TestStage, terminal) const { events, result } = await collect(chain({ value: 'test' })) - expect(events).toStrictEqual([ - { data: 'keep-1' }, - { data: 'keep-2' }, - ]) + expect(events).toStrictEqual([{ data: 'keep-1' }, { data: 'keep-2' }]) expect(result).toStrictEqual({ output: 'done' }) }) }) diff --git a/strands-ts/src/middleware/index.ts b/strands-ts/src/middleware/index.ts index 9290de85c..29860ea5c 100644 --- a/strands-ts/src/middleware/index.ts +++ b/strands-ts/src/middleware/index.ts @@ -1,10 +1,5 @@ export type { Stage, MiddlewareNext, MiddlewareHandler, HandlerOf, NextOf } from './types.js' -export { - createStage, - InvokeModelStage, - ExecuteToolStage, - AgentStreamStage, -} from './stages.js' +export { createStage, InvokeModelStage, ExecuteToolStage, AgentStreamStage } from './stages.js' export type { InvokeModelContext, InvokeModelResult, diff --git a/strands-ts/src/middleware/registry.ts b/strands-ts/src/middleware/registry.ts index ca441a147..cd426417b 100644 --- a/strands-ts/src/middleware/registry.ts +++ b/strands-ts/src/middleware/registry.ts @@ -21,7 +21,7 @@ export class MiddlewareRegistry { */ add( stage: Stage, - handler: MiddlewareHandler, + handler: MiddlewareHandler ): void { const handlers = this._handlers.get(stage) ?? [] handlers.push(handler) @@ -39,7 +39,7 @@ export class MiddlewareRegistry { */ compose( stage: Stage, - terminal: MiddlewareNext, + terminal: MiddlewareNext ): MiddlewareNext { const handlers = (this._handlers.get(stage) ?? []) as MiddlewareHandler[] @@ -65,7 +65,7 @@ export class MiddlewareRegistry { invoke( stage: Stage, context: TContext, - terminal: MiddlewareNext, + terminal: MiddlewareNext ): AsyncGenerator { const chain = this.compose(stage, terminal) return chain(context) diff --git a/strands-ts/src/middleware/stages.ts b/strands-ts/src/middleware/stages.ts index a6bfa996f..0f68dc911 100644 --- a/strands-ts/src/middleware/stages.ts +++ b/strands-ts/src/middleware/stages.ts @@ -1,5 +1,12 @@ import type { Stage } from './types.js' -import type { LocalAgent, AgentStreamEvent, InvocationState, InvokeArgs, InvokeOptions, AgentResult } from '../types/agent.js' +import type { + LocalAgent, + AgentStreamEvent, + InvocationState, + InvokeArgs, + InvokeOptions, + AgentResult, +} from '../types/agent.js' import type { Message, SystemPrompt, ToolResultBlock } from '../types/messages.js' import type { ToolSpec, ToolChoice } from '../tools/types.js' import type { StateStore } from '../state-store.js' diff --git a/strands-ts/src/types/agent.ts b/strands-ts/src/types/agent.ts index 914ae68b4..e55a38dd6 100644 --- a/strands-ts/src/types/agent.ts +++ b/strands-ts/src/types/agent.ts @@ -353,7 +353,7 @@ export interface LocalAgent { */ addMiddleware( stage: Stage, - handler: MiddlewareHandler, + handler: MiddlewareHandler ): void } From 5a1179188e52a8cf0ba23593257fef3d1ae0a241 Mon Sep 17 00:00:00 2001 From: jackypc Date: Fri, 29 May 2026 16:16:32 -0400 Subject: [PATCH 4/4] fix: clean up rebase artifacts in agent.ts and types/agent.ts Two issues surfaced by the review agent: - Duplicate _interruptState.resume() call: stream() already processes interrupt responses before middleware runs (so context.interrupt() can see them); _stream() was still doing the same work afterwards. Drop the resume() call inside _stream() but keep the extraction since the interrupted-state guard depends on its length. - Stray /** in types/agent.ts at line 329, left over from the rebase merge between upstream's takeSnapshot/loadSnapshot block and the PR's addMiddleware block. Remove the duplicate doc-comment opener. --- strands-ts/src/agent/agent.ts | 7 +++---- strands-ts/src/types/agent.ts | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts index 194e7e95e..99dd1198a 100644 --- a/strands-ts/src/agent/agent.ts +++ b/strands-ts/src/agent/agent.ts @@ -1035,11 +1035,10 @@ export class Agent implements LocalAgent, InvokableAgent { // agent loop cycles within this invocation. const invocationState: InvocationState = options?.invocationState ?? {} - // Handle interrupt responses if present in input + // Interrupt responses are already consumed in stream() before middleware + // runs (so middleware-level interrupt() can find them). Re-extract here + // to gate the "non-interrupt input while interrupted" check below. const interruptResponses = this._extractInterruptResponses(args) - if (interruptResponses.length > 0) { - this._interruptState.resume(interruptResponses) - } // Reject non-interrupt input while in interrupted state if (this._interruptState.activated && interruptResponses.length === 0) { diff --git a/strands-ts/src/types/agent.ts b/strands-ts/src/types/agent.ts index e55a38dd6..7a4adf6a7 100644 --- a/strands-ts/src/types/agent.ts +++ b/strands-ts/src/types/agent.ts @@ -326,7 +326,6 @@ export interface LocalAgent { options?: HookCallbackOptions ): HookCleanup - /** /** * Captures a point-in-time snapshot of the agent's current state. *