diff --git a/examples/prefix-caching/README.md b/examples/prefix-caching/README.md new file mode 100644 index 00000000..419734b1 --- /dev/null +++ b/examples/prefix-caching/README.md @@ -0,0 +1,14 @@ +# WebLLM App for Prefix Caching Demo + +This example demonstrates the use of `cachedPrefixes` in WebLLM. +To try it out, you can do the following steps under this folder + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/examples/prefix-caching/package.json b/examples/prefix-caching/package.json new file mode 100644 index 00000000..d005440f --- /dev/null +++ b/examples/prefix-caching/package.json @@ -0,0 +1,20 @@ +{ + "name": "prefix-caching-example", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/prefix-caching.html --port 8888", + "build": "parcel build src/prefix-caching.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "^0.2.78" + } +} diff --git a/examples/prefix-caching/src/prefix-caching.html b/examples/prefix-caching/src/prefix-caching.html new file mode 100644 index 00000000..944e94fe --- /dev/null +++ b/examples/prefix-caching/src/prefix-caching.html @@ -0,0 +1,23 @@ + + + + +

WebLLM Prefix Caching Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + + diff --git a/examples/prefix-caching/src/prefix-caching.ts b/examples/prefix-caching/src/prefix-caching.ts new file mode 100644 index 00000000..456e9b35 --- /dev/null +++ b/examples/prefix-caching/src/prefix-caching.ts @@ -0,0 +1,142 @@ +import * as webllm from "@mlc-ai/web-llm"; + +const SYSTEM_PROMPT_PREFIX = + "You are a helpful assistant running in the user's browser, responsible for answering questions."; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +async function testPrefix() { + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + + const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", + // Prefilling KV cache for efficiency + cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]], + }, + { + context_window_size: 2048, + }, + ); + + const reply_using_prefix = await engine.chat.completions.create({ + messages: [ + { role: "system", content: SYSTEM_PROMPT_PREFIX }, + { role: "user", content: "List three US states." }, + ], + // below configurations are all optional + n: 1, + temperature: 1.5, + max_tokens: 64, + logprobs: true, + top_logprobs: 2, + }); + console.log(reply_using_prefix); + console.log(reply_using_prefix.usage); +} + +async function testWithoutPrefix() { + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + + const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC"; + // Engine Initialization without cachedPrefixes + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", + }, + { + context_window_size: 2048, + }, + ); + + const reply_without_prefix = await engine.chat.completions.create({ + messages: [ + { role: "system", content: SYSTEM_PROMPT_PREFIX }, + { role: "user", content: "List three US states." }, + ], + // below configurations are all optional + n: 1, + temperature: 1.5, + max_tokens: 64, + logprobs: true, + top_logprobs: 2, + }); + console.log(reply_without_prefix); + console.log(reply_without_prefix.usage); +} + +async function testMultiRound() { + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + + const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", + cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]], // Prefilling KV cache for efficiency + }, + { + context_window_size: 2048, + }, + ); + + // First Completion with cachedPrefixes + const reply0 = await engine.chat.completions.create({ + messages: [ + { role: "system", content: SYSTEM_PROMPT_PREFIX }, + { role: "user", content: "List three US states." }, + ], + // below configurations are all optional + n: 1, + temperature: 1.5, + max_tokens: 64, + logprobs: true, + top_logprobs: 2, + }); + console.log(reply0); + console.log(reply0.usage); + + // Second Completion with cachedPrefixes + const reply1 = await engine.chat.completions.create({ + messages: [ + { role: "system", content: SYSTEM_PROMPT_PREFIX }, + { role: "user", content: "Where is the US capital?" }, + ], + // below configurations are all optional + n: 1, + temperature: 1.5, + max_tokens: 64, + logprobs: true, + top_logprobs: 2, + }); + console.log(reply1); + console.log(reply1.usage); +} + +async function main() { + await testPrefix(); + + await testWithoutPrefix(); + + await testMultiRound(); +} + +main(); diff --git a/src/config.ts b/src/config.ts index 3f6a39a6..d1c855a7 100644 --- a/src/config.ts +++ b/src/config.ts @@ -9,6 +9,7 @@ import { NonNegativeError, RangeError, } from "./error"; +import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion"; /** * Conversation template config @@ -105,15 +106,20 @@ export interface ChatOptions extends Partial {} * appConfig: Configure the app, including the list of models and whether to use IndexedDB cache. * initProgressCallback: A callback for showing the progress of loading the model. * logitProcessorRegistry: A register for stateful logit processors, see `webllm.LogitProcessor`. + * cachedPrefixes: Specifies a system prompt (prefix) that will be prefilled when loading the engine + * to create their corresponding KV cache and store them for reuse. These cached kv pairs persist + * until the engine is reloaded. * * @note All fields are optional, and `logitProcessorRegistry` is only used for `MLCEngine` and not * other `MLCEngine`s. + * @note cachedPrefixes is experimental. It may change in future versions. */ export interface MLCEngineConfig { appConfig?: AppConfig; initProgressCallback?: InitProgressCallback; logitProcessorRegistry?: Map; logLevel?: LogLevel; + cachedPrefixes?: ChatCompletionMessageParam[][]; } /** diff --git a/src/engine.ts b/src/engine.ts index 6fe3f577..c10aefc8 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -131,6 +131,7 @@ export class MLCEngine implements MLCEngineInterface { private logitProcessorRegistry?: Map; private initProgressCallback?: InitProgressCallback; private appConfig: AppConfig; + private cachedPrefixes: ChatCompletionMessageParam[][]; // Signals and flags private interruptSignal = false; @@ -149,6 +150,7 @@ export class MLCEngine implements MLCEngineInterface { this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel); this.setInitProgressCallback(engineConfig?.initProgressCallback); this.setLogitProcessorRegistry(engineConfig?.logitProcessorRegistry); + this.cachedPrefixes = engineConfig?.cachedPrefixes || []; this.chat = new API.Chat(this); this.completions = new API.Completions(this); @@ -392,6 +394,16 @@ export class MLCEngine implements MLCEngineInterface { this.loadedModelIdToPipeline.set(modelId, newPipeline); this.loadedModelIdToLock.set(modelId, new CustomLock()); + // Call prefillConvSequence() if cachedPrefixes is specified + if ( + newPipeline instanceof LLMChatPipeline && + this.cachedPrefixes.length > 0 + ) { + for (let i = 0; i < this.cachedPrefixes.length; i++) { + await newPipeline.prefillConvSequence(this.cachedPrefixes[i]); + } + } + // Clean up const tend = performance.now(); if (this.initProgressCallback !== undefined) { diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 7050e00b..96ecbcf7 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -34,9 +34,13 @@ import { PrefillChunkSizeSmallerThanImageError, CannotFindImageEmbedError, } from "./error"; +import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion"; type ImageURL = ChatCompletionContentPartImage.ImageURL; +// Default sequence ID for chat completion +const CHAT_SEQUENCE_ID = 0; + export class LLMChatPipeline { private config: ChatConfig; private tokenizer: Tokenizer; @@ -128,6 +132,8 @@ export class LLMChatPipeline { private curRoundGrammarInitTotalTime = 0; // Total time of getting next bitmask and accepting token in seconds private curRoundGrammarPerTokenTotalTime = 0; + private seqIdToPrefix: Map; + private nextSequenceId: number; constructor( tvm: tvmjs.Instance, @@ -173,6 +179,8 @@ export class LLMChatPipeline { log.info("token_postproc_method: ", this.token_postproc_method); log.info("prepend_space_in_encode: ", this.prepend_space_in_encode); + this.seqIdToPrefix = new Map(); + this.nextSequenceId = 0; this.device = this.tvm.webgpu(); // 1. Create VM and get the core functions @@ -344,7 +352,12 @@ export class LLMChatPipeline { * Reset KV Cache */ resetKVCache() { - this.fclearKVCaches(this.kvCache); + // Check whether to keep prefixes in the KV cache + if (this.seqIdToPrefix.size === 0) { + this.fclearKVCaches(this.kvCache); + } else { + this.fKVCacheRemoveSequence!(this.kvCache, new tvmjs.Scalar(0, "int64")); + } this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64")); if (this.slidingWindowSize != -1) { this.fKVCacheEnableSlidingWindowForSeq( @@ -483,6 +496,15 @@ export class LLMChatPipeline { await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule()); } + matchPrefix(inputTokens: number[], prefixTokens: number[]): number { + for (let i = 0; i < prefixTokens.length; i++) { + if (inputTokens[i] !== prefixTokens[i]) { + return i; + } + } + return prefixTokens.length; + } + /** * Generate the first token given input prompt */ @@ -491,11 +513,17 @@ export class LLMChatPipeline { msgRole: Role, // either user or tool inp_role_str?: string, genConfig?: GenerationConfig, + seqID = CHAT_SEQUENCE_ID, ): Promise { - if (msgRole !== Role.user && msgRole !== Role.tool) { - throw new MessageOrderError( - "The last message should be from `user` or `tool`.", - ); + if (seqID === CHAT_SEQUENCE_ID) { + if (msgRole !== Role.user && msgRole !== Role.tool) { + throw new MessageOrderError( + "The last message should be from `user` or `tool`.", + ); + } + } else { + // Set the input as system prompt during prefix prefilling + this.conversation.override_system_message = inp; } if (this.resetStatsPerPrefill) { this.resetRuntimeStats(); @@ -583,11 +611,13 @@ export class LLMChatPipeline { } // 0. Get inputData from conversation - if (conversation.isTextCompletion) { - conversation.prompt = inp; - } else { - conversation.appendMessage(msgRole, inp, inp_role_str); - conversation.appendReplyHeader(Role.assistant); + if (seqID === CHAT_SEQUENCE_ID) { + if (conversation.isTextCompletion) { + conversation.prompt = inp; + } else { + conversation.appendMessage(msgRole, inp, inp_role_str); + conversation.appendReplyHeader(Role.assistant); + } } const retGetInputData = this.getInputData(); const inputData: Array | ImageURL> = retGetInputData[0]; @@ -610,11 +640,63 @@ export class LLMChatPipeline { throw new CannotFindImageEmbedError(); } + let maxMatchedLen = -1; + let matchedSeqId = -1; + + // Prefix matching and forking + const inputTokens = inputData.flat() as number[]; + for (const [id, prefixTokens] of this.seqIdToPrefix) { + const matchedLen = this.matchPrefix(inputTokens, prefixTokens); + if (matchedLen > maxMatchedLen) { + maxMatchedLen = matchedLen; + matchedSeqId = id; + } + } + + // If a match is found, fork the sequence + if (matchedSeqId !== -1 && maxMatchedLen > 0) { + log.info("Forking sequence", matchedSeqId, "at position", maxMatchedLen); + if (seqID === CHAT_SEQUENCE_ID) { + this.fKVCacheRemoveSequence!( + this.kvCache, + new tvmjs.Scalar(seqID, "int64"), + ); + } + this.tvm.beginScope(); + this.tvm.getGlobalFunc("vm.builtin.kv_state_fork_sequence")( + this.kvCache, + new tvmjs.Scalar(matchedSeqId, "int64"), // fork_parent_id + new tvmjs.Scalar(seqID, "int64"), // fork_child_id + new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position + ); + this.tvm.endScope(); + } else if (seqID !== CHAT_SEQUENCE_ID) { + // If no match is found, add the new sequence to the KV cache + log.info("Adding prefix to KV cache: ", seqID); + this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64")); + } + + // Add the new sequence to the seqIdToPrefix map (if it is a prefix) + if (seqID !== CHAT_SEQUENCE_ID) { + this.seqIdToPrefix.set(seqID, inputTokens); + } + // 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data - const retGetChunks = getChunkedPrefillInputData( - inputData, - this.prefillChunkSize, - ); + let retGetChunks; + if (maxMatchedLen === -1) { + retGetChunks = getChunkedPrefillInputData( + inputData, + this.prefillChunkSize, + ); + } else { + // If a matched prefix exists, only forward the remaining tokens + retGetChunks = getChunkedPrefillInputData( + inputData.map((arr) => + Array.isArray(arr) ? arr.slice(maxMatchedLen) : arr, + ), + this.prefillChunkSize, + ); + } const chunks: Array | ImageURL>[] = retGetChunks[0]; const chunkLens: Array = retGetChunks[1]; @@ -626,7 +708,7 @@ export class LLMChatPipeline { const chunkLen = chunkLens[i]; const prevFilledLen = this.filledKVCacheLength; logits = this.tvm.detachFromCurrentScope( - await this.embedAndForward(chunk, chunkLen), + await this.embedAndForward(chunk, chunkLen, seqID), ); if (this.filledKVCacheLength !== prevFilledLen + chunkLen) { throw new Error( @@ -651,6 +733,41 @@ export class LLMChatPipeline { this.processNextToken(nextToken, genConfig); } + async prefillConvSequence( + messages: ChatCompletionMessageParam[], + inp_role_str?: string, + genConfig?: GenerationConfig, + ): Promise { + for (const message of messages) { + this.nextSequenceId = this.nextSequenceId + 1; + const newSeqId = this.nextSequenceId; + // Call the regular prefillStep with the new seqID + if (typeof message.content === "string") { + // Support long system prompt + if (message.role === "system") { + await this.prefillStep( + message.content, + Role.tool, + inp_role_str, + genConfig, + newSeqId, + ); + } else { + throw Error( + "Invalid role in prefix message: " + + message.role + + ", expected 'system'.", + ); + } + } else { + throw Error( + "Invalid content in prefix message, does not support image input.", + ); + } + } + this.conversation.reset(); + } + async decodeStep(genConfig?: GenerationConfig): Promise { if (this.stopTriggered) { throw Error("Cannot run decode when stopped"); @@ -869,6 +986,7 @@ export class LLMChatPipeline { * * @param inputData data to embed and forward * @param inputDataLen length of this inputData, should smaller than prefill chunk size. + * @param seqID sequence ID of the input data in KV cache for prefix caching * @returns The logits returned by this forward as tvmjs.NDArray on GPU. * * @note Precondition: inputData's data length is smaller than prefill chunk size @@ -876,6 +994,7 @@ export class LLMChatPipeline { private async embedAndForward( inputData: Array | ImageURL>, inputDataLen: number, + seqID = CHAT_SEQUENCE_ID, ): Promise { if (inputDataLen > this.prefillChunkSize) { throw new Error( @@ -913,7 +1032,8 @@ export class LLMChatPipeline { // 3. Forward the concatenated embeddings const inputLenShape = this.tvm.makeShapeTuple([inputDataLen]); - const seqIdsTuple = this.tvm.makeShapeTuple([0]); + // set seqIdsTuple to be childID + const seqIdsTuple = this.tvm.makeShapeTuple([seqID]); this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape); let retValue; if (inputDataLen > 1) {