Skip to content

VinF Hybrid Inference: migrate to LanguageModelMessage #9027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: vaihi-exp-google-ai
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 118 additions & 26 deletions packages/vertexai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {
Availability,
LanguageModel,
LanguageModelCreateOptions,
LanguageModelMessageContent
LanguageModelMessage
} from '../types/language-model';
import { match, stub } from 'sinon';
import { GenerateContentRequest, AIErrorCode } from '../types';
Expand Down Expand Up @@ -138,7 +138,7 @@ describe('ChromeAdapter', () => {
})
).to.be.false;
});
it('returns false if request content has non-user role', async () => {
it('returns false if request content has "function" role', async () => {
const adapter = new ChromeAdapter(
{
availability: async () => Availability.available
Expand All @@ -149,7 +149,7 @@ describe('ChromeAdapter', () => {
await adapter.isAvailable({
contents: [
{
role: 'model',
role: 'function',
parts: []
}
]
Expand Down Expand Up @@ -306,7 +306,7 @@ describe('ChromeAdapter', () => {
} as LanguageModel;
const languageModel = {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
} as LanguageModel;
const createStub = stub(languageModelProvider, 'create').resolves(
languageModel
Expand All @@ -331,8 +331,13 @@ describe('ChromeAdapter', () => {
// Asserts Vertex input type is mapped to Chrome type.
expect(promptStub).to.have.been.calledOnceWith([
{
type: 'text',
content: request.contents[0].parts[0].text
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
]);
// Asserts expected output.
Expand All @@ -352,7 +357,7 @@ describe('ChromeAdapter', () => {
} as LanguageModel;
const languageModel = {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
} as LanguageModel;
const createStub = stub(languageModelProvider, 'create').resolves(
languageModel
Expand Down Expand Up @@ -390,12 +395,17 @@ describe('ChromeAdapter', () => {
// Asserts Vertex input type is mapped to Chrome type.
expect(promptStub).to.have.been.calledOnceWith([
{
type: 'text',
content: request.contents[0].parts[0].text
},
{
type: 'image',
content: match.instanceOf(ImageBitmap)
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
},
{
type: 'image',
content: match.instanceOf(ImageBitmap)
}
]
}
]);
// Asserts expected output.
Expand All @@ -412,7 +422,7 @@ describe('ChromeAdapter', () => {
it('honors prompt options', async () => {
const languageModel = {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
} as LanguageModel;
const languageModelProvider = {
create: () => Promise.resolve(languageModel)
Expand All @@ -436,13 +446,48 @@ describe('ChromeAdapter', () => {
expect(promptStub).to.have.been.calledOnceWith(
[
{
type: 'text',
content: request.contents[0].parts[0].text
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
],
promptOptions
);
});
it('normalizes roles', async () => {
const languageModel = {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
prompt: (p: LanguageModelMessage[]) => Promise.resolve('unused')
} as LanguageModel;
const promptStub = stub(languageModel, 'prompt').resolves('unused');
const languageModelProvider = {
create: () => Promise.resolve(languageModel)
} as LanguageModel;
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device'
);
const request = {
contents: [{ role: 'model', parts: [{ text: 'unused' }] }]
} as GenerateContentRequest;
await adapter.generateContent(request);
expect(promptStub).to.have.been.calledOnceWith([
{
// Asserts Vertex's "model" role normalized to Chrome's "assistant" role.
role: 'assistant',
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
]);
});
});
describe('countTokens', () => {
it('counts tokens is not yet available', async () => {
Expand Down Expand Up @@ -514,8 +559,13 @@ describe('ChromeAdapter', () => {
expect(createStub).to.have.been.calledOnceWith(createOptions);
expect(promptStub).to.have.been.calledOnceWith([
{
type: 'text',
content: request.contents[0].parts[0].text
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
]);
const actual = await toStringArray(response.body!);
Expand Down Expand Up @@ -570,12 +620,17 @@ describe('ChromeAdapter', () => {
expect(createStub).to.have.been.calledOnceWith(createOptions);
expect(promptStub).to.have.been.calledOnceWith([
{
type: 'text',
content: request.contents[0].parts[0].text
},
{
type: 'image',
content: match.instanceOf(ImageBitmap)
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
},
{
type: 'image',
content: match.instanceOf(ImageBitmap)
}
]
}
]);
const actual = await toStringArray(response.body!);
Expand Down Expand Up @@ -611,13 +666,50 @@ describe('ChromeAdapter', () => {
expect(promptStub).to.have.been.calledOnceWith(
[
{
type: 'text',
content: request.contents[0].parts[0].text
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
],
promptOptions
);
});
it('normalizes roles', async () => {
const languageModel = {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
promptStreaming: p => new ReadableStream()
} as LanguageModel;
const promptStub = stub(languageModel, 'promptStreaming').returns(
new ReadableStream()
);
const languageModelProvider = {
create: () => Promise.resolve(languageModel)
} as LanguageModel;
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device'
);
const request = {
contents: [{ role: 'model', parts: [{ text: 'unused' }] }]
} as GenerateContentRequest;
await adapter.generateContentStream(request);
expect(promptStub).to.have.been.calledOnceWith([
{
// Asserts Vertex's "model" role normalized to Chrome's "assistant" role.
role: 'assistant',
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
]);
});
});
});

Expand Down
49 changes: 35 additions & 14 deletions packages/vertexai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import {
InferenceMode,
Part,
AIErrorCode,
OnDeviceParams
OnDeviceParams,
Content,
Role
} from '../types';
import {
Availability,
LanguageModel,
LanguageModelMessageContent
LanguageModelMessage,
LanguageModelMessageContent,
LanguageModelMessageRole
} from '../types/language-model';

/**
Expand Down Expand Up @@ -109,10 +113,8 @@ export class ChromeAdapter {
*/
async generateContent(request: GenerateContentRequest): Promise<Response> {
const session = await this.createSession();
// TODO: support multiple content objects when Chrome supports
// sequence<LanguageModelMessage>
const contents = await Promise.all(
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
request.contents.map(ChromeAdapter.toLanguageModelMessage)
);
const text = await session.prompt(
contents,
Expand All @@ -133,10 +135,8 @@ export class ChromeAdapter {
request: GenerateContentRequest
): Promise<Response> {
const session = await this.createSession();
// TODO: support multiple content objects when Chrome supports
// sequence<LanguageModelMessage>
const contents = await Promise.all(
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
request.contents.map(ChromeAdapter.toLanguageModelMessage)
);
const stream = await session.promptStreaming(
contents,
Expand All @@ -163,12 +163,8 @@ export class ChromeAdapter {
}

for (const content of request.contents) {
// Returns false if the request contains multiple roles, eg a chat history.
// TODO: remove this guard once LanguageModelMessage is supported.
if (content.role !== 'user') {
logger.debug(
`Non-user role "${content.role}" rejected for on-device inference.`
);
if (content.role === 'function') {
logger.debug(`"Function" role rejected for on-device inference.`);
return false;
}

Expand Down Expand Up @@ -227,6 +223,21 @@ export class ChromeAdapter {
});
}

/**
* Converts Vertex {@link Content} object to a Chrome {@link LanguageModelMessage} object.
*/
private static async toLanguageModelMessage(
content: Content
): Promise<LanguageModelMessage> {
const languageModelMessageContents = await Promise.all(
content.parts.map(ChromeAdapter.toLanguageModelMessageContent)
);
return {
role: ChromeAdapter.toLanguageModelMessageRole(content.role),
content: languageModelMessageContents
};
}

/**
* Converts a Vertex Part object to a Chrome LanguageModelMessageContent object.
*/
Expand Down Expand Up @@ -254,6 +265,16 @@ export class ChromeAdapter {
throw new Error('Not yet implemented');
}

/**
* Converts a Vertex {@link Role} string to a {@link LanguageModelMessageRole} string.
*/
private static toLanguageModelMessageRole(
role: Role
): LanguageModelMessageRole {
// Assumes 'function' rule has been filtered by isOnDeviceRequest
return role === 'model' ? 'assistant' : 'user';
}

/**
* Abstracts Chrome session creation.
*
Expand Down
14 changes: 9 additions & 5 deletions packages/vertexai/src/types/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* {@see https://github.com/webmachinelearning/prompt-api#full-api-surface-in-web-idl}
*/
export interface LanguageModel extends EventTarget {
create(options?: LanguageModelCreateOptions): Promise<LanguageModel>;
availability(options?: LanguageModelCreateCoreOptions): Promise<Availability>;
Expand Down Expand Up @@ -57,12 +59,14 @@ interface LanguageModelExpectedInput {
type: LanguageModelMessageType;
languages?: string[];
}
// TODO: revert to type from Prompt API explainer once it's supported.
export type LanguageModelPrompt = LanguageModelMessageContent[];
export type LanguageModelPrompt =
| LanguageModelMessage[]
| LanguageModelMessageShorthand[]
| string;
type LanguageModelInitialPrompts =
| LanguageModelMessage[]
| LanguageModelMessageShorthand[];
interface LanguageModelMessage {
export interface LanguageModelMessage {
role: LanguageModelMessageRole;
content: LanguageModelMessageContent[];
}
Expand All @@ -74,7 +78,7 @@ export interface LanguageModelMessageContent {
type: LanguageModelMessageType;
content: LanguageModelMessageContentValue;
}
type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
export type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
type LanguageModelMessageType = 'text' | 'image' | 'audio';
type LanguageModelMessageContentValue =
| ImageBitmapSource
Expand Down
Loading