diff --git a/AGENTS.md b/AGENTS.md index 67a970348..2d0bc1b4d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -78,10 +78,14 @@ sdk-typescript/ │ │ │ ├── anthropic.ts # Anthropic Claude │ │ │ ├── bedrock.ts # AWS Bedrock │ │ │ ├── vercel.ts # Vercel AI SDK +│ │ │ ├── webllm/ # WebLLM (on-device via WebGPU) provider +│ │ │ │ ├── __tests__/ # model.test.ts, cache.test.node.ts, browser.test.browser.ts +│ │ │ │ ├── cache.ts # Download / list / check / evict helpers + errors +│ │ │ │ ├── model.ts # WebLLMModel (runtime streaming + message formatting) +│ │ │ │ └── index.ts │ │ │ ├── defaults.ts # Centralized model defaults + warning messages │ │ │ ├── model.ts # Base model interface │ │ │ └── streaming.ts # Streaming event types -│ │ │ │ │ ├── multiagent/ # Multi-agent orchestration │ │ │ ├── __tests__/ │ │ │ ├── graph.ts # Graph orchestrator (DAG) @@ -275,7 +279,7 @@ sdk-typescript/ - **`strands-ts/src/conversation-manager/`**: Conversation history management strategies - **`strands-ts/src/hooks/`**: Hooks system for event-driven extensibility - **`strands-ts/src/logging/`**: Structured logging utilities -- **`strands-ts/src/models/`**: Model provider implementations (Bedrock, Anthropic, OpenAI, Google, Vercel) +- **`strands-ts/src/models/`**: Model provider implementations (Bedrock, Anthropic, OpenAI, Google, Vercel, WebLLM) - **`strands-ts/src/multiagent/`**: Multi-agent orchestration patterns (Graph for DAG execution, Swarm for handoff-based routing) - **`strands-ts/src/plugins/`**: Plugin system for extending agent functionality - **`strands-ts/src/registry/`**: Tool registry implementation diff --git a/strands-ts/package.json b/strands-ts/package.json index 86d738b92..a0af1bae5 100644 --- a/strands-ts/package.json +++ b/strands-ts/package.json @@ -36,6 +36,10 @@ "types": "./dist/src/models/vercel.d.ts", "default": "./dist/src/models/vercel.js" }, + "./models/webllm": { + "types": "./dist/src/models/webllm/index.d.ts", + "default": "./dist/src/models/webllm/index.js" + }, "./multiagent": { "types": "./dist/src/multiagent/index.d.ts", "default": "./dist/src/multiagent/index.js" @@ -130,6 +134,7 @@ "@aws-sdk/credential-providers": "^3.943.0", "@eslint/js": "^9.39.4", "@google/genai": "^1.40.0", + "@mlc-ai/web-llm": "^0.2.79", "@opentelemetry/api": "^1.9.0", "@opentelemetry/exporter-metrics-otlp-http": "^0.214.0", "@opentelemetry/exporter-trace-otlp-http": "^0.214.0", @@ -178,6 +183,7 @@ "@aws-sdk/client-s3": "^3.943.0", "@google/genai": "^1.40.0", "@modelcontextprotocol/sdk": "^1.25.2", + "@mlc-ai/web-llm": "^0.2.79", "@opentelemetry/api": "^1.9.0", "@opentelemetry/exporter-metrics-otlp-http": "^0.214.0", "@opentelemetry/exporter-trace-otlp-http": "^0.214.0", @@ -208,6 +214,9 @@ "@google/genai": { "optional": true }, + "@mlc-ai/web-llm": { + "optional": true + }, "openai": { "optional": true }, diff --git a/strands-ts/src/models/webllm/__tests__/browser.test.browser.ts b/strands-ts/src/models/webllm/__tests__/browser.test.browser.ts new file mode 100644 index 000000000..20785fc0d --- /dev/null +++ b/strands-ts/src/models/webllm/__tests__/browser.test.browser.ts @@ -0,0 +1,27 @@ +// ABOUTME: Browser-only smoke test for the WebLLM provider. +// ABOUTME: Verifies the public module imports cleanly and listWebLLMModels works +// ABOUTME: against the real @mlc-ai/web-llm prebuilt app config in a browser. + +import { describe, it, expect } from 'vitest' +import { isBrowser } from '../../../__fixtures__/environment.js' +import { WebLLMModel, listWebLLMModels } from '../index.js' + +describe('WebLLM browser smoke', () => { + it('runs in a browser environment', () => { + expect(isBrowser).toBe(true) + }) + + it('exposes WebLLMModel as a constructor', () => { + expect(typeof WebLLMModel).toBe('function') + const model = new WebLLMModel({ modelId: 'Llama-3.1-8B-Instruct-q4f32_1-MLC' }) + expect(model.getConfig().modelId).toBe('Llama-3.1-8B-Instruct-q4f32_1-MLC') + }) + + it('lists prebuilt models', async () => { + const models = await listWebLLMModels() + expect(models.length).toBeGreaterThan(0) + expect(models[0]).toHaveProperty('modelId') + expect(models[0]).toHaveProperty('modelUrl') + expect(models[0]).toHaveProperty('modelLib') + }) +}) diff --git a/strands-ts/src/models/webllm/__tests__/cache.test.node.ts b/strands-ts/src/models/webllm/__tests__/cache.test.node.ts new file mode 100644 index 000000000..57b2fa1f2 --- /dev/null +++ b/strands-ts/src/models/webllm/__tests__/cache.test.node.ts @@ -0,0 +1,210 @@ +// ABOUTME: Unit tests for WebLLM cache / download helpers. +// ABOUTME: The `@mlc-ai/web-llm` module is mocked so these run in node without WebGPU. + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import type { MockedFunction } from 'vitest' +import { + deleteWebLLMModel, + downloadWebLLMModel, + isWebLLMModelCached, + listWebLLMModels, + resetWebLLMModuleCache, + WebLLMModelNotFoundError, + WebLLMUnavailableError, +} from '../cache.js' + +// A minimal mock of the `@mlc-ai/web-llm` module surface we depend on. +const mockPrebuiltAppConfig = { + model_list: [ + { + model_id: 'test-model', + model: 'https://example.com/test-model', + model_lib: 'https://example.com/test-model.wasm', + vram_required_MB: 2048, + }, + { + model_id: 'other-model', + model: 'https://example.com/other', + model_lib: 'https://example.com/other.wasm', + }, + ], +} + +const mockCreateEngine = vi.fn( + async ( + _modelId: string | string[], + _engineConfig?: { initProgressCallback?: (report: unknown) => void }, + _chatOpts?: unknown + ) => ({ + unload: vi.fn(async () => undefined), + chat: { completions: { create: vi.fn() } }, + }) +) +const mockHasModelInCache = vi.fn(async () => false) +const mockDeleteModelAllInfoInCache = vi.fn(async () => undefined) + +vi.mock('@mlc-ai/web-llm', () => ({ + CreateMLCEngine: mockCreateEngine, + prebuiltAppConfig: mockPrebuiltAppConfig, + hasModelInCache: mockHasModelInCache, + deleteModelAllInfoInCache: mockDeleteModelAllInfoInCache, +})) + +// Fake out the browser environment check so these helpers run in node. +const originalWindow = globalThis.window +beforeEach(() => { + ;(globalThis as { window?: unknown }).window = {} as unknown + vi.clearAllMocks() + resetWebLLMModuleCache() + mockHasModelInCache.mockResolvedValue(false) + mockDeleteModelAllInfoInCache.mockResolvedValue(undefined) + mockCreateEngine.mockImplementation(async () => ({ + unload: vi.fn(async () => undefined), + chat: { completions: { create: vi.fn() } }, + })) +}) +afterEach(() => { + if (originalWindow === undefined) { + delete (globalThis as { window?: unknown }).window + } else { + ;(globalThis as { window?: unknown }).window = originalWindow + } +}) + +describe('isWebLLMModelCached', () => { + it('returns true when the model is in cache', async () => { + mockHasModelInCache.mockResolvedValueOnce(true) + const result = await isWebLLMModelCached('test-model') + expect(result).toBe(true) + expect(mockHasModelInCache).toHaveBeenCalledWith('test-model', mockPrebuiltAppConfig) + }) + + it('returns false when the model is not in cache', async () => { + const result = await isWebLLMModelCached('test-model') + expect(result).toBe(false) + }) + + it('returns false when hasModelInCache throws (treats as not cached)', async () => { + mockHasModelInCache.mockRejectedValueOnce(new Error('storage error')) + const result = await isWebLLMModelCached('test-model') + expect(result).toBe(false) + }) + + it('throws WebLLMModelNotFoundError for unknown modelId', async () => { + await expect(isWebLLMModelCached('nonexistent-model')).rejects.toBeInstanceOf(WebLLMModelNotFoundError) + }) + + it('throws WebLLMUnavailableError when not in browser environment', async () => { + delete (globalThis as { window?: unknown }).window + await expect(isWebLLMModelCached('test-model')).rejects.toBeInstanceOf(WebLLMUnavailableError) + }) +}) + +describe('deleteWebLLMModel', () => { + it('delegates to deleteModelAllInfoInCache', async () => { + await deleteWebLLMModel('test-model') + expect(mockDeleteModelAllInfoInCache).toHaveBeenCalledWith('test-model', mockPrebuiltAppConfig) + }) + + it('throws for unknown model', async () => { + await expect(deleteWebLLMModel('nonexistent')).rejects.toBeInstanceOf(WebLLMModelNotFoundError) + }) +}) + +describe('listWebLLMModels', () => { + it('returns all models from prebuiltAppConfig', async () => { + const models = await listWebLLMModels() + expect(models).toHaveLength(2) + expect(models[0]).toEqual({ + modelId: 'test-model', + modelUrl: 'https://example.com/test-model', + modelLib: 'https://example.com/test-model.wasm', + vramMB: 2048, + }) + expect(models[1]).toEqual({ + modelId: 'other-model', + modelUrl: 'https://example.com/other', + modelLib: 'https://example.com/other.wasm', + }) + }) + + it('uses custom appConfig when provided', async () => { + const custom = { + model_list: [{ model_id: 'custom', model: 'x', model_lib: 'y' }], + } + const models = await listWebLLMModels(custom as never) + expect(models).toEqual([{ modelId: 'custom', modelUrl: 'x', modelLib: 'y' }]) + }) +}) + +describe('downloadWebLLMModel', () => { + it('creates a temporary engine and unloads it after load', async () => { + const unload = vi.fn(async () => undefined) + mockCreateEngine.mockImplementationOnce(async () => ({ + unload, + chat: { completions: { create: vi.fn() } }, + })) + await downloadWebLLMModel({ modelId: 'test-model' }) + expect(mockCreateEngine).toHaveBeenCalledTimes(1) + expect(mockCreateEngine).toHaveBeenCalledWith('test-model', { appConfig: mockPrebuiltAppConfig }, undefined) + expect(unload).toHaveBeenCalledTimes(1) + }) + + it('forwards onProgress as the engine initProgressCallback', async () => { + const onProgress = vi.fn() + const unload = vi.fn(async () => undefined) + mockCreateEngine.mockImplementationOnce(async (_modelId, engineConfig) => { + ;(engineConfig as { initProgressCallback?: (r: unknown) => void }).initProgressCallback?.({ + progress: 0.5, + text: 'loading', + timeElapsed: 1, + }) + return { unload, chat: { completions: { create: vi.fn() } } } + }) + await downloadWebLLMModel({ modelId: 'test-model', onProgress }) + expect(onProgress).toHaveBeenCalledWith({ progress: 0.5, text: 'loading', timeElapsed: 1 }) + }) + + it('throws AbortError when signal is already aborted', async () => { + const controller = new AbortController() + controller.abort() + await expect(downloadWebLLMModel({ modelId: 'test-model', signal: controller.signal })).rejects.toMatchObject({ + name: 'AbortError', + }) + expect(mockCreateEngine).not.toHaveBeenCalled() + }) + + it('throws AbortError when aborted mid-download', async () => { + const controller = new AbortController() + const unload = vi.fn(async () => undefined) + mockCreateEngine.mockImplementationOnce(async () => { + controller.abort() + return { unload, chat: { completions: { create: vi.fn() } } } + }) + await expect(downloadWebLLMModel({ modelId: 'test-model', signal: controller.signal })).rejects.toMatchObject({ + name: 'AbortError', + }) + expect(unload).toHaveBeenCalled() + }) + + it('throws when model is not in app config', async () => { + await expect(downloadWebLLMModel({ modelId: 'nonexistent' })).rejects.toBeInstanceOf(WebLLMModelNotFoundError) + }) + + it('surfaces engine errors via normalizeError', async () => { + mockCreateEngine.mockImplementationOnce(async () => { + throw new Error('webgpu unavailable') + }) + await expect(downloadWebLLMModel({ modelId: 'test-model' })).rejects.toThrow('webgpu unavailable') + }) +}) + +describe('loadWebLLMModule error handling', () => { + it('throws WebLLMUnavailableError when environment is not a browser', async () => { + delete (globalThis as { window?: unknown }).window + await expect(downloadWebLLMModel({ modelId: 'test-model' })).rejects.toBeInstanceOf(WebLLMUnavailableError) + }) +}) + +// Silence unused-helper lint noise +export type _Unused = MockedFunction diff --git a/strands-ts/src/models/webllm/__tests__/model.test.ts b/strands-ts/src/models/webllm/__tests__/model.test.ts new file mode 100644 index 000000000..609c9aa40 --- /dev/null +++ b/strands-ts/src/models/webllm/__tests__/model.test.ts @@ -0,0 +1,411 @@ +// ABOUTME: Unit tests for the WebLLM model provider using a mocked MLC engine. +// ABOUTME: Runs in both node and browser environments (engine is fully mocked). + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import type { Mock } from 'vitest' +import { WebLLMModel } from '../model.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock, JsonBlock } from '../../../types/messages.js' +import { collectIterator } from '../../../__fixtures__/model-test-helpers.js' +import { warnOnce } from '../../../logging/warn-once.js' +import { resetWebLLMModuleCache } from '../cache.js' + +type CreateMock = Mock<(req: unknown) => Promise>> + +/** + * Builds a mock MLCEngineInterface-compatible object whose + * `chat.completions.create` yields the given chunks. + */ +function makeMockEngine(chunks: unknown[]): { + engine: { + chat: { completions: { create: CreateMock } } + unload: Mock + reload: Mock + } + create: CreateMock +} { + const create = vi.fn(async () => { + async function* gen(): AsyncGenerator { + for (const c of chunks) yield c + } + return gen() + }) as unknown as CreateMock + const engine = { + chat: { completions: { create } }, + unload: vi.fn(async () => undefined), + reload: vi.fn(async () => undefined), + } + return { engine, create } +} + +vi.mock('../../../logging/warn-once.js', () => ({ + warnOnce: vi.fn(), +})) + +describe('WebLLMModel', () => { + beforeEach(() => { + vi.clearAllMocks() + resetWebLLMModuleCache() + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('constructor', () => { + it('creates an instance with default modelId warning', () => { + new WebLLMModel() + expect(warnOnce).toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default WebLLM modelId') + ) + }) + + it('does not warn when modelId is explicitly set', () => { + new WebLLMModel({ modelId: 'Phi-3.5-mini-instruct-q4f16_1-MLC' }) + expect(warnOnce).not.toHaveBeenCalled() + }) + + it('stores provided config', () => { + const model = new WebLLMModel({ + modelId: 'custom-model', + temperature: 0.3, + maxTokens: 512, + topP: 0.9, + }) + expect(model.getConfig()).toStrictEqual({ + modelId: 'custom-model', + temperature: 0.3, + maxTokens: 512, + topP: 0.9, + }) + }) + + it('accepts an external engine without triggering module load', async () => { + const { engine } = makeMockEngine([]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + expect(model.getConfig().modelId).toBe('test') + }) + }) + + describe('updateConfig', () => { + it('merges new config with existing config', () => { + const model = new WebLLMModel({ modelId: 'm', temperature: 0.5 }) + model.updateConfig({ temperature: 0.8, maxTokens: 1024 }) + expect(model.getConfig()).toStrictEqual({ + modelId: 'm', + temperature: 0.8, + maxTokens: 1024, + }) + }) + }) + + describe('stream', () => { + it('yields correct events for a simple text response', async () => { + const { engine } = makeMockEngine([ + { choices: [{ delta: { role: 'assistant' } }] }, + { choices: [{ delta: { content: 'Hello' } }] }, + { choices: [{ delta: { content: ' world' } }] }, + { choices: [{ delta: {}, finish_reason: 'stop' }] }, + { choices: [], usage: { prompt_tokens: 4, completion_tokens: 2, total_tokens: 6 } }, + ]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(model.stream(messages)) + + expect(events).toEqual([ + { type: 'modelMessageStartEvent', role: 'assistant' }, + { type: 'modelContentBlockStartEvent' }, + { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'Hello' } }, + { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: ' world' } }, + { type: 'modelContentBlockStopEvent' }, + { type: 'modelMetadataEvent', usage: { inputTokens: 4, outputTokens: 2, totalTokens: 6 } }, + { type: 'modelMessageStopEvent', stopReason: 'endTurn' }, + ]) + }) + + it('emits metadata before message stop when usage arrives mid-stream', async () => { + // Some WebLLM builds emit usage on the same chunk as finish_reason. + const { engine } = makeMockEngine([ + { choices: [{ delta: { role: 'assistant' } }] }, + { choices: [{ delta: { content: 'Hi' } }] }, + { + choices: [{ delta: {}, finish_reason: 'stop' }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }, + ]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + const events = await collectIterator( + model.stream([new Message({ role: 'user', content: [new TextBlock('Hi')] })]) + ) + const metaIndex = events.findIndex((e) => e.type === 'modelMetadataEvent') + const stopIndex = events.findIndex((e) => e.type === 'modelMessageStopEvent') + expect(metaIndex).toBeGreaterThan(-1) + expect(stopIndex).toBeGreaterThan(metaIndex) + }) + + it('maps tool calls to content block start/delta/stop events', async () => { + const { engine } = makeMockEngine([ + { choices: [{ delta: { role: 'assistant' } }] }, + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'tool_1', + type: 'function', + function: { name: 'add', arguments: '{"a":1' }, + }, + ], + }, + }, + ], + }, + { + choices: [ + { + delta: { + tool_calls: [{ index: 0, function: { arguments: ',"b":2}' } }], + }, + }, + ], + }, + { choices: [{ delta: {}, finish_reason: 'tool_calls' }] }, + ]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + const events = await collectIterator( + model.stream([new Message({ role: 'user', content: [new TextBlock('add')] })]) + ) + expect(events).toContainEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'add', toolUseId: 'tool_1' }, + }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"a":1' }, + }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: ',"b":2}' }, + }) + expect(events).toContainEqual({ type: 'modelMessageStopEvent', stopReason: 'toolUse' }) + }) + + it('maps `length` finish reason to maxTokens stop reason', async () => { + const { engine } = makeMockEngine([ + { choices: [{ delta: { role: 'assistant' } }] }, + { choices: [{ delta: { content: 'partial' } }] }, + { choices: [{ delta: {}, finish_reason: 'length' }] }, + ]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + const events = await collectIterator( + model.stream([new Message({ role: 'user', content: [new TextBlock('Hi')] })]) + ) + expect(events).toContainEqual({ type: 'modelMessageStopEvent', stopReason: 'maxTokens' }) + }) + + it('throws when called with no messages', async () => { + const { engine } = makeMockEngine([]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + await expect(collectIterator(model.stream([]))).rejects.toThrow('At least one message is required') + }) + + it('propagates errors from the engine', async () => { + const create = vi.fn(async () => { + throw new Error('engine boom') + }) as unknown as CreateMock + const engine = { + chat: { completions: { create } }, + unload: vi.fn(async () => undefined), + } + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + await expect( + collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('Hi')] })])) + ).rejects.toThrow('engine boom') + }) + }) + + describe('request formatting', () => { + async function captureRequest( + streamOptions: Parameters[1], + messages: Message[], + config: ConstructorParameters[0] = {} + ): Promise { + let captured: unknown + const create = vi.fn(async (req: unknown) => { + captured = req + async function* gen(): AsyncGenerator { + yield { choices: [{ delta: { role: 'assistant' } }] } + yield { choices: [{ delta: {}, finish_reason: 'stop' }] } + } + return gen() + }) as unknown as CreateMock + const engine = { + chat: { completions: { create } }, + unload: vi.fn(async () => undefined), + } + const model = new WebLLMModel({ ...config, engine: engine as never, modelId: 'test' }) + await collectIterator(model.stream(messages, streamOptions)) + return captured + } + + it('emits an OpenAI-compatible streaming request with config fields', async () => { + const req = (await captureRequest(undefined, [new Message({ role: 'user', content: [new TextBlock('hello')] })], { + temperature: 0.4, + maxTokens: 128, + topP: 0.9, + frequencyPenalty: 0.2, + presencePenalty: 0.1, + })) as Record + + expect(req).toMatchObject({ + stream: true, + stream_options: { include_usage: true }, + temperature: 0.4, + max_tokens: 128, + top_p: 0.9, + frequency_penalty: 0.2, + presence_penalty: 0.1, + messages: [{ role: 'user', content: 'hello' }], + }) + }) + + it('includes a system message when systemPrompt is a string', async () => { + const req = (await captureRequest({ systemPrompt: 'Be brief.' }, [ + new Message({ role: 'user', content: [new TextBlock('Hi')] }), + ])) as { messages: Array<{ role: string; content: string }> } + expect(req.messages[0]).toEqual({ role: 'system', content: 'Be brief.' }) + }) + + it('flattens system prompt content blocks to a single string', async () => { + const req = (await captureRequest( + { + systemPrompt: [new TextBlock('You are '), new TextBlock('helpful')], + }, + [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + )) as { messages: Array<{ role: string; content: string }> } + expect(req.messages[0]).toEqual({ role: 'system', content: 'You are helpful' }) + }) + + it('formats tool specs and tool_choice', async () => { + const req = (await captureRequest( + { + toolSpecs: [ + { + name: 'add', + description: 'Add two numbers', + inputSchema: { + type: 'object' as const, + properties: { a: { type: 'number' }, b: { type: 'number' } }, + }, + }, + ], + toolChoice: { any: {} }, + }, + [new Message({ role: 'user', content: [new TextBlock('2+2')] })] + )) as { tools: unknown[]; tool_choice: unknown } + expect(req.tools).toEqual([ + { + type: 'function', + function: { + name: 'add', + description: 'Add two numbers', + parameters: { type: 'object', properties: { a: { type: 'number' }, b: { type: 'number' } } }, + }, + }, + ]) + expect(req.tool_choice).toBe('required') + }) + + it('maps `tool` tool_choice to named function', async () => { + const req = (await captureRequest( + { + toolSpecs: [ + { name: 'foo', description: 'does foo', inputSchema: { type: 'object' as const, properties: {} } }, + ], + toolChoice: { tool: { name: 'foo' } }, + }, + [new Message({ role: 'user', content: [new TextBlock('go')] })] + )) as { tool_choice: unknown } + expect(req.tool_choice).toEqual({ type: 'function', function: { name: 'foo' } }) + }) + + it('emits assistant tool_calls from history', async () => { + const req = (await captureRequest(undefined, [ + new Message({ role: 'user', content: [new TextBlock('add 1+2')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'add', toolUseId: 't1', input: { a: 1, b: 2 } })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'success', + content: [new JsonBlock({ json: 3 })], + }), + ], + }), + ])) as { messages: Array> } + + expect(req.messages).toEqual([ + { role: 'user', content: 'add 1+2' }, + { + role: 'assistant', + content: '', + tool_calls: [{ id: 't1', type: 'function', function: { name: 'add', arguments: '{"a":1,"b":2}' } }], + }, + { role: 'tool', tool_call_id: 't1', content: '3' }, + ]) + }) + + it('wraps errored tool results with [ERROR] prefix', async () => { + const req = (await captureRequest(undefined, [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'error', + content: [new TextBlock('boom')], + }), + ], + }), + ])) as { messages: Array> } + expect(req.messages[0]).toEqual({ role: 'tool', tool_call_id: 't1', content: '[ERROR] boom' }) + }) + + it('throws when a tool spec lacks name or description', async () => { + const create = vi.fn(async () => { + async function* gen(): AsyncGenerator { + yield { choices: [{ delta: {}, finish_reason: 'stop' }] } + } + return gen() + }) as unknown as CreateMock + const engine = { + chat: { completions: { create } }, + unload: vi.fn(async () => undefined), + } + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + await expect( + collectIterator( + model.stream([new Message({ role: 'user', content: [new TextBlock('Hi')] })], { + toolSpecs: [{ name: '', description: 'x' } as never], + }) + ) + ).rejects.toThrow('Tool specification must have both name and description') + }) + }) + + describe('unload', () => { + it('is a no-op when the engine was externally provided', async () => { + const { engine } = makeMockEngine([]) + const model = new WebLLMModel({ engine: engine as never, modelId: 'test' }) + await model.unload() + expect(engine.unload).not.toHaveBeenCalled() + }) + }) +}) diff --git a/strands-ts/src/models/webllm/cache.ts b/strands-ts/src/models/webllm/cache.ts new file mode 100644 index 000000000..3a808003a --- /dev/null +++ b/strands-ts/src/models/webllm/cache.ts @@ -0,0 +1,311 @@ +/** + * WebLLM model cache and download helpers. + * + * These helpers are independent of an Agent or {@link WebLLMModel} — use them + * from a settings UI to inspect, pre-download, or evict models before wiring + * them into an agent. + */ + +import type { AppConfig, ChatOptions, InitProgressReport, MLCEngineInterface, ModelRecord } from '@mlc-ai/web-llm' +import { logger } from '../../logging/logger.js' +import { ModelError, normalizeError } from '../../errors.js' + +/** + * Thrown when WebLLM cannot run in the current environment (no WebGPU, + * no browser globals, or the `@mlc-ai/web-llm` package is missing). + */ +export class WebLLMUnavailableError extends ModelError { + constructor(message: string, options?: { cause?: unknown }) { + super(message, options) + this.name = 'WebLLMUnavailableError' + } +} + +/** + * Thrown when a requested model ID is not present in the active app config. + */ +export class WebLLMModelNotFoundError extends ModelError { + /** + * The model ID that could not be found. + */ + public readonly modelId: string + + constructor(modelId: string, message?: string) { + super(message ?? `WebLLM model '${modelId}' is not in the active appConfig.model_list`) + this.name = 'WebLLMModelNotFoundError' + this.modelId = modelId + } +} + +/** + * Options for pre-downloading a WebLLM model without constructing a + * {@link WebLLMModel}. + */ +export interface DownloadWebLLMModelOptions { + /** + * WebLLM model identifier to download. + */ + modelId: string + + /** + * Custom `AppConfig`. Defaults to WebLLM's `prebuiltAppConfig`. + */ + appConfig?: AppConfig + + /** + * Baseline `ChatConfig` overrides. + */ + chatOpts?: ChatOptions + + /** + * Progress callback with percent-complete, text, and elapsed time. + */ + onProgress?: (report: InitProgressReport) => void + + /** + * Signal to cancel the download. When aborted, the temporary engine is + * unloaded and an `AbortError` is thrown. + */ + signal?: AbortSignal +} + +/** + * Summary of a model entry in the WebLLM app config. + */ +export interface WebLLMModelInfo { + /** + * WebLLM model identifier (e.g. `Llama-3.1-8B-Instruct-q4f32_1-MLC`). + */ + modelId: string + + /** + * Model weights URL (from `ModelRecord.model`). + */ + modelUrl: string + + /** + * Model library (wasm) URL. + */ + modelLib: string + + /** + * Estimated VRAM requirement in megabytes, if provided by the registry. + */ + vramMB?: number + + /** + * Optional model type (e.g. `LLM`, `embedding`) from the registry. + */ + modelType?: string +} + +/** + * Shape of the subset of `@mlc-ai/web-llm` exports used by this provider. + * + * @internal + */ +interface WebLLMModule { + CreateMLCEngine: ( + modelId: string | string[], + engineConfig?: { + appConfig?: AppConfig + initProgressCallback?: (report: InitProgressReport) => void + logLevel?: 'TRACE' | 'DEBUG' | 'INFO' | 'WARN' | 'ERROR' | 'SILENT' + }, + chatOpts?: unknown + ) => Promise + prebuiltAppConfig: AppConfig + hasModelInCache: (modelId: string, appConfig?: AppConfig) => Promise + deleteModelAllInfoInCache: (modelId: string, appConfig?: AppConfig) => Promise +} + +let cachedModule: WebLLMModule | undefined + +/** + * Dynamically imports `@mlc-ai/web-llm`. Caches the module after first import. + * + * @throws {@link WebLLMUnavailableError} when the package is not installed + * or the current environment does not support it. + * + * @internal + */ +export async function loadWebLLMModule(): Promise { + if (cachedModule) return cachedModule + try { + const mod = (await import('@mlc-ai/web-llm')) as unknown as WebLLMModule + cachedModule = mod + return mod + } catch (error) { + throw new WebLLMUnavailableError( + "Failed to load '@mlc-ai/web-llm'. Install it as a peer dependency and ensure this code runs in a browser with WebGPU support.", + { cause: error } + ) + } +} + +/** + * Resets the cached WebLLM module reference. Intended for tests only. + * + * @internal + */ +export function resetWebLLMModuleCache(): void { + cachedModule = undefined +} + +/** + * Verifies the current environment can run WebLLM and throws with a clear + * message otherwise. Does not verify WebGPU — that check is deferred to + * engine creation to surface WebLLM's own diagnostics. + * + * @internal + */ +export function assertBrowserEnvironment(): void { + if (typeof window === 'undefined') { + throw new WebLLMUnavailableError('WebLLM requires a browser environment with WebGPU. Run this code in a browser.') + } +} + +/** + * Returns true if the model is already cached locally (no download needed). + * + * @param modelId - WebLLM model ID to check + * @param appConfig - Optional custom app config. Defaults to `prebuiltAppConfig`. + * @returns `true` if cached, `false` otherwise + * + * @throws {@link WebLLMModelNotFoundError} when `modelId` is not in the app config. + * @throws {@link WebLLMUnavailableError} when WebLLM cannot be loaded. + */ +export async function isWebLLMModelCached(modelId: string, appConfig?: AppConfig): Promise { + assertBrowserEnvironment() + const mod = await loadWebLLMModule() + const config = appConfig ?? mod.prebuiltAppConfig + ensureModelInConfig(modelId, config) + try { + return await mod.hasModelInCache(modelId, config) + } catch (error) { + logger.debug(`model_id=<${modelId}> | hasModelInCache failed, treating as not cached | error=<${error}>`) + return false + } +} + +/** + * Deletes all cached data for a model (weights, tokenizer, wasm, chat config). + * + * @param modelId - WebLLM model ID to evict + * @param appConfig - Optional custom app config. Defaults to `prebuiltAppConfig`. + * + * @throws {@link WebLLMModelNotFoundError} when `modelId` is not in the app config. + * @throws {@link WebLLMUnavailableError} when WebLLM cannot be loaded. + */ +export async function deleteWebLLMModel(modelId: string, appConfig?: AppConfig): Promise { + assertBrowserEnvironment() + const mod = await loadWebLLMModule() + const config = appConfig ?? mod.prebuiltAppConfig + ensureModelInConfig(modelId, config) + await mod.deleteModelAllInfoInCache(modelId, config) +} + +/** + * Lists models available in the active app config. + * + * @param appConfig - Optional custom app config. Defaults to `prebuiltAppConfig`. + * @returns Array of {@link WebLLMModelInfo} entries. + * + * @throws {@link WebLLMUnavailableError} when WebLLM cannot be loaded. + */ +export async function listWebLLMModels(appConfig?: AppConfig): Promise { + const mod = await loadWebLLMModule() + const config = appConfig ?? mod.prebuiltAppConfig + return config.model_list.map(toModelInfo) +} + +/** + * Pre-downloads a WebLLM model by creating a temporary engine, waiting for + * the model to load, and then unloading the engine. The model weights remain + * in browser cache (IndexedDB / CacheStorage) for subsequent use. + * + * Supports cancellation via `AbortSignal` — on abort, the engine is unloaded + * and an `AbortError` is thrown. Note that in-flight chunk fetches may continue + * briefly before fully stopping. + * + * @param options - Download options + * + * @throws {@link WebLLMModelNotFoundError} when `modelId` is not in the app config. + * @throws {@link WebLLMUnavailableError} when WebLLM cannot be loaded. + * + * @example + * ```typescript + * await downloadWebLLMModel({ + * modelId: 'Llama-3.1-8B-Instruct-q4f32_1-MLC', + * onProgress: (r) => console.log(r.text, r.progress), + * }) + * ``` + */ +export async function downloadWebLLMModel(options: DownloadWebLLMModelOptions): Promise { + assertBrowserEnvironment() + const mod = await loadWebLLMModule() + const config = options.appConfig ?? mod.prebuiltAppConfig + ensureModelInConfig(options.modelId, config) + + if (options.signal?.aborted) { + throw abortError() + } + + let engine: MLCEngineInterface | undefined + const abortHandler = (): void => { + if (engine) { + engine.unload().catch((error: unknown) => { + logger.debug(`model_id=<${options.modelId}> | unload on abort failed | error=<${error}>`) + }) + } + } + options.signal?.addEventListener('abort', abortHandler, { once: true }) + + try { + const engineConfig: Parameters[1] = { appConfig: config } + if (options.onProgress) engineConfig.initProgressCallback = options.onProgress + engine = await mod.CreateMLCEngine(options.modelId, engineConfig, options.chatOpts) + if (options.signal?.aborted) { + throw abortError() + } + } catch (error) { + if (options.signal?.aborted) { + throw abortError() + } + throw normalizeError(error) + } finally { + options.signal?.removeEventListener('abort', abortHandler) + if (engine) { + try { + await engine.unload() + } catch (error) { + logger.debug(`model_id=<${options.modelId}> | unload after download failed | error=<${error}>`) + } + } + } +} + +function ensureModelInConfig(modelId: string, appConfig: AppConfig): void { + if (!appConfig.model_list.some((m: ModelRecord) => m.model_id === modelId)) { + throw new WebLLMModelNotFoundError(modelId) + } +} + +function toModelInfo(record: ModelRecord): WebLLMModelInfo { + const info: WebLLMModelInfo = { + modelId: record.model_id, + modelUrl: record.model, + modelLib: record.model_lib, + } + if (record.vram_required_MB !== undefined) info.vramMB = record.vram_required_MB + if (typeof (record as unknown as { model_type?: string }).model_type === 'string') { + info.modelType = (record as unknown as { model_type: string }).model_type + } + return info +} + +function abortError(): Error { + const err = new Error('WebLLM download aborted') + err.name = 'AbortError' + return err +} diff --git a/strands-ts/src/models/webllm/index.ts b/strands-ts/src/models/webllm/index.ts new file mode 100644 index 000000000..3cd24f5eb --- /dev/null +++ b/strands-ts/src/models/webllm/index.ts @@ -0,0 +1,45 @@ +/** + * WebLLM model provider — on-device inference in the browser via WebGPU. + * + * Powered by `@mlc-ai/web-llm`, this provider lets agents run LLMs locally + * without sending requests to a remote API. Models are downloaded on first + * use and cached in browser storage for subsequent runs. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { WebLLMModel } from '@strands-agents/sdk/models/webllm' + * + * const model = new WebLLMModel({ + * modelId: 'Llama-3.1-8B-Instruct-q4f32_1-MLC', + * onProgress: (r) => console.log(r.text, r.progress), + * }) + * const agent = new Agent({ model }) + * const result = await agent.invoke('Hello!') + * ``` + * + * @example + * ```typescript + * // Pre-download a model independently of an Agent, e.g. from a settings UI + * import { downloadWebLLMModel, isWebLLMModelCached } from '@strands-agents/sdk/models/webllm' + * + * if (!(await isWebLLMModelCached('Phi-3.5-mini-instruct-q4f16_1-MLC'))) { + * await downloadWebLLMModel({ + * modelId: 'Phi-3.5-mini-instruct-q4f16_1-MLC', + * onProgress: (r) => updateProgressBar(r.progress, r.text), + * }) + * } + * ``` + */ + +export { WebLLMModel } from './model.js' +export type { WebLLMModelConfig, WebLLMModelOptions } from './model.js' +export { + deleteWebLLMModel, + downloadWebLLMModel, + isWebLLMModelCached, + listWebLLMModels, + WebLLMModelNotFoundError, + WebLLMUnavailableError, +} from './cache.js' +export type { DownloadWebLLMModelOptions, WebLLMModelInfo } from './cache.js' diff --git a/strands-ts/src/models/webllm/model.ts b/strands-ts/src/models/webllm/model.ts new file mode 100644 index 000000000..41f7ea9cb --- /dev/null +++ b/strands-ts/src/models/webllm/model.ts @@ -0,0 +1,581 @@ +/** + * WebLLM model provider. + * + * Runs LLMs locally in the browser via WebGPU using `@mlc-ai/web-llm`. Models + * are downloaded on first use and cached in browser storage (IndexedDB / + * CacheStorage) for subsequent runs. Use the cache helpers (`downloadWebLLMModel`, + * `isWebLLMModelCached`, `deleteWebLLMModel`, `listWebLLMModels`) to pre-download + * models, check the cache, or evict models independently of an agent invocation. + * + * @see https://webllm.mlc.ai/ + */ + +import type { AppConfig, ChatOptions, InitProgressReport, MLCEngineInterface } from '@mlc-ai/web-llm' +import { Model, resolveConfigMetadata } from '../model.js' +import type { BaseModelConfig, StreamOptions } from '../model.js' +import type { Message, StopReason, ToolResultBlock } from '../../types/messages.js' +import type { ModelStreamEvent, Usage } from '../streaming.js' +import { normalizeError } from '../../errors.js' +import { logger } from '../../logging/logger.js' +import { warnOnce } from '../../logging/warn-once.js' +import { assertBrowserEnvironment, loadWebLLMModule } from './cache.js' + +const DEFAULT_MODEL_ID = 'Llama-3.1-8B-Instruct-q4f32_1-MLC' + +/** + * Configuration for the WebLLM model provider. + */ +export interface WebLLMModelConfig extends BaseModelConfig { + /** + * WebLLM model identifier. + * Must match a `model_id` in the active `AppConfig.model_list` + * (defaults to {@link https://github.com/mlc-ai/web-llm/blob/main/src/config.ts | prebuiltAppConfig}). + */ + modelId?: string + + /** + * Controls randomness in generation (0.0 to 2.0). + */ + temperature?: number + + /** + * Maximum number of tokens to generate in the response. + */ + maxTokens?: number + + /** + * Controls diversity via nucleus sampling (0.0 to 1.0). + */ + topP?: number + + /** + * Reduces repetition of token sequences (-2.0 to 2.0). + */ + frequencyPenalty?: number + + /** + * Encourages the model to talk about new topics (-2.0 to 2.0). + */ + presencePenalty?: number + + /** + * Additional parameters forwarded to `engine.chat.completions.create`. + * + * Provider-managed fields (`messages`, `stream`, `stream_options`, `tools`, + * `tool_choice`, `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, + * `presence_penalty`) are overwritten by the provider even if set here. + */ + params?: Record +} + +/** + * Options for constructing a {@link WebLLMModel}. + */ +export interface WebLLMModelOptions extends WebLLMModelConfig { + /** + * Pre-constructed WebLLM engine. If provided, the model will not create its own + * engine and will not call `reload()` — the caller is responsible for loading + * the desired model. Use this to share a single engine across multiple model + * instances or to use a web/service worker engine variant. + */ + engine?: MLCEngineInterface + + /** + * Custom WebLLM `AppConfig`. Needed when registering a model URL that is not + * part of WebLLM's built-in prebuilt list, or when overriding the cache backend. + * @see https://github.com/mlc-ai/web-llm#custom-models + */ + appConfig?: AppConfig + + /** + * Baseline `ChatConfig` overrides passed to `engine.reload()` + * (e.g. `context_window_size`, `repetition_penalty`). + * Not used when `engine` is provided. + */ + chatOpts?: ChatOptions + + /** + * Called during the initial model load/download with progress updates. + * Only invoked when the model creates its own engine. + */ + onProgress?: (report: InitProgressReport) => void +} + +/** + * Minimal structural shape of a WebLLM / OpenAI chat completion chunk. + * We declare this locally rather than importing WebLLM's own type because + * WebLLM doesn't expose OpenAI protocol types at its package root with + * `verbatimModuleSyntax`-friendly resolution. + */ +interface ChatCompletionChunkLike { + choices?: Array<{ + delta?: { + role?: string + content?: string | null + tool_calls?: Array<{ + index?: number + id?: string + type?: string + function?: { name?: string; arguments?: string } + }> + } + finish_reason?: string | null + }> + usage?: { + prompt_tokens?: number + completion_tokens?: number + total_tokens?: number + } | null +} + +/** + * Minimal structural shape for an OpenAI-style chat completion message param + * (system/user/assistant/tool). WebLLM accepts OpenAI's message format. + */ +type ChatMessageParam = + | { role: 'system'; content: string } + | { role: 'user'; content: string } + | { + role: 'assistant' + content: string + tool_calls?: Array<{ id: string; type: 'function'; function: { name: string; arguments: string } }> + } + | { role: 'tool'; tool_call_id: string; content: string } + +/** + * Minimal structural shape for a chat completion tool definition. + */ +interface ChatTool { + type: 'function' + function: { + name: string + description: string + parameters: Record + } +} + +/** + * Parameters passed to `engine.chat.completions.create` for streaming WebLLM + * chat completions. WebLLM's streaming overload is structurally identical to + * OpenAI's streaming API. + */ +interface WebLLMChatCreateParams { + messages: ChatMessageParam[] + stream: true + stream_options?: { include_usage: boolean } + temperature?: number + max_tokens?: number + top_p?: number + frequency_penalty?: number + presence_penalty?: number + tools?: ChatTool[] + tool_choice?: 'auto' | 'required' | { type: 'function'; function: { name: string } } +} + +/** + * WebLLM model provider — on-device inference via WebGPU. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { WebLLMModel } from '@strands-agents/sdk/models/webllm' + * + * const model = new WebLLMModel({ + * modelId: 'Llama-3.1-8B-Instruct-q4f32_1-MLC', + * onProgress: (r) => console.log(r.text, r.progress), + * }) + * + * const agent = new Agent({ model }) + * const result = await agent.invoke('Hello!') + * ``` + * + * @example + * ```typescript + * // Share a pre-created engine (e.g. to use a web worker) + * import { CreateWebWorkerMLCEngine } from '@mlc-ai/web-llm' + * const engine = await CreateWebWorkerMLCEngine(worker, 'Phi-3.5-mini-instruct-q4f16_1-MLC') + * const model = new WebLLMModel({ engine, modelId: 'Phi-3.5-mini-instruct-q4f16_1-MLC' }) + * ``` + */ +export class WebLLMModel extends Model { + private _config: WebLLMModelConfig + private readonly _appConfig: AppConfig | undefined + private readonly _chatOpts: ChatOptions | undefined + private readonly _onProgress: ((report: InitProgressReport) => void) | undefined + private readonly _externalEngine: MLCEngineInterface | undefined + private _enginePromise: Promise | undefined + + constructor(options?: WebLLMModelOptions) { + super() + const { engine, appConfig, chatOpts, onProgress, ...modelConfig } = options ?? {} + + this._config = { ...modelConfig } + this._appConfig = appConfig + this._chatOpts = chatOpts + this._onProgress = onProgress + this._externalEngine = engine + + if (modelConfig.modelId === undefined) { + warnOnce( + logger, + `model_id=<${DEFAULT_MODEL_ID}> | using default WebLLM modelId, which is subject to change | set modelId explicitly to pin the value` + ) + } + } + + updateConfig(modelConfig: WebLLMModelConfig): void { + this._config = { ...this._config, ...modelConfig } + } + + getConfig(): WebLLMModelConfig { + return resolveConfigMetadata(this._config, this._config.modelId ?? DEFAULT_MODEL_ID) + } + + /** + * Unloads the underlying WebLLM engine, freeing GPU memory. No-op if an + * external engine was provided via the `engine` option — lifecycle of + * externally-supplied engines is the caller's responsibility. + */ + async unload(): Promise { + if (this._externalEngine) return + if (!this._enginePromise) return + try { + const engine = await this._enginePromise + await engine.unload() + } catch (error) { + logger.debug(`unload failed | error=<${error}>`) + } finally { + this._enginePromise = undefined + } + } + + async *stream(messages: Message[], options?: StreamOptions): AsyncIterable { + if (!messages || messages.length === 0) { + throw new Error('At least one message is required') + } + + const engine = await this._getEngine() + + try { + const request = this._formatRequest(messages, options) + // WebLLM's typings for chat.completions.create are narrower than the + // streaming variant we always use — cast to the streaming return type. + const stream = (await engine.chat.completions.create( + request as unknown as Parameters[0] + )) as unknown as AsyncIterable + + const state: StreamState = { messageStarted: false, textContentBlockStarted: false } + const activeToolCalls = new Map() + let bufferedUsage: ModelStreamEvent | undefined + let bufferedStop: ModelStreamEvent | undefined + + for await (const chunk of stream) { + const usage = extractUsage(chunk) + if (usage) { + bufferedUsage = { type: 'modelMetadataEvent', usage } + } + + for (const event of mapChunkToEvents(chunk, state, activeToolCalls)) { + if (event.type === 'modelMessageStopEvent') { + // Hold the stop event until the stream drains so any trailing + // usage-only chunk produces metadata before stop. + bufferedStop = event + } else { + yield event + } + } + } + + if (bufferedUsage) yield bufferedUsage + if (bufferedStop) yield bufferedStop + } catch (error) { + throw normalizeError(error) + } + } + + private async _getEngine(): Promise { + if (this._externalEngine) return this._externalEngine + if (!this._enginePromise) { + this._enginePromise = this._createEngine().catch((error) => { + // Allow retry after a failed init + this._enginePromise = undefined + throw error + }) + } + return this._enginePromise + } + + private async _createEngine(): Promise { + assertBrowserEnvironment() + const mod = await loadWebLLMModule() + const modelId = this._config.modelId ?? DEFAULT_MODEL_ID + const engineConfig: Parameters[1] = {} + if (this._appConfig) engineConfig.appConfig = this._appConfig + if (this._onProgress) engineConfig.initProgressCallback = this._onProgress + return mod.CreateMLCEngine(modelId, engineConfig, this._chatOpts) + } + + private _formatRequest(messages: Message[], options?: StreamOptions): WebLLMChatCreateParams { + const request: WebLLMChatCreateParams = { + ...(this._config.params ?? {}), + messages: this._formatMessages(messages, options), + stream: true, + stream_options: { include_usage: true }, + } + + if (this._config.temperature !== undefined) request.temperature = this._config.temperature + if (this._config.maxTokens !== undefined) request.max_tokens = this._config.maxTokens + if (this._config.topP !== undefined) request.top_p = this._config.topP + if (this._config.frequencyPenalty !== undefined) request.frequency_penalty = this._config.frequencyPenalty + if (this._config.presencePenalty !== undefined) request.presence_penalty = this._config.presencePenalty + + if (options?.toolSpecs && options.toolSpecs.length > 0) { + request.tools = options.toolSpecs.map((spec) => { + if (!spec.name || !spec.description) { + throw new Error('Tool specification must have both name and description') + } + return { + type: 'function', + function: { + name: spec.name, + description: spec.description, + parameters: (spec.inputSchema ?? { type: 'object', properties: {} }) as Record, + }, + } + }) + + if (options.toolChoice) { + if ('auto' in options.toolChoice) { + request.tool_choice = 'auto' + } else if ('any' in options.toolChoice) { + request.tool_choice = 'required' + } else if ('tool' in options.toolChoice) { + request.tool_choice = { + type: 'function', + function: { name: options.toolChoice.tool.name }, + } + } + } + } + + return request + } + + private _formatMessages(messages: Message[], options?: StreamOptions): ChatMessageParam[] { + const result: ChatMessageParam[] = [] + + if (options?.systemPrompt !== undefined) { + const systemText = extractSystemText(options.systemPrompt) + if (systemText.length > 0) { + result.push({ role: 'system', content: systemText }) + } + } + + for (const message of messages) { + if (message.role === 'user') { + const toolResults = message.content.filter((b): b is ToolResultBlock => b.type === 'toolResultBlock') + const textParts: string[] = [] + for (const block of message.content) { + if (block.type === 'textBlock') { + textParts.push(block.text) + } else if (block.type !== 'toolResultBlock') { + logger.warn( + `block_type=<${block.type}> | webllm user messages only support text and tool results, skipping` + ) + } + } + + const text = textParts.join('').trim() + if (text.length > 0) { + result.push({ role: 'user', content: text }) + } + + for (const toolResult of toolResults) { + result.push({ + role: 'tool', + tool_call_id: toolResult.toolUseId, + content: formatToolResultContent(toolResult), + }) + } + } else { + const assistantText: string[] = [] + const toolCalls: Array<{ id: string; type: 'function'; function: { name: string; arguments: string } }> = [] + + for (const block of message.content) { + if (block.type === 'textBlock') { + assistantText.push(block.text) + } else if (block.type === 'toolUseBlock') { + toolCalls.push({ + id: block.toolUseId, + type: 'function', + function: { + name: block.name, + arguments: JSON.stringify(block.input), + }, + }) + } else if (block.type === 'reasoningBlock') { + if (block.text) assistantText.push(block.text) + } else { + logger.warn( + `block_type=<${block.type}> | webllm assistant messages only support text and tool use, skipping` + ) + } + } + + const content = assistantText.join('').trim() + if (content.length === 0 && toolCalls.length === 0) continue + + if (toolCalls.length > 0) { + result.push({ + role: 'assistant', + content, + tool_calls: toolCalls, + }) + } else { + result.push({ role: 'assistant', content }) + } + } + } + + return result + } +} + +interface StreamState { + messageStarted: boolean + textContentBlockStarted: boolean +} + +function mapChunkToEvents( + chunk: ChatCompletionChunkLike, + state: StreamState, + activeToolCalls: Map +): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + const choice = chunk.choices?.[0] + if (!choice) return events + + const delta = choice.delta + + if (delta?.role && !state.messageStarted) { + state.messageStarted = true + events.push({ type: 'modelMessageStartEvent', role: delta.role as 'user' | 'assistant' }) + } + + if (delta?.content && delta.content.length > 0) { + if (!state.textContentBlockStarted) { + state.textContentBlockStarted = true + events.push({ type: 'modelContentBlockStartEvent' }) + } + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: delta.content }, + }) + } + + if (delta?.tool_calls && delta.tool_calls.length > 0) { + for (const toolCall of delta.tool_calls) { + if (toolCall.index === undefined || typeof toolCall.index !== 'number') { + logger.warn(`tool_call=<${JSON.stringify(toolCall)}> | received tool call with invalid index`) + continue + } + + if (toolCall.id && toolCall.function?.name) { + events.push({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: toolCall.function.name, toolUseId: toolCall.id }, + }) + activeToolCalls.set(toolCall.index, true) + } + + if (toolCall.function?.arguments) { + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: toolCall.function.arguments }, + }) + } + } + } + + if (choice.finish_reason) { + if (state.textContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + state.textContentBlockStarted = false + } + + for (const [index] of activeToolCalls) { + events.push({ type: 'modelContentBlockStopEvent' }) + activeToolCalls.delete(index) + } + + events.push({ type: 'modelMessageStopEvent', stopReason: mapFinishReason(choice.finish_reason) }) + } + + return events +} + +function extractUsage(chunk: ChatCompletionChunkLike): Usage | undefined { + if (!chunk.usage) return undefined + return { + inputTokens: chunk.usage.prompt_tokens ?? 0, + outputTokens: chunk.usage.completion_tokens ?? 0, + totalTokens: chunk.usage.total_tokens ?? 0, + } +} + +function mapFinishReason(finishReason: string): StopReason { + switch (finishReason) { + case 'stop': + return 'endTurn' + case 'tool_calls': + return 'toolUse' + case 'length': + return 'maxTokens' + case 'content_filter': + return 'contentFiltered' + default: + logger.warn(`finish_reason=<${finishReason}> | unknown webllm stop reason, passing through`) + return finishReason + } +} + +function extractSystemText(systemPrompt: StreamOptions['systemPrompt']): string { + if (typeof systemPrompt === 'string') return systemPrompt.trim() + if (!Array.isArray(systemPrompt)) return '' + const parts: string[] = [] + for (const block of systemPrompt) { + if (block.type === 'textBlock') { + parts.push(block.text) + } else if (block.type === 'cachePointBlock') { + logger.warn('cache points are not supported in webllm system prompts, ignoring cache points') + } else if (block.type === 'guardContentBlock') { + logger.warn('guard content is not supported in webllm system prompts, removing guard content block') + } + } + return parts.join('').trim() +} + +function formatToolResultContent(toolResult: ToolResultBlock): string { + const parts: string[] = [] + for (const block of toolResult.content) { + if (block.type === 'textBlock') { + parts.push(block.text) + } else if (block.type === 'jsonBlock') { + try { + parts.push(JSON.stringify(block.json)) + } catch (error) { + logger.warn(`tool_use_id=<${toolResult.toolUseId}> | failed to stringify json block | error=<${error}>`) + } + } else { + logger.warn( + `tool_use_id=<${toolResult.toolUseId}>, block_type=<${block.type}> | webllm tool results only support text and json, skipping` + ) + } + } + + const text = parts.join('').trim() + if (text.length === 0) { + return toolResult.status === 'error' ? '[ERROR] (empty)' : '(empty)' + } + return toolResult.status === 'error' ? `[ERROR] ${text}` : text +} diff --git a/strands-ts/test/packages/cjs-module/cjs.js b/strands-ts/test/packages/cjs-module/cjs.js index 543facaac..ea2aaaa99 100644 --- a/strands-ts/test/packages/cjs-module/cjs.js +++ b/strands-ts/test/packages/cjs-module/cjs.js @@ -15,6 +15,7 @@ const { BedrockModel: BedrockFromSubpath } = require('@strands-agents/sdk/models const { OpenAIModel } = require('@strands-agents/sdk/models/openai') const { AnthropicModel } = require('@strands-agents/sdk/models/anthropic') const { GoogleModel } = require('@strands-agents/sdk/models/google') +const { WebLLMModel, listWebLLMModels, WebLLMUnavailableError } = require('@strands-agents/sdk/models/webllm') const { z } = require('zod') @@ -85,6 +86,21 @@ async function main() { throw new Error('BedrockModel from subpath should match main export') } console.log('✓ Model subpath exports verified') + + // Verify WebLLM subpath export — constructor is available even in node without the peer dep. + if (typeof WebLLMModel !== 'function') { + throw new Error('WebLLMModel should be a constructor') + } + try { + await listWebLLMModels() + console.log('✓ WebLLM subpath exports loaded (with peer available)') + } catch (err) { + if (err instanceof WebLLMUnavailableError) { + console.log('✓ WebLLM subpath exports resolve and fail cleanly without peer dep') + } else { + throw err + } + } } main().catch((error) => { diff --git a/strands-ts/test/packages/esm-module/esm.js b/strands-ts/test/packages/esm-module/esm.js index 440a4ecfa..9b78227c2 100644 --- a/strands-ts/test/packages/esm-module/esm.js +++ b/strands-ts/test/packages/esm-module/esm.js @@ -15,6 +15,7 @@ import { BedrockModel as BedrockFromSubpath } from '@strands-agents/sdk/models/b import { OpenAIModel } from '@strands-agents/sdk/models/openai' import { AnthropicModel } from '@strands-agents/sdk/models/anthropic' import { GoogleModel } from '@strands-agents/sdk/models/google' +import { WebLLMModel, listWebLLMModels, WebLLMUnavailableError } from '@strands-agents/sdk/models/webllm' import { z } from 'zod' @@ -110,3 +111,22 @@ if (BedrockFromSubpath !== BedrockModel) { throw new Error('BedrockModel from subpath should match main export') } console.log('✓ Model subpath exports verified') + +// Verify WebLLM subpath export — constructor is available and helpers work without a browser +// as long as we don't touch the engine. Calling listWebLLMModels triggers module load, which +// should fail cleanly with WebLLMUnavailableError in node (no @mlc-ai/web-llm installed). +if (typeof WebLLMModel !== 'function') { + throw new Error('WebLLMModel should be a constructor') +} +try { + await listWebLLMModels() + // In the package-test environment there is no @mlc-ai/web-llm peer, so this should throw. + // If it ever resolves, we still want to succeed. + console.log('✓ WebLLM subpath exports loaded (with peer available)') +} catch (err) { + if (err instanceof WebLLMUnavailableError) { + console.log('✓ WebLLM subpath exports resolve and fail cleanly without peer dep') + } else { + throw err + } +}