This repository was archived by the owner on Jun 3, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 102
feat(models): add WebLLM model provider for on-device browser inference #1036
Draft
jsamuel1
wants to merge
1
commit into
strands-agents:main
Choose a base branch
from
jsamuel1:feat/webllm-model-provider
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
strands-ts/src/models/webllm/__tests__/browser.test.browser.ts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| // ABOUTME: Browser-only smoke test for the WebLLM provider. | ||
| // ABOUTME: Verifies the public module imports cleanly and listWebLLMModels works | ||
| // ABOUTME: against the real @mlc-ai/web-llm prebuilt app config in a browser. | ||
|
|
||
| import { describe, it, expect } from 'vitest' | ||
| import { isBrowser } from '../../../__fixtures__/environment.js' | ||
| import { WebLLMModel, listWebLLMModels } from '../index.js' | ||
|
|
||
| describe('WebLLM browser smoke', () => { | ||
| it('runs in a browser environment', () => { | ||
| expect(isBrowser).toBe(true) | ||
| }) | ||
|
|
||
| it('exposes WebLLMModel as a constructor', () => { | ||
| expect(typeof WebLLMModel).toBe('function') | ||
| const model = new WebLLMModel({ modelId: 'Llama-3.1-8B-Instruct-q4f32_1-MLC' }) | ||
| expect(model.getConfig().modelId).toBe('Llama-3.1-8B-Instruct-q4f32_1-MLC') | ||
| }) | ||
|
|
||
| it('lists prebuilt models', async () => { | ||
| const models = await listWebLLMModels() | ||
| expect(models.length).toBeGreaterThan(0) | ||
| expect(models[0]).toHaveProperty('modelId') | ||
| expect(models[0]).toHaveProperty('modelUrl') | ||
| expect(models[0]).toHaveProperty('modelLib') | ||
| }) | ||
| }) |
210 changes: 210 additions & 0 deletions
210
strands-ts/src/models/webllm/__tests__/cache.test.node.ts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| // ABOUTME: Unit tests for WebLLM cache / download helpers. | ||
| // ABOUTME: The `@mlc-ai/web-llm` module is mocked so these run in node without WebGPU. | ||
|
|
||
| import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' | ||
| import type { MockedFunction } from 'vitest' | ||
| import { | ||
| deleteWebLLMModel, | ||
| downloadWebLLMModel, | ||
| isWebLLMModelCached, | ||
| listWebLLMModels, | ||
| resetWebLLMModuleCache, | ||
| WebLLMModelNotFoundError, | ||
| WebLLMUnavailableError, | ||
| } from '../cache.js' | ||
|
|
||
| // A minimal mock of the `@mlc-ai/web-llm` module surface we depend on. | ||
| const mockPrebuiltAppConfig = { | ||
| model_list: [ | ||
| { | ||
| model_id: 'test-model', | ||
| model: 'https://example.com/test-model', | ||
| model_lib: 'https://example.com/test-model.wasm', | ||
| vram_required_MB: 2048, | ||
| }, | ||
| { | ||
| model_id: 'other-model', | ||
| model: 'https://example.com/other', | ||
| model_lib: 'https://example.com/other.wasm', | ||
| }, | ||
| ], | ||
| } | ||
|
|
||
| const mockCreateEngine = vi.fn( | ||
| async ( | ||
| _modelId: string | string[], | ||
| _engineConfig?: { initProgressCallback?: (report: unknown) => void }, | ||
| _chatOpts?: unknown | ||
| ) => ({ | ||
| unload: vi.fn(async () => undefined), | ||
| chat: { completions: { create: vi.fn() } }, | ||
| }) | ||
| ) | ||
| const mockHasModelInCache = vi.fn(async () => false) | ||
| const mockDeleteModelAllInfoInCache = vi.fn(async () => undefined) | ||
|
|
||
| vi.mock('@mlc-ai/web-llm', () => ({ | ||
| CreateMLCEngine: mockCreateEngine, | ||
| prebuiltAppConfig: mockPrebuiltAppConfig, | ||
| hasModelInCache: mockHasModelInCache, | ||
| deleteModelAllInfoInCache: mockDeleteModelAllInfoInCache, | ||
| })) | ||
|
|
||
| // Fake out the browser environment check so these helpers run in node. | ||
| const originalWindow = globalThis.window | ||
| beforeEach(() => { | ||
| ;(globalThis as { window?: unknown }).window = {} as unknown | ||
| vi.clearAllMocks() | ||
| resetWebLLMModuleCache() | ||
| mockHasModelInCache.mockResolvedValue(false) | ||
| mockDeleteModelAllInfoInCache.mockResolvedValue(undefined) | ||
| mockCreateEngine.mockImplementation(async () => ({ | ||
| unload: vi.fn(async () => undefined), | ||
| chat: { completions: { create: vi.fn() } }, | ||
| })) | ||
| }) | ||
| afterEach(() => { | ||
| if (originalWindow === undefined) { | ||
| delete (globalThis as { window?: unknown }).window | ||
| } else { | ||
| ;(globalThis as { window?: unknown }).window = originalWindow | ||
| } | ||
| }) | ||
|
|
||
| describe('isWebLLMModelCached', () => { | ||
| it('returns true when the model is in cache', async () => { | ||
| mockHasModelInCache.mockResolvedValueOnce(true) | ||
| const result = await isWebLLMModelCached('test-model') | ||
| expect(result).toBe(true) | ||
| expect(mockHasModelInCache).toHaveBeenCalledWith('test-model', mockPrebuiltAppConfig) | ||
| }) | ||
|
|
||
| it('returns false when the model is not in cache', async () => { | ||
| const result = await isWebLLMModelCached('test-model') | ||
| expect(result).toBe(false) | ||
| }) | ||
|
|
||
| it('returns false when hasModelInCache throws (treats as not cached)', async () => { | ||
| mockHasModelInCache.mockRejectedValueOnce(new Error('storage error')) | ||
| const result = await isWebLLMModelCached('test-model') | ||
| expect(result).toBe(false) | ||
| }) | ||
|
|
||
| it('throws WebLLMModelNotFoundError for unknown modelId', async () => { | ||
| await expect(isWebLLMModelCached('nonexistent-model')).rejects.toBeInstanceOf(WebLLMModelNotFoundError) | ||
| }) | ||
|
|
||
| it('throws WebLLMUnavailableError when not in browser environment', async () => { | ||
| delete (globalThis as { window?: unknown }).window | ||
| await expect(isWebLLMModelCached('test-model')).rejects.toBeInstanceOf(WebLLMUnavailableError) | ||
| }) | ||
| }) | ||
|
|
||
| describe('deleteWebLLMModel', () => { | ||
| it('delegates to deleteModelAllInfoInCache', async () => { | ||
| await deleteWebLLMModel('test-model') | ||
| expect(mockDeleteModelAllInfoInCache).toHaveBeenCalledWith('test-model', mockPrebuiltAppConfig) | ||
| }) | ||
|
|
||
| it('throws for unknown model', async () => { | ||
| await expect(deleteWebLLMModel('nonexistent')).rejects.toBeInstanceOf(WebLLMModelNotFoundError) | ||
| }) | ||
| }) | ||
|
|
||
| describe('listWebLLMModels', () => { | ||
| it('returns all models from prebuiltAppConfig', async () => { | ||
| const models = await listWebLLMModels() | ||
| expect(models).toHaveLength(2) | ||
| expect(models[0]).toEqual({ | ||
| modelId: 'test-model', | ||
| modelUrl: 'https://example.com/test-model', | ||
| modelLib: 'https://example.com/test-model.wasm', | ||
| vramMB: 2048, | ||
| }) | ||
| expect(models[1]).toEqual({ | ||
| modelId: 'other-model', | ||
| modelUrl: 'https://example.com/other', | ||
| modelLib: 'https://example.com/other.wasm', | ||
| }) | ||
| }) | ||
|
|
||
| it('uses custom appConfig when provided', async () => { | ||
| const custom = { | ||
| model_list: [{ model_id: 'custom', model: 'x', model_lib: 'y' }], | ||
| } | ||
| const models = await listWebLLMModels(custom as never) | ||
| expect(models).toEqual([{ modelId: 'custom', modelUrl: 'x', modelLib: 'y' }]) | ||
| }) | ||
| }) | ||
|
|
||
| describe('downloadWebLLMModel', () => { | ||
| it('creates a temporary engine and unloads it after load', async () => { | ||
| const unload = vi.fn(async () => undefined) | ||
| mockCreateEngine.mockImplementationOnce(async () => ({ | ||
| unload, | ||
| chat: { completions: { create: vi.fn() } }, | ||
| })) | ||
| await downloadWebLLMModel({ modelId: 'test-model' }) | ||
| expect(mockCreateEngine).toHaveBeenCalledTimes(1) | ||
| expect(mockCreateEngine).toHaveBeenCalledWith('test-model', { appConfig: mockPrebuiltAppConfig }, undefined) | ||
| expect(unload).toHaveBeenCalledTimes(1) | ||
| }) | ||
|
|
||
| it('forwards onProgress as the engine initProgressCallback', async () => { | ||
| const onProgress = vi.fn() | ||
| const unload = vi.fn(async () => undefined) | ||
| mockCreateEngine.mockImplementationOnce(async (_modelId, engineConfig) => { | ||
| ;(engineConfig as { initProgressCallback?: (r: unknown) => void }).initProgressCallback?.({ | ||
| progress: 0.5, | ||
| text: 'loading', | ||
| timeElapsed: 1, | ||
| }) | ||
| return { unload, chat: { completions: { create: vi.fn() } } } | ||
| }) | ||
| await downloadWebLLMModel({ modelId: 'test-model', onProgress }) | ||
| expect(onProgress).toHaveBeenCalledWith({ progress: 0.5, text: 'loading', timeElapsed: 1 }) | ||
| }) | ||
|
|
||
| it('throws AbortError when signal is already aborted', async () => { | ||
| const controller = new AbortController() | ||
| controller.abort() | ||
| await expect(downloadWebLLMModel({ modelId: 'test-model', signal: controller.signal })).rejects.toMatchObject({ | ||
| name: 'AbortError', | ||
| }) | ||
| expect(mockCreateEngine).not.toHaveBeenCalled() | ||
| }) | ||
|
|
||
| it('throws AbortError when aborted mid-download', async () => { | ||
| const controller = new AbortController() | ||
| const unload = vi.fn(async () => undefined) | ||
| mockCreateEngine.mockImplementationOnce(async () => { | ||
| controller.abort() | ||
| return { unload, chat: { completions: { create: vi.fn() } } } | ||
| }) | ||
| await expect(downloadWebLLMModel({ modelId: 'test-model', signal: controller.signal })).rejects.toMatchObject({ | ||
| name: 'AbortError', | ||
| }) | ||
| expect(unload).toHaveBeenCalled() | ||
| }) | ||
|
|
||
| it('throws when model is not in app config', async () => { | ||
| await expect(downloadWebLLMModel({ modelId: 'nonexistent' })).rejects.toBeInstanceOf(WebLLMModelNotFoundError) | ||
| }) | ||
|
|
||
| it('surfaces engine errors via normalizeError', async () => { | ||
| mockCreateEngine.mockImplementationOnce(async () => { | ||
| throw new Error('webgpu unavailable') | ||
| }) | ||
| await expect(downloadWebLLMModel({ modelId: 'test-model' })).rejects.toThrow('webgpu unavailable') | ||
| }) | ||
| }) | ||
|
|
||
| describe('loadWebLLMModule error handling', () => { | ||
| it('throws WebLLMUnavailableError when environment is not a browser', async () => { | ||
| delete (globalThis as { window?: unknown }).window | ||
| await expect(downloadWebLLMModel({ modelId: 'test-model' })).rejects.toBeInstanceOf(WebLLMUnavailableError) | ||
| }) | ||
| }) | ||
|
|
||
| // Silence unused-helper lint noise | ||
| export type _Unused = MockedFunction<typeof mockCreateEngine> |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: The peer dependency is specified as
"^0.2.79"which for a pre-1.0 package (semver treats 0.x specially) only allows0.2.xpatches. This is correctly conservative. However,@mlc-ai/web-llmhas a history of frequent breaking changes within minor versions (their API changed between 0.2.x releases).Suggestion: Consider whether pinning more tightly (e.g.
~0.2.79or exact0.2.79) would be safer, or alternatively document in the module TSDoc which web-llm API surface you depend on. If the intent is to support a range, add a comment in package.json or the README noting the tested/verified version range.