From c22ab66e69c5352aee76f48dbc949ec4dfc602f0 Mon Sep 17 00:00:00 2001 From: Erik Eldridge Date: Tue, 13 May 2025 10:22:55 -0700 Subject: [PATCH] Migrate to LanguageModelMessage --- .../src/methods/chrome-adapter.test.ts | 144 ++++++++++++++---- .../vertexai/src/methods/chrome-adapter.ts | 49 ++++-- packages/vertexai/src/types/language-model.ts | 14 +- 3 files changed, 162 insertions(+), 45 deletions(-) diff --git a/packages/vertexai/src/methods/chrome-adapter.test.ts b/packages/vertexai/src/methods/chrome-adapter.test.ts index fbe7ec1a5c5..20150845e0f 100644 --- a/packages/vertexai/src/methods/chrome-adapter.test.ts +++ b/packages/vertexai/src/methods/chrome-adapter.test.ts @@ -24,7 +24,7 @@ import { Availability, LanguageModel, LanguageModelCreateOptions, - LanguageModelMessageContent + LanguageModelMessage } from '../types/language-model'; import { match, stub } from 'sinon'; import { GenerateContentRequest, AIErrorCode } from '../types'; @@ -138,7 +138,7 @@ describe('ChromeAdapter', () => { }) ).to.be.false; }); - it('returns false if request content has non-user role', async () => { + it('returns false if request content has "function" role', async () => { const adapter = new ChromeAdapter( { availability: async () => Availability.available @@ -149,7 +149,7 @@ describe('ChromeAdapter', () => { await adapter.isAvailable({ contents: [ { - role: 'model', + role: 'function', parts: [] } ] @@ -306,7 +306,7 @@ describe('ChromeAdapter', () => { } as LanguageModel; const languageModel = { // eslint-disable-next-line @typescript-eslint/no-unused-vars - prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('') + prompt: (p: LanguageModelMessage[]) => Promise.resolve('') } as LanguageModel; const createStub = stub(languageModelProvider, 'create').resolves( languageModel @@ -331,8 +331,13 @@ describe('ChromeAdapter', () => { // Asserts Vertex input type is mapped to Chrome type. expect(promptStub).to.have.been.calledOnceWith([ { - type: 'text', - content: request.contents[0].parts[0].text + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] } ]); // Asserts expected output. @@ -352,7 +357,7 @@ describe('ChromeAdapter', () => { } as LanguageModel; const languageModel = { // eslint-disable-next-line @typescript-eslint/no-unused-vars - prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('') + prompt: (p: LanguageModelMessage[]) => Promise.resolve('') } as LanguageModel; const createStub = stub(languageModelProvider, 'create').resolves( languageModel @@ -390,12 +395,17 @@ describe('ChromeAdapter', () => { // Asserts Vertex input type is mapped to Chrome type. expect(promptStub).to.have.been.calledOnceWith([ { - type: 'text', - content: request.contents[0].parts[0].text - }, - { - type: 'image', - content: match.instanceOf(ImageBitmap) + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + }, + { + type: 'image', + content: match.instanceOf(ImageBitmap) + } + ] } ]); // Asserts expected output. @@ -412,7 +422,7 @@ describe('ChromeAdapter', () => { it('honors prompt options', async () => { const languageModel = { // eslint-disable-next-line @typescript-eslint/no-unused-vars - prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('') + prompt: (p: LanguageModelMessage[]) => Promise.resolve('') } as LanguageModel; const languageModelProvider = { create: () => Promise.resolve(languageModel) @@ -436,13 +446,48 @@ describe('ChromeAdapter', () => { expect(promptStub).to.have.been.calledOnceWith( [ { - type: 'text', - content: request.contents[0].parts[0].text + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] } ], promptOptions ); }); + it('normalizes roles', async () => { + const languageModel = { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + prompt: (p: LanguageModelMessage[]) => Promise.resolve('unused') + } as LanguageModel; + const promptStub = stub(languageModel, 'prompt').resolves('unused'); + const languageModelProvider = { + create: () => Promise.resolve(languageModel) + } as LanguageModel; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device' + ); + const request = { + contents: [{ role: 'model', parts: [{ text: 'unused' }] }] + } as GenerateContentRequest; + await adapter.generateContent(request); + expect(promptStub).to.have.been.calledOnceWith([ + { + // Asserts Vertex's "model" role normalized to Chrome's "assistant" role. + role: 'assistant', + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] + } + ]); + }); }); describe('countTokens', () => { it('counts tokens is not yet available', async () => { @@ -514,8 +559,13 @@ describe('ChromeAdapter', () => { expect(createStub).to.have.been.calledOnceWith(createOptions); expect(promptStub).to.have.been.calledOnceWith([ { - type: 'text', - content: request.contents[0].parts[0].text + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] } ]); const actual = await toStringArray(response.body!); @@ -570,12 +620,17 @@ describe('ChromeAdapter', () => { expect(createStub).to.have.been.calledOnceWith(createOptions); expect(promptStub).to.have.been.calledOnceWith([ { - type: 'text', - content: request.contents[0].parts[0].text - }, - { - type: 'image', - content: match.instanceOf(ImageBitmap) + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + }, + { + type: 'image', + content: match.instanceOf(ImageBitmap) + } + ] } ]); const actual = await toStringArray(response.body!); @@ -611,13 +666,50 @@ describe('ChromeAdapter', () => { expect(promptStub).to.have.been.calledOnceWith( [ { - type: 'text', - content: request.contents[0].parts[0].text + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] } ], promptOptions ); }); + it('normalizes roles', async () => { + const languageModel = { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + promptStreaming: p => new ReadableStream() + } as LanguageModel; + const promptStub = stub(languageModel, 'promptStreaming').returns( + new ReadableStream() + ); + const languageModelProvider = { + create: () => Promise.resolve(languageModel) + } as LanguageModel; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device' + ); + const request = { + contents: [{ role: 'model', parts: [{ text: 'unused' }] }] + } as GenerateContentRequest; + await adapter.generateContentStream(request); + expect(promptStub).to.have.been.calledOnceWith([ + { + // Asserts Vertex's "model" role normalized to Chrome's "assistant" role. + role: 'assistant', + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] + } + ]); + }); }); }); diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index aa3709048a2..5683ecc2fd5 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -23,12 +23,16 @@ import { InferenceMode, Part, AIErrorCode, - OnDeviceParams + OnDeviceParams, + Content, + Role } from '../types'; import { Availability, LanguageModel, - LanguageModelMessageContent + LanguageModelMessage, + LanguageModelMessageContent, + LanguageModelMessageRole } from '../types/language-model'; /** @@ -109,10 +113,8 @@ export class ChromeAdapter { */ async generateContent(request: GenerateContentRequest): Promise { const session = await this.createSession(); - // TODO: support multiple content objects when Chrome supports - // sequence const contents = await Promise.all( - request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent) + request.contents.map(ChromeAdapter.toLanguageModelMessage) ); const text = await session.prompt( contents, @@ -133,10 +135,8 @@ export class ChromeAdapter { request: GenerateContentRequest ): Promise { const session = await this.createSession(); - // TODO: support multiple content objects when Chrome supports - // sequence const contents = await Promise.all( - request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent) + request.contents.map(ChromeAdapter.toLanguageModelMessage) ); const stream = await session.promptStreaming( contents, @@ -163,12 +163,8 @@ export class ChromeAdapter { } for (const content of request.contents) { - // Returns false if the request contains multiple roles, eg a chat history. - // TODO: remove this guard once LanguageModelMessage is supported. - if (content.role !== 'user') { - logger.debug( - `Non-user role "${content.role}" rejected for on-device inference.` - ); + if (content.role === 'function') { + logger.debug(`"Function" role rejected for on-device inference.`); return false; } @@ -227,6 +223,21 @@ export class ChromeAdapter { }); } + /** + * Converts Vertex {@link Content} object to a Chrome {@link LanguageModelMessage} object. + */ + private static async toLanguageModelMessage( + content: Content + ): Promise { + const languageModelMessageContents = await Promise.all( + content.parts.map(ChromeAdapter.toLanguageModelMessageContent) + ); + return { + role: ChromeAdapter.toLanguageModelMessageRole(content.role), + content: languageModelMessageContents + }; + } + /** * Converts a Vertex Part object to a Chrome LanguageModelMessageContent object. */ @@ -254,6 +265,16 @@ export class ChromeAdapter { throw new Error('Not yet implemented'); } + /** + * Converts a Vertex {@link Role} string to a {@link LanguageModelMessageRole} string. + */ + private static toLanguageModelMessageRole( + role: Role + ): LanguageModelMessageRole { + // Assumes 'function' rule has been filtered by isOnDeviceRequest + return role === 'model' ? 'assistant' : 'user'; + } + /** * Abstracts Chrome session creation. * diff --git a/packages/vertexai/src/types/language-model.ts b/packages/vertexai/src/types/language-model.ts index 22916e7ff96..d8751ff5eed 100644 --- a/packages/vertexai/src/types/language-model.ts +++ b/packages/vertexai/src/types/language-model.ts @@ -14,7 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +/** + * {@see https://github.com/webmachinelearning/prompt-api#full-api-surface-in-web-idl} + */ export interface LanguageModel extends EventTarget { create(options?: LanguageModelCreateOptions): Promise; availability(options?: LanguageModelCreateCoreOptions): Promise; @@ -57,12 +59,14 @@ interface LanguageModelExpectedInput { type: LanguageModelMessageType; languages?: string[]; } -// TODO: revert to type from Prompt API explainer once it's supported. -export type LanguageModelPrompt = LanguageModelMessageContent[]; +export type LanguageModelPrompt = + | LanguageModelMessage[] + | LanguageModelMessageShorthand[] + | string; type LanguageModelInitialPrompts = | LanguageModelMessage[] | LanguageModelMessageShorthand[]; -interface LanguageModelMessage { +export interface LanguageModelMessage { role: LanguageModelMessageRole; content: LanguageModelMessageContent[]; } @@ -74,7 +78,7 @@ export interface LanguageModelMessageContent { type: LanguageModelMessageType; content: LanguageModelMessageContentValue; } -type LanguageModelMessageRole = 'system' | 'user' | 'assistant'; +export type LanguageModelMessageRole = 'system' | 'user' | 'assistant'; type LanguageModelMessageType = 'text' | 'image' | 'audio'; type LanguageModelMessageContentValue = | ImageBitmapSource