Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.

Commit 74f8728

Browse files
committed
Issue #22: Add Non-Streaming Mode Support to BedrockModelProvider
1 parent 9c845a3 commit 74f8728

2 files changed

Lines changed: 244 additions & 36 deletions

File tree

src/models/__tests__/bedrock.test.ts

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,49 @@ function setupMockSend(streamGenerator: () => AsyncGenerator<unknown>): void {
3131
}
3232

3333
// Mock the AWS SDK
34-
vi.mock('@aws-sdk/client-bedrock-runtime', () => {
35-
const mockSend = vi.fn(
36-
async (): Promise<{ stream: AsyncIterable<unknown> }> => ({
37-
stream: (async function* (): AsyncGenerator<unknown> {
38-
yield { messageStart: { role: 'assistant' } }
39-
yield { contentBlockStart: { contentBlockIndex: 0 } }
40-
yield { contentBlockDelta: { delta: { text: 'Hello' }, contentBlockIndex: 0 } }
41-
yield { contentBlockStop: { contentBlockIndex: 0 } }
42-
yield { messageStop: { stopReason: 'end_turn' } }
43-
yield {
44-
metadata: {
45-
usage: {
46-
inputTokens: 10,
47-
outputTokens: 5,
48-
totalTokens: 15,
49-
},
50-
metrics: {
51-
latencyMs: 100,
34+
vi.mock('@aws-sdk/client-bedrock-runtime', async () => {
35+
// Mock command classes that the code under test will instantiate
36+
const ConverseStreamCommand = vi.fn()
37+
const ConverseCommand = vi.fn()
38+
39+
const mockSend = vi.fn(async (command: unknown) => {
40+
// Check which constructor was used to create the command object
41+
if (command instanceof ConverseStreamCommand) {
42+
// Return a streaming response
43+
return {
44+
stream: (async function* (): AsyncGenerator<unknown> {
45+
yield { messageStart: { role: 'assistant' } }
46+
yield { contentBlockStart: { contentBlockIndex: 0 } }
47+
yield { contentBlockDelta: { delta: { text: 'Hello' }, contentBlockIndex: 0 } }
48+
yield { contentBlockStop: { contentBlockIndex: 0 } }
49+
yield { messageStop: { stopReason: 'end_turn' } }
50+
yield {
51+
metadata: {
52+
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
53+
metrics: { latencyMs: 100 },
5254
},
55+
}
56+
})(),
57+
}
58+
}
59+
60+
if (command instanceof ConverseCommand) {
61+
// Return a non-streaming (full) response for the non-streaming API
62+
return {
63+
output: {
64+
message: {
65+
role: 'assistant',
66+
content: [{ text: 'Hello' }],
5367
},
54-
}
55-
})(),
56-
})
57-
)
68+
},
69+
stopReason: 'end_turn',
70+
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 },
71+
metrics: { latencyMs: 100 },
72+
}
73+
}
74+
75+
throw new Error('Unhandled command type in mock')
76+
})
5877

5978
// Create a mock ValidationException class
6079
class MockValidationException extends Error {
@@ -68,7 +87,8 @@ vi.mock('@aws-sdk/client-bedrock-runtime', () => {
6887
BedrockRuntimeClient: vi.fn().mockImplementation(() => ({
6988
send: mockSend,
7089
})),
71-
ConverseStreamCommand: vi.fn(),
90+
ConverseStreamCommand,
91+
ConverseCommand,
7292
ValidationException: MockValidationException,
7393
}
7494
})
@@ -400,7 +420,49 @@ describe('BedrockModel', () => {
400420
})
401421

402422
describe('stream', () => {
403-
it('yields and validate events', async () => {
423+
it('yields and validates nonstreaming events', async () => {
424+
const provider = new BedrockModel({ stream: false })
425+
const messages: Message[] = [{ role: 'user', content: [{ type: 'textBlock', text: 'Hello' }] }]
426+
427+
const events = await collectEvents(provider.stream(messages))
428+
429+
expect(events).toStrictEqual([
430+
{
431+
role: 'assistant',
432+
type: 'modelMessageStartEvent',
433+
},
434+
{
435+
type: 'modelContentBlockStartEvent',
436+
},
437+
{
438+
delta: {
439+
text: 'Hello',
440+
type: 'textDelta',
441+
},
442+
type: 'modelContentBlockDeltaEvent',
443+
},
444+
{
445+
type: 'modelContentBlockStopEvent',
446+
},
447+
{
448+
stopReason: 'endTurn',
449+
type: 'modelMessageStopEvent',
450+
},
451+
{
452+
metrics: {
453+
latencyMs: 100,
454+
},
455+
type: 'modelMetadataEvent',
456+
usage: {
457+
inputTokens: 10,
458+
outputTokens: 5,
459+
totalTokens: 15,
460+
},
461+
},
462+
])
463+
})
464+
465+
it('yields and validates streaming events', async () => {
404466
const provider = new BedrockModel()
405467
const messages: Message[] = [{ role: 'user', content: [{ type: 'textBlock', text: 'Hello' }] }]
406468

src/models/bedrock.ts

Lines changed: 158 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import {
2525
type ConverseStreamMetadataEvent as BedrockConverseStreamMetadataEvent,
2626
ContentBlockDelta,
2727
type ToolConfiguration,
28+
ConverseCommand,
29+
type ConverseCommandOutput,
2830
} from '@aws-sdk/client-bedrock-runtime'
2931
import type { Model, BaseModelConfig, StreamOptions } from '../models/model'
3032
import type { Message, ContentBlock } from '../types/messages'
@@ -136,6 +138,16 @@ export interface BedrockModelConfig extends BaseModelConfig {
136138
* @see https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/client/bedrock-runtime/command/ConverseStreamCommand/
137139
*/
138140
additionalArgs?: JSONValue
141+
142+
/**
143+
* Whether or not to stream responses from the model.
144+
*
145+
* This will use the ConverseStream API instead of the Converse API.
146+
*
147+
* @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
148+
* @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html
149+
*/
150+
stream?: boolean
139151
}
140152

141153
/**
@@ -308,19 +320,27 @@ export class BedrockModel implements Model<BedrockModelConfig, BedrockRuntimeCli
308320
// Format the request for Bedrock
309321
const request = this._formatRequest(messages, options)
310322

311-
// Create and send the command
312-
const command = new ConverseStreamCommand(request)
313-
const response = await this._client.send(command)
314-
315-
// Stream the response
316-
if (response.stream) {
317-
for await (const chunk of response.stream) {
318-
// Map Bedrock events to SDK events
319-
const events = this._mapBedrockEventToSDKEvents(chunk)
320-
for (const event of events) {
321-
yield event
323+
if (this._config.stream !== false) {
324+
// Create and send the command
325+
const command = new ConverseStreamCommand(request)
326+
const response = await this._client.send(command)
327+
328+
// Stream the response
329+
if (response.stream) {
330+
for await (const chunk of response.stream) {
331+
// Map Bedrock events to SDK events
332+
const events = this._mapStreamedBedrockEventToSDKEvent(chunk)
333+
for (const event of events) {
334+
yield event
335+
}
322336
}
323337
}
338+
} else {
339+
const command = new ConverseCommand(request)
340+
const response = await this._client.send(command)
341+
for (const event of this._mapBedrockEventToSDKEvent(response)) {
342+
yield event
343+
}
324344
}
325345
} catch (error) {
326346
const err = error as Error
@@ -497,13 +517,139 @@ export class BedrockModel implements Model<BedrockModelConfig, BedrockRuntimeCli
497517
}
498518
}
499519

520+
private _mapBedrockEventToSDKEvent(event: ConverseCommandOutput): ModelStreamEvent[] {
521+
const events: ModelStreamEvent[] = []
522+
523+
// Message start
524+
const output = ensureDefined(event.output, 'event.output')
525+
const message = ensureDefined(output.message, 'output.message')
526+
const role = ensureDefined(message.role, 'message.role')
527+
events.push({
528+
type: 'modelMessageStartEvent',
529+
role,
530+
})
531+
532+
// Match on content blocks
533+
const content = ensureDefined(message.content, 'message.content')
534+
content.forEach((block, index) => {
535+
if (block.text) {
536+
events.push({
537+
type: 'modelContentBlockStartEvent',
538+
})
539+
540+
events.push({
541+
type: 'modelContentBlockDeltaEvent',
542+
delta: {
543+
type: 'textDelta',
544+
text: block.text,
545+
},
546+
})
547+
548+
events.push({
549+
type: 'modelContentBlockStopEvent',
550+
})
551+
} else if (block.toolUse) {
552+
events.push({
553+
type: 'modelContentBlockStartEvent',
554+
contentBlockIndex: index,
555+
start: {
556+
type: 'toolUseStart',
557+
name: ensureDefined(block.toolUse.name, 'toolUse.name'),
558+
toolUseId: ensureDefined(block.toolUse.toolUseId, 'toolUse.toolUseId'),
559+
},
560+
})
561+
562+
events.push({
563+
type: 'modelContentBlockDeltaEvent',
564+
contentBlockIndex: index,
565+
delta: {
566+
type: 'toolUseInputDelta',
567+
input: JSON.stringify(ensureDefined(block.toolUse.input, 'toolUse.input')),
568+
},
569+
})
570+
571+
events.push({
572+
type: 'modelContentBlockStopEvent',
573+
contentBlockIndex: index,
574+
})
575+
} else if (block.reasoningContent) {
576+
const reasoningText = ensureDefined(block.reasoningContent.reasoningText, 'reasoningContent.reasoningText')
577+
events.push({
578+
type: 'modelContentBlockDeltaEvent',
579+
contentBlockIndex: index,
580+
delta: {
581+
type: 'reasoningDelta',
582+
text: ensureDefined(reasoningText.text, 'reasoningText.text'),
583+
},
584+
})
585+
586+
if (reasoningText.signature) {
587+
events.push({
588+
type: 'modelContentBlockDeltaEvent',
589+
contentBlockIndex: index,
590+
delta: {
591+
type: 'reasoningDelta',
592+
signature: reasoningText.signature,
593+
},
594+
})
595+
}
596+
597+
events.push({
598+
type: 'modelContentBlockStopEvent',
599+
contentBlockIndex: index,
600+
})
601+
}
602+
})
603+
604+
const stopReasonRaw = ensureDefined(event.stopReason, 'event.stopReason') as string
605+
let mappedStopReason: string
606+
607+
if (stopReasonRaw in STOP_REASON_MAP) {
608+
mappedStopReason = STOP_REASON_MAP[stopReasonRaw as keyof typeof STOP_REASON_MAP]
609+
} else {
610+
console.warn(`Unknown stop reason: "${stopReasonRaw}". Converting to camelCase: "${snakeToCamel(stopReasonRaw)}"`)
611+
mappedStopReason = snakeToCamel(stopReasonRaw) // Assumes snakeToCamel utility exists
612+
}
613+
614+
// Adjust for tool_use, which is sometimes incorrectly reported as end_turn
615+
if (mappedStopReason === 'endTurn' && event.output?.message?.content?.some((block) => 'toolUse' in block)) {
616+
mappedStopReason = 'toolUse'
617+
console.warn(`Adjusting stop reason from 'end_turn' to 'tool_use' due to tool use in content blocks.`)
618+
}
619+
620+
events.push({
621+
type: 'modelMessageStopEvent',
622+
stopReason: mappedStopReason,
623+
})
624+
625+
const usage = ensureDefined(event.usage, 'output.usage')
626+
const metadataEvent: ModelStreamEvent = {
627+
type: 'modelMetadataEvent',
628+
usage: {
629+
inputTokens: ensureDefined(usage.inputTokens, 'usage.inputTokens'),
630+
outputTokens: ensureDefined(usage.outputTokens, 'usage.outputTokens'),
631+
totalTokens: ensureDefined(usage.totalTokens, 'usage.totalTokens'),
632+
},
633+
}
634+
635+
if (event.metrics) {
636+
metadataEvent.metrics = {
637+
latencyMs: ensureDefined(event.metrics.latencyMs, 'metrics.latencyMs'),
638+
}
639+
}
640+
641+
events.push(metadataEvent)
642+
643+
return events
644+
}
645+
500646
/**
501647
* Maps a Bedrock event to SDK streaming events.
502648
*
503649
* @param chunk - Bedrock event chunk
504650
* @returns Array of SDK streaming events
505651
*/
506-
private _mapBedrockEventToSDKEvents(chunk: ConverseStreamOutput): ModelStreamEvent[] {
652+
private _mapStreamedBedrockEventToSDKEvent(chunk: ConverseStreamOutput): ModelStreamEvent[] {
507653
const events: ModelStreamEvent[] = []
508654

509655
// Extract the event type key

0 commit comments

Comments
 (0)