Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions intelligence/ts/src/engines/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ interface ModelResponse {
model: string | undefined;
}

export async function checkSupport(model: string, engine: string): Promise<Result<string>> {
export async function getEngineModelName(model: string, engine: string): Promise<Result<string>> {
try {
const response = await fetch(`${REMOTE_URL}/v1/fetch-model-config`, {
method: 'POST',
Expand Down Expand Up @@ -54,7 +54,7 @@ export async function checkSupport(model: string, engine: string): Promise<Resul
ok: false,
failure: {
code: FailureCode.UnsupportedModelError,
description: `Model '${model}' is not supported on the webllm engine.`,
description: `Model '${model}' is not supported on the ${engine} engine.`,
},
};
}
Expand Down
9 changes: 3 additions & 6 deletions intelligence/ts/src/engines/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export interface Engine {
encrypt?: boolean
): Promise<ChatResponseResult>;
fetchModel(model: string, callback: (progress: Progress) => void): Promise<Result<void>>;
isSupported(model: string): Promise<Result<string>>;
isSupported(model: string): Promise<boolean>;
}

export abstract class BaseEngine implements Engine {
Expand Down Expand Up @@ -64,11 +64,8 @@ export abstract class BaseEngine implements Engine {
};
}

async isSupported(_model: string): Promise<Result<string>> {
async isSupported(_model: string): Promise<boolean> {
await Promise.resolve();
return {
ok: false,
failure: { code: FailureCode.NotImplementedError, description: 'Method not implemented.' },
};
return false;
}
}
7 changes: 2 additions & 5 deletions intelligence/ts/src/engines/remoteEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,9 @@ export class RemoteEngine extends BaseEngine {
};
}

async isSupported(model: string): Promise<Result<string>> {
async isSupported(_model: string): Promise<boolean> {
await Promise.resolve();
return {
ok: true,
value: model,
};
return true;
}

private createRequestData(
Expand Down
33 changes: 27 additions & 6 deletions intelligence/ts/src/engines/transformersEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import type { ProgressInfo, TextGenerationConfig } from '@huggingface/transforme
import { FailureCode, Message, Result, Progress, ChatResponseResult } from '../typing';

import { BaseEngine } from './engine';
import { checkSupport } from './common';
import { getEngineModelName } from './common';

const stoppingCriteria = new InterruptableStoppingCriteria();
const choice = 0;
Expand All @@ -42,10 +42,20 @@ export class TransformersEngine extends BaseEngine {
stream?: boolean,
onStreamEvent?: (event: { chunk: string }) => void
): Promise<ChatResponseResult> {
const modelNameRes = await getEngineModelName(model, 'onnx');
if (!modelNameRes.ok) {
return {
ok: false,
failure: {
code: FailureCode.UnsupportedModelError,
description: `The model ${model} is not supported on the Transformers.js engine.`,
},
};
}
try {
if (!(model in this.generationPipelines)) {
let options = {};
const modelElems = model.split('|');
const modelElems = modelNameRes.value.split('|');
const modelId = modelElems[0];
if (modelElems.length > 1) {
options = {
Expand Down Expand Up @@ -126,16 +136,26 @@ export class TransformersEngine extends BaseEngine {
ok: false,
failure: {
code: FailureCode.LocalEngineChatError,
description: `TransformersEngine failed with: ${String(error)}`,
description: `Transformers.js engine failed with: ${String(error)}`,
},
};
}
}

async fetchModel(model: string, callback: (progress: Progress) => void): Promise<Result<void>> {
const modelNameRes = await getEngineModelName(model, 'onnx');
if (!modelNameRes.ok) {
return {
ok: false,
failure: {
code: FailureCode.UnsupportedModelError,
description: `The model ${model} is not supported on the Transformers.js engine.`,
},
};
}
try {
if (!(model in this.generationPipelines)) {
this.generationPipelines.model = await pipeline('text-generation', model, {
this.generationPipelines.model = await pipeline('text-generation', modelNameRes.value, {
dtype: 'q4',
progress_callback: (progressInfo: ProgressInfo) => {
let percentage = 0;
Expand Down Expand Up @@ -169,7 +189,8 @@ export class TransformersEngine extends BaseEngine {
}
}

async isSupported(model: string): Promise<Result<string>> {
return await checkSupport(model, 'onnx');
async isSupported(model: string): Promise<boolean> {
const modelNameRes = await getEngineModelName(model, 'onnx');
return modelNameRes.ok;
}
}
31 changes: 26 additions & 5 deletions intelligence/ts/src/engines/webllmEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
Tool,
} from '../typing';
import { BaseEngine } from './engine';
import { checkSupport } from './common';
import { getEngineModelName } from './common';

async function runQuery(
engine: MLCEngineInterface,
Expand Down Expand Up @@ -72,10 +72,20 @@ export class WebllmEngine extends BaseEngine {
onStreamEvent?: (event: StreamEvent) => void,
_tools?: Tool[]
): Promise<ChatResponseResult> {
const modelNameRes = await getEngineModelName(model, 'webllm');
if (!modelNameRes.ok) {
return {
ok: false,
failure: {
code: FailureCode.UnsupportedModelError,
description: `The model ${model} is not supported on the WebLLM engine.`,
},
};
}
try {
if (!(model in this.#loadedEngines)) {
this.#loadedEngines.model = await CreateMLCEngine(
model,
modelNameRes.value,
{},
{
context_window_size: 2048,
Expand Down Expand Up @@ -109,10 +119,20 @@ export class WebllmEngine extends BaseEngine {
}

async fetchModel(model: string, callback: (progress: Progress) => void): Promise<Result<void>> {
const modelNameRes = await getEngineModelName(model, 'webllm');
if (!modelNameRes.ok) {
return {
ok: false,
failure: {
code: FailureCode.UnsupportedModelError,
description: `The model ${model} is not supported on the WebLLM engine.`,
},
};
}
try {
if (!(model in this.#loadedEngines)) {
this.#loadedEngines.model = await CreateMLCEngine(
model,
modelNameRes.value,
{
initProgressCallback: (report: InitProgressReport) => {
callback({ percentage: report.progress, description: report.text });
Expand All @@ -132,7 +152,8 @@ export class WebllmEngine extends BaseEngine {
}
}

async isSupported(model: string): Promise<Result<string>> {
return await checkSupport(model, 'webllm');
async isSupported(model: string): Promise<boolean> {
const modelNameRes = await getEngineModelName(model, 'webllm');
return modelNameRes.ok;
}
}
6 changes: 3 additions & 3 deletions intelligence/ts/src/flowerintelligence.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@
const getEngineRes = await fi['getEngine']('meta/llama3.2-1b/instruct-fp16', true, false);
expect(getEngineRes.ok).toBe(true);
if (getEngineRes.ok) {
expect(getEngineRes.value[0]).toBeInstanceOf(RemoteEngine);
expect(getEngineRes.value).toBeInstanceOf(RemoteEngine);
}
});

it('should return a local engine when the model can run locally', async () => {
const getEngineRes = await fi['getEngine']('meta/llama3.2-1b/instruct-fp16', false, false);
expect(getEngineRes.ok).toBe(true);
if (getEngineRes.ok) {
expect(getEngineRes.value[0]).toBeInstanceOf(TransformersEngine);
expect(getEngineRes.value).toBeInstanceOf(TransformersEngine);

Check failure on line 54 in intelligence/ts/src/flowerintelligence.test.ts

View workflow job for this annotation

GitHub Actions / Tests

src/flowerintelligence.test.ts > FlowerIntelligence > getEngine > should return a local engine when the model can run locally

AssertionError: expected RemoteEngine{ …(3) } to be an instance of TransformersEngine ❯ src/flowerintelligence.test.ts:54:36
}
});

Expand Down Expand Up @@ -86,9 +86,9 @@
describe('chooseLocalEngine', () => {
it('should return a local engine for a valid provider', async () => {
const chooseEngineRes = await fi['chooseLocalEngine']('meta/llama3.2-1b/instruct-fp16');
expect(chooseEngineRes.ok).toBe(true);

Check failure on line 89 in intelligence/ts/src/flowerintelligence.test.ts

View workflow job for this annotation

GitHub Actions / Tests

src/flowerintelligence.test.ts > FlowerIntelligence > chooseLocalEngine > should return a local engine for a valid provider

AssertionError: expected false to be true // Object.is equality - Expected + Received - true + false ❯ src/flowerintelligence.test.ts:89:34
if (chooseEngineRes.ok) {
expect(chooseEngineRes.value[0]).toBeInstanceOf(TransformersEngine);
expect(chooseEngineRes.value).toBeInstanceOf(TransformersEngine);
}
});
});
Expand Down Expand Up @@ -136,7 +136,7 @@
maxCompletionTokens: 5,
});
if (!data.ok) {
assert.fail(data.failure.description);

Check failure on line 139 in intelligence/ts/src/flowerintelligence.test.ts

View workflow job for this annotation

GitHub Actions / Tests

src/flowerintelligence.test.ts > FlowerIntelligence > Chat > generates some reduced text

AssertionError: 500: Internal Server Error ❯ src/flowerintelligence.test.ts:139:16
} else {
expect(data.message.content).not.toBeNull();
if (data.message.content) {
Expand All @@ -162,7 +162,7 @@
const regex = emojiRegex();

if (!data.ok) {
assert.fail(data.failure.description);

Check failure on line 165 in intelligence/ts/src/flowerintelligence.test.ts

View workflow job for this annotation

GitHub Actions / Tests

src/flowerintelligence.test.ts > FlowerIntelligence > Chat > generates some text with custom system prompt

AssertionError: 500: Internal Server Error ❯ src/flowerintelligence.test.ts:165:16
} else {
expect(data.message.content).not.toBeNull();
if (data.message.content) {
Expand Down
31 changes: 16 additions & 15 deletions intelligence/ts/src/flowerintelligence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ export class FlowerIntelligence {
if (!engineResult.ok) {
return engineResult;
} else {
const [engine, modelId] = engineResult.value;
return await engine.fetchModel(modelId, callback);
return await engineResult.value.fetchModel(model, callback);
}
}

Expand Down Expand Up @@ -153,8 +152,9 @@ export class FlowerIntelligence {
({ messages, ...options } = inputOrOptions);
}

const model = options.model ?? DEFAULT_MODEL;
const engineResult = await this.getEngine(
options.model ?? DEFAULT_MODEL,
model,
options.forceRemote ?? false,
options.forceLocal ?? false
);
Expand All @@ -163,10 +163,9 @@ export class FlowerIntelligence {
return engineResult;
}

const [engine, modelId] = engineResult.value;
return await engine.chat(
return await engineResult.value.chat(
messages,
modelId,
model,
options.temperature,
options.maxCompletionTokens,
options.stream,
Expand All @@ -180,33 +179,32 @@ export class FlowerIntelligence {
modelId: string,
forceRemote: boolean,
forceLocal: boolean
): Promise<Result<[Engine, string]>> {
): Promise<Result<Engine>> {
const argsResult = this.validateArgs(forceRemote, forceLocal);
if (!argsResult.ok) {
return argsResult;
}

if (forceRemote) {
return this.getOrCreateRemoteEngine(modelId);
return this.getOrCreateRemoteEngine();
}

const localEngineResult = await this.chooseLocalEngine(modelId);
if (localEngineResult.ok) {
return localEngineResult;
}

return this.getOrCreateRemoteEngine(modelId);
return this.getOrCreateRemoteEngine(localEngineResult);
}

private async chooseLocalEngine(modelId: string): Promise<Result<[Engine, string]>> {
private async chooseLocalEngine(modelId: string): Promise<Result<Engine>> {
const compatibleEngines = (
await Promise.all(
this.#availableLocalEngines.map(async (engine) => {
const supportedResult = await engine.isSupported(modelId);
return supportedResult.ok ? [engine, supportedResult.value] : null;
return (await engine.isSupported(modelId)) ? engine : null;
})
)
).filter((item): item is [Engine, string] => item !== null);
).filter((item): item is Engine => item !== null);

if (compatibleEngines.length > 0) {
// Currently we just select the first compatible localEngine without further check
Expand All @@ -222,7 +220,10 @@ export class FlowerIntelligence {
}
}

private getOrCreateRemoteEngine(modelId: string): Result<[Engine, string]> {
private getOrCreateRemoteEngine(localFailure?: Result<Engine>): Result<Engine> {
if (localFailure && !FlowerIntelligence.#remoteHandoff && !FlowerIntelligence.#apiKey) {
return localFailure;
}
if (!FlowerIntelligence.#remoteHandoff) {
return {
ok: false,
Expand All @@ -242,7 +243,7 @@ export class FlowerIntelligence {
};
}
this.#remoteEngine = this.#remoteEngine ?? new RemoteEngine(FlowerIntelligence.#apiKey);
return { ok: true, value: [this.#remoteEngine, modelId] };
return { ok: true, value: this.#remoteEngine };
}

private validateArgs(forceRemote: boolean, forceLocal: boolean): Result<void> {
Expand Down
Loading