diff --git a/examples/prefix-caching/README.md b/examples/prefix-caching/README.md
new file mode 100644
index 00000000..419734b1
--- /dev/null
+++ b/examples/prefix-caching/README.md
@@ -0,0 +1,14 @@
+# WebLLM App for Prefix Caching Demo
+
+This example demonstrates the use of `cachedPrefixes` in WebLLM.
+To try it out, you can do the following steps under this folder
+
+```bash
+npm install
+npm start
+```
+
+Note if you would like to hack WebLLM core package.
+You can change web-llm dependencies as `"file:../.."`, and follow the build from source
+instruction in the project to build webllm locally. This option is only recommended
+if you would like to hack WebLLM core package.
diff --git a/examples/prefix-caching/package.json b/examples/prefix-caching/package.json
new file mode 100644
index 00000000..d005440f
--- /dev/null
+++ b/examples/prefix-caching/package.json
@@ -0,0 +1,20 @@
+{
+ "name": "prefix-caching-example",
+ "version": "0.1.0",
+ "private": true,
+ "scripts": {
+ "start": "parcel src/prefix-caching.html --port 8888",
+ "build": "parcel build src/prefix-caching.html --dist-dir lib"
+ },
+ "devDependencies": {
+ "buffer": "^5.7.1",
+ "parcel": "^2.8.3",
+ "process": "^0.11.10",
+ "tslib": "^2.3.1",
+ "typescript": "^4.9.5",
+ "url": "^0.11.3"
+ },
+ "dependencies": {
+ "@mlc-ai/web-llm": "^0.2.78"
+ }
+}
diff --git a/examples/prefix-caching/src/prefix-caching.html b/examples/prefix-caching/src/prefix-caching.html
new file mode 100644
index 00000000..944e94fe
--- /dev/null
+++ b/examples/prefix-caching/src/prefix-caching.html
@@ -0,0 +1,23 @@
+
+
+
+
+ WebLLM Prefix Caching Test Page
+ Open console to see output
+
+
+
+
+ Prompt
+
+
+ Response
+
+
+
+
+
+
+
diff --git a/examples/prefix-caching/src/prefix-caching.ts b/examples/prefix-caching/src/prefix-caching.ts
new file mode 100644
index 00000000..456e9b35
--- /dev/null
+++ b/examples/prefix-caching/src/prefix-caching.ts
@@ -0,0 +1,142 @@
+import * as webllm from "@mlc-ai/web-llm";
+
+const SYSTEM_PROMPT_PREFIX =
+ "You are a helpful assistant running in the user's browser, responsible for answering questions.";
+
+function setLabel(id: string, text: string) {
+ const label = document.getElementById(id);
+ if (label == null) {
+ throw Error("Cannot find label " + id);
+ }
+ label.innerText = text;
+}
+
+async function testPrefix() {
+ const initProgressCallback = (report: webllm.InitProgressReport) => {
+ setLabel("init-label", report.text);
+ };
+
+ const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ selectedModel,
+ {
+ initProgressCallback: initProgressCallback,
+ logLevel: "INFO",
+ // Prefilling KV cache for efficiency
+ cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]],
+ },
+ {
+ context_window_size: 2048,
+ },
+ );
+
+ const reply_using_prefix = await engine.chat.completions.create({
+ messages: [
+ { role: "system", content: SYSTEM_PROMPT_PREFIX },
+ { role: "user", content: "List three US states." },
+ ],
+ // below configurations are all optional
+ n: 1,
+ temperature: 1.5,
+ max_tokens: 64,
+ logprobs: true,
+ top_logprobs: 2,
+ });
+ console.log(reply_using_prefix);
+ console.log(reply_using_prefix.usage);
+}
+
+async function testWithoutPrefix() {
+ const initProgressCallback = (report: webllm.InitProgressReport) => {
+ setLabel("init-label", report.text);
+ };
+
+ const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
+ // Engine Initialization without cachedPrefixes
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ selectedModel,
+ {
+ initProgressCallback: initProgressCallback,
+ logLevel: "INFO",
+ },
+ {
+ context_window_size: 2048,
+ },
+ );
+
+ const reply_without_prefix = await engine.chat.completions.create({
+ messages: [
+ { role: "system", content: SYSTEM_PROMPT_PREFIX },
+ { role: "user", content: "List three US states." },
+ ],
+ // below configurations are all optional
+ n: 1,
+ temperature: 1.5,
+ max_tokens: 64,
+ logprobs: true,
+ top_logprobs: 2,
+ });
+ console.log(reply_without_prefix);
+ console.log(reply_without_prefix.usage);
+}
+
+async function testMultiRound() {
+ const initProgressCallback = (report: webllm.InitProgressReport) => {
+ setLabel("init-label", report.text);
+ };
+
+ const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ selectedModel,
+ {
+ initProgressCallback: initProgressCallback,
+ logLevel: "INFO",
+ cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]], // Prefilling KV cache for efficiency
+ },
+ {
+ context_window_size: 2048,
+ },
+ );
+
+ // First Completion with cachedPrefixes
+ const reply0 = await engine.chat.completions.create({
+ messages: [
+ { role: "system", content: SYSTEM_PROMPT_PREFIX },
+ { role: "user", content: "List three US states." },
+ ],
+ // below configurations are all optional
+ n: 1,
+ temperature: 1.5,
+ max_tokens: 64,
+ logprobs: true,
+ top_logprobs: 2,
+ });
+ console.log(reply0);
+ console.log(reply0.usage);
+
+ // Second Completion with cachedPrefixes
+ const reply1 = await engine.chat.completions.create({
+ messages: [
+ { role: "system", content: SYSTEM_PROMPT_PREFIX },
+ { role: "user", content: "Where is the US capital?" },
+ ],
+ // below configurations are all optional
+ n: 1,
+ temperature: 1.5,
+ max_tokens: 64,
+ logprobs: true,
+ top_logprobs: 2,
+ });
+ console.log(reply1);
+ console.log(reply1.usage);
+}
+
+async function main() {
+ await testPrefix();
+
+ await testWithoutPrefix();
+
+ await testMultiRound();
+}
+
+main();
diff --git a/src/config.ts b/src/config.ts
index 3f6a39a6..d1c855a7 100644
--- a/src/config.ts
+++ b/src/config.ts
@@ -9,6 +9,7 @@ import {
NonNegativeError,
RangeError,
} from "./error";
+import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion";
/**
* Conversation template config
@@ -105,15 +106,20 @@ export interface ChatOptions extends Partial {}
* appConfig: Configure the app, including the list of models and whether to use IndexedDB cache.
* initProgressCallback: A callback for showing the progress of loading the model.
* logitProcessorRegistry: A register for stateful logit processors, see `webllm.LogitProcessor`.
+ * cachedPrefixes: Specifies a system prompt (prefix) that will be prefilled when loading the engine
+ * to create their corresponding KV cache and store them for reuse. These cached kv pairs persist
+ * until the engine is reloaded.
*
* @note All fields are optional, and `logitProcessorRegistry` is only used for `MLCEngine` and not
* other `MLCEngine`s.
+ * @note cachedPrefixes is experimental. It may change in future versions.
*/
export interface MLCEngineConfig {
appConfig?: AppConfig;
initProgressCallback?: InitProgressCallback;
logitProcessorRegistry?: Map;
logLevel?: LogLevel;
+ cachedPrefixes?: ChatCompletionMessageParam[][];
}
/**
diff --git a/src/engine.ts b/src/engine.ts
index 6fe3f577..c10aefc8 100644
--- a/src/engine.ts
+++ b/src/engine.ts
@@ -131,6 +131,7 @@ export class MLCEngine implements MLCEngineInterface {
private logitProcessorRegistry?: Map;
private initProgressCallback?: InitProgressCallback;
private appConfig: AppConfig;
+ private cachedPrefixes: ChatCompletionMessageParam[][];
// Signals and flags
private interruptSignal = false;
@@ -149,6 +150,7 @@ export class MLCEngine implements MLCEngineInterface {
this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel);
this.setInitProgressCallback(engineConfig?.initProgressCallback);
this.setLogitProcessorRegistry(engineConfig?.logitProcessorRegistry);
+ this.cachedPrefixes = engineConfig?.cachedPrefixes || [];
this.chat = new API.Chat(this);
this.completions = new API.Completions(this);
@@ -392,6 +394,16 @@ export class MLCEngine implements MLCEngineInterface {
this.loadedModelIdToPipeline.set(modelId, newPipeline);
this.loadedModelIdToLock.set(modelId, new CustomLock());
+ // Call prefillConvSequence() if cachedPrefixes is specified
+ if (
+ newPipeline instanceof LLMChatPipeline &&
+ this.cachedPrefixes.length > 0
+ ) {
+ for (let i = 0; i < this.cachedPrefixes.length; i++) {
+ await newPipeline.prefillConvSequence(this.cachedPrefixes[i]);
+ }
+ }
+
// Clean up
const tend = performance.now();
if (this.initProgressCallback !== undefined) {
diff --git a/src/llm_chat.ts b/src/llm_chat.ts
index 7050e00b..96ecbcf7 100644
--- a/src/llm_chat.ts
+++ b/src/llm_chat.ts
@@ -34,9 +34,13 @@ import {
PrefillChunkSizeSmallerThanImageError,
CannotFindImageEmbedError,
} from "./error";
+import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion";
type ImageURL = ChatCompletionContentPartImage.ImageURL;
+// Default sequence ID for chat completion
+const CHAT_SEQUENCE_ID = 0;
+
export class LLMChatPipeline {
private config: ChatConfig;
private tokenizer: Tokenizer;
@@ -128,6 +132,8 @@ export class LLMChatPipeline {
private curRoundGrammarInitTotalTime = 0;
// Total time of getting next bitmask and accepting token in seconds
private curRoundGrammarPerTokenTotalTime = 0;
+ private seqIdToPrefix: Map;
+ private nextSequenceId: number;
constructor(
tvm: tvmjs.Instance,
@@ -173,6 +179,8 @@ export class LLMChatPipeline {
log.info("token_postproc_method: ", this.token_postproc_method);
log.info("prepend_space_in_encode: ", this.prepend_space_in_encode);
+ this.seqIdToPrefix = new Map();
+ this.nextSequenceId = 0;
this.device = this.tvm.webgpu();
// 1. Create VM and get the core functions
@@ -344,7 +352,12 @@ export class LLMChatPipeline {
* Reset KV Cache
*/
resetKVCache() {
- this.fclearKVCaches(this.kvCache);
+ // Check whether to keep prefixes in the KV cache
+ if (this.seqIdToPrefix.size === 0) {
+ this.fclearKVCaches(this.kvCache);
+ } else {
+ this.fKVCacheRemoveSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
+ }
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
if (this.slidingWindowSize != -1) {
this.fKVCacheEnableSlidingWindowForSeq(
@@ -483,6 +496,15 @@ export class LLMChatPipeline {
await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule());
}
+ matchPrefix(inputTokens: number[], prefixTokens: number[]): number {
+ for (let i = 0; i < prefixTokens.length; i++) {
+ if (inputTokens[i] !== prefixTokens[i]) {
+ return i;
+ }
+ }
+ return prefixTokens.length;
+ }
+
/**
* Generate the first token given input prompt
*/
@@ -491,11 +513,17 @@ export class LLMChatPipeline {
msgRole: Role, // either user or tool
inp_role_str?: string,
genConfig?: GenerationConfig,
+ seqID = CHAT_SEQUENCE_ID,
): Promise {
- if (msgRole !== Role.user && msgRole !== Role.tool) {
- throw new MessageOrderError(
- "The last message should be from `user` or `tool`.",
- );
+ if (seqID === CHAT_SEQUENCE_ID) {
+ if (msgRole !== Role.user && msgRole !== Role.tool) {
+ throw new MessageOrderError(
+ "The last message should be from `user` or `tool`.",
+ );
+ }
+ } else {
+ // Set the input as system prompt during prefix prefilling
+ this.conversation.override_system_message = inp;
}
if (this.resetStatsPerPrefill) {
this.resetRuntimeStats();
@@ -583,11 +611,13 @@ export class LLMChatPipeline {
}
// 0. Get inputData from conversation
- if (conversation.isTextCompletion) {
- conversation.prompt = inp;
- } else {
- conversation.appendMessage(msgRole, inp, inp_role_str);
- conversation.appendReplyHeader(Role.assistant);
+ if (seqID === CHAT_SEQUENCE_ID) {
+ if (conversation.isTextCompletion) {
+ conversation.prompt = inp;
+ } else {
+ conversation.appendMessage(msgRole, inp, inp_role_str);
+ conversation.appendReplyHeader(Role.assistant);
+ }
}
const retGetInputData = this.getInputData();
const inputData: Array | ImageURL> = retGetInputData[0];
@@ -610,11 +640,63 @@ export class LLMChatPipeline {
throw new CannotFindImageEmbedError();
}
+ let maxMatchedLen = -1;
+ let matchedSeqId = -1;
+
+ // Prefix matching and forking
+ const inputTokens = inputData.flat() as number[];
+ for (const [id, prefixTokens] of this.seqIdToPrefix) {
+ const matchedLen = this.matchPrefix(inputTokens, prefixTokens);
+ if (matchedLen > maxMatchedLen) {
+ maxMatchedLen = matchedLen;
+ matchedSeqId = id;
+ }
+ }
+
+ // If a match is found, fork the sequence
+ if (matchedSeqId !== -1 && maxMatchedLen > 0) {
+ log.info("Forking sequence", matchedSeqId, "at position", maxMatchedLen);
+ if (seqID === CHAT_SEQUENCE_ID) {
+ this.fKVCacheRemoveSequence!(
+ this.kvCache,
+ new tvmjs.Scalar(seqID, "int64"),
+ );
+ }
+ this.tvm.beginScope();
+ this.tvm.getGlobalFunc("vm.builtin.kv_state_fork_sequence")(
+ this.kvCache,
+ new tvmjs.Scalar(matchedSeqId, "int64"), // fork_parent_id
+ new tvmjs.Scalar(seqID, "int64"), // fork_child_id
+ new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position
+ );
+ this.tvm.endScope();
+ } else if (seqID !== CHAT_SEQUENCE_ID) {
+ // If no match is found, add the new sequence to the KV cache
+ log.info("Adding prefix to KV cache: ", seqID);
+ this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64"));
+ }
+
+ // Add the new sequence to the seqIdToPrefix map (if it is a prefix)
+ if (seqID !== CHAT_SEQUENCE_ID) {
+ this.seqIdToPrefix.set(seqID, inputTokens);
+ }
+
// 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data
- const retGetChunks = getChunkedPrefillInputData(
- inputData,
- this.prefillChunkSize,
- );
+ let retGetChunks;
+ if (maxMatchedLen === -1) {
+ retGetChunks = getChunkedPrefillInputData(
+ inputData,
+ this.prefillChunkSize,
+ );
+ } else {
+ // If a matched prefix exists, only forward the remaining tokens
+ retGetChunks = getChunkedPrefillInputData(
+ inputData.map((arr) =>
+ Array.isArray(arr) ? arr.slice(maxMatchedLen) : arr,
+ ),
+ this.prefillChunkSize,
+ );
+ }
const chunks: Array | ImageURL>[] = retGetChunks[0];
const chunkLens: Array = retGetChunks[1];
@@ -626,7 +708,7 @@ export class LLMChatPipeline {
const chunkLen = chunkLens[i];
const prevFilledLen = this.filledKVCacheLength;
logits = this.tvm.detachFromCurrentScope(
- await this.embedAndForward(chunk, chunkLen),
+ await this.embedAndForward(chunk, chunkLen, seqID),
);
if (this.filledKVCacheLength !== prevFilledLen + chunkLen) {
throw new Error(
@@ -651,6 +733,41 @@ export class LLMChatPipeline {
this.processNextToken(nextToken, genConfig);
}
+ async prefillConvSequence(
+ messages: ChatCompletionMessageParam[],
+ inp_role_str?: string,
+ genConfig?: GenerationConfig,
+ ): Promise {
+ for (const message of messages) {
+ this.nextSequenceId = this.nextSequenceId + 1;
+ const newSeqId = this.nextSequenceId;
+ // Call the regular prefillStep with the new seqID
+ if (typeof message.content === "string") {
+ // Support long system prompt
+ if (message.role === "system") {
+ await this.prefillStep(
+ message.content,
+ Role.tool,
+ inp_role_str,
+ genConfig,
+ newSeqId,
+ );
+ } else {
+ throw Error(
+ "Invalid role in prefix message: " +
+ message.role +
+ ", expected 'system'.",
+ );
+ }
+ } else {
+ throw Error(
+ "Invalid content in prefix message, does not support image input.",
+ );
+ }
+ }
+ this.conversation.reset();
+ }
+
async decodeStep(genConfig?: GenerationConfig): Promise {
if (this.stopTriggered) {
throw Error("Cannot run decode when stopped");
@@ -869,6 +986,7 @@ export class LLMChatPipeline {
*
* @param inputData data to embed and forward
* @param inputDataLen length of this inputData, should smaller than prefill chunk size.
+ * @param seqID sequence ID of the input data in KV cache for prefix caching
* @returns The logits returned by this forward as tvmjs.NDArray on GPU.
*
* @note Precondition: inputData's data length is smaller than prefill chunk size
@@ -876,6 +994,7 @@ export class LLMChatPipeline {
private async embedAndForward(
inputData: Array | ImageURL>,
inputDataLen: number,
+ seqID = CHAT_SEQUENCE_ID,
): Promise {
if (inputDataLen > this.prefillChunkSize) {
throw new Error(
@@ -913,7 +1032,8 @@ export class LLMChatPipeline {
// 3. Forward the concatenated embeddings
const inputLenShape = this.tvm.makeShapeTuple([inputDataLen]);
- const seqIdsTuple = this.tvm.makeShapeTuple([0]);
+ // set seqIdsTuple to be childID
+ const seqIdsTuple = this.tvm.makeShapeTuple([seqID]);
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape);
let retValue;
if (inputDataLen > 1) {