Skip to content

[Cache] Add cachedPrefixes for caching repeated system prompts #664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
NonNegativeError,
RangeError,
} from "./error";
import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion";

/**
* Conversation template config
Expand Down Expand Up @@ -114,6 +115,7 @@ export interface MLCEngineConfig {
initProgressCallback?: InitProgressCallback;
logitProcessorRegistry?: Map<string, LogitProcessor>;
logLevel?: LogLevel;
cachedPrefixes?: ChatCompletionMessageParam[][];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add docs to MLCEngineConfig, specifying the behavior of cachedPrefixes (e.g. will prefill when loading the engine to create the prefixes' KV, will only dispose these KV when reloading the engine). Perhaps we can also mark this as experimental to signify potential future API/behavior change

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add an examples/cached_prefixes? Where we can demonstrate the prefill time difference between using cachedPrefixes and not using it. We should also test whether the behavior is expected in multi-turn conversation.

}

/**
Expand Down
12 changes: 12 additions & 0 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ export class MLCEngine implements MLCEngineInterface {
private logitProcessorRegistry?: Map<string, LogitProcessor>;
private initProgressCallback?: InitProgressCallback;
private appConfig: AppConfig;
private cachedPrefixes: ChatCompletionMessageParam[][];

// Signals and flags
private interruptSignal = false;
Expand All @@ -149,6 +150,7 @@ export class MLCEngine implements MLCEngineInterface {
this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel);
this.setInitProgressCallback(engineConfig?.initProgressCallback);
this.setLogitProcessorRegistry(engineConfig?.logitProcessorRegistry);
this.cachedPrefixes = engineConfig?.cachedPrefixes || [];

this.chat = new API.Chat(this);
this.completions = new API.Completions(this);
Expand Down Expand Up @@ -392,6 +394,16 @@ export class MLCEngine implements MLCEngineInterface {
this.loadedModelIdToPipeline.set(modelId, newPipeline);
this.loadedModelIdToLock.set(modelId, new CustomLock());

// Call prefillConvSequence() if cachedPrefixes is specified
if (
newPipeline instanceof LLMChatPipeline &&
this.cachedPrefixes.length > 0
) {
for (let i = 0; i < this.cachedPrefixes.length; i++) {
await newPipeline.prefillConvSequence(this.cachedPrefixes[i]);
}
}

// Clean up
const tend = performance.now();
if (this.initProgressCallback !== undefined) {
Expand Down
154 changes: 138 additions & 16 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
PrefillChunkSizeSmallerThanImageError,
CannotFindImageEmbedError,
} from "./error";
import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion";

type ImageURL = ChatCompletionContentPartImage.ImageURL;

Expand Down Expand Up @@ -128,6 +129,8 @@ export class LLMChatPipeline {
private curRoundGrammarInitTotalTime = 0;
// Total time of getting next bitmask and accepting token in seconds
private curRoundGrammarPerTokenTotalTime = 0;
private seqIdToPrefix: Map<number, number[]>;
private nextSequenceId: number;

constructor(
tvm: tvmjs.Instance,
Expand Down Expand Up @@ -173,6 +176,8 @@ export class LLMChatPipeline {
log.info("token_postproc_method: ", this.token_postproc_method);
log.info("prepend_space_in_encode: ", this.prepend_space_in_encode);

this.seqIdToPrefix = new Map<number, number[]>();
this.nextSequenceId = 0;
this.device = this.tvm.webgpu();

// 1. Create VM and get the core functions
Expand Down Expand Up @@ -344,7 +349,12 @@ export class LLMChatPipeline {
* Reset KV Cache
*/
resetKVCache() {
this.fclearKVCaches(this.kvCache);
// Check whether to keep prefixes in the KV cache
if (this.seqIdToPrefix.size === 0) {
this.fclearKVCaches(this.kvCache);
} else {
this.fKVCacheRemoveSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have multiple sequence IDs, let's make a constant, say CHAT_SEQUENCE_ID=0 (or maybe a better naming), instead of using a magic number 0 that may be hard to keep track of

}
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
if (this.slidingWindowSize != -1) {
this.fKVCacheEnableSlidingWindowForSeq(
Expand Down Expand Up @@ -483,6 +493,15 @@ export class LLMChatPipeline {
await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule());
}

matchPrefix(inputTokens: number[], prefixTokens: number[]): number {
for (let i = 0; i < prefixTokens.length; i++) {
if (inputTokens[i] !== prefixTokens[i]) {
return i;
}
}
return prefixTokens.length;
}

/**
* Generate the first token given input prompt
*/
Expand All @@ -491,11 +510,17 @@ export class LLMChatPipeline {
msgRole: Role, // either user or tool
inp_role_str?: string,
genConfig?: GenerationConfig,
seqID = 0,
): Promise<void> {
if (msgRole !== Role.user && msgRole !== Role.tool) {
throw new MessageOrderError(
"The last message should be from `user` or `tool`.",
);
if (seqID === 0) {
if (msgRole !== Role.user && msgRole !== Role.tool) {
throw new MessageOrderError(
"The last message should be from `user` or `tool`.",
);
}
} else {
// Set the input as system prompt during prefix prefilling
this.conversation.override_system_message = inp;
}
if (this.resetStatsPerPrefill) {
this.resetRuntimeStats();
Expand Down Expand Up @@ -583,11 +608,13 @@ export class LLMChatPipeline {
}

// 0. Get inputData from conversation
if (conversation.isTextCompletion) {
conversation.prompt = inp;
} else {
conversation.appendMessage(msgRole, inp, inp_role_str);
conversation.appendReplyHeader(Role.assistant);
if (seqID === 0) {
if (conversation.isTextCompletion) {
conversation.prompt = inp;
} else {
conversation.appendMessage(msgRole, inp, inp_role_str);
conversation.appendReplyHeader(Role.assistant);
}
}
const retGetInputData = this.getInputData();
const inputData: Array<Array<number> | ImageURL> = retGetInputData[0];
Expand All @@ -610,11 +637,68 @@ export class LLMChatPipeline {
throw new CannotFindImageEmbedError();
}

let maxMatchedLen = -1;
let matchedSeqId = -1;

// Prefix matching and forking
const inputTokens = inputData.flat() as number[];
for (const [id, prefixTokens] of this.seqIdToPrefix) {
const matchedLen = this.matchPrefix(inputTokens, prefixTokens);
if (matchedLen > maxMatchedLen) {
maxMatchedLen = matchedLen;
matchedSeqId = id;
}
}

// If a match is found, fork the sequence
if (matchedSeqId !== -1 && maxMatchedLen > 0) {
console.log(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use log.info() instead of console.log()

"Forking sequence",
matchedSeqId,
"at position",
maxMatchedLen,
);
if (seqID === 0) {
this.fKVCacheRemoveSequence!(
this.kvCache,
new tvmjs.Scalar(seqID, "int64"),
);
}
this.tvm.beginScope();
this.tvm.getGlobalFunc("vm.builtin.kv_state_fork_sequence")(
this.kvCache,
new tvmjs.Scalar(matchedSeqId, "int64"), // fork_parent_id
new tvmjs.Scalar(seqID, "int64"), // fork_child_id
new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position
);
this.tvm.endScope();
} else if (seqID !== 0) {
// If no match is found, add the new sequence to the KV cache
console.log("Adding prefix to KV cache: ", seqID);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use log.info() instead of console.log()

this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64"));
}

// Add the new sequence to the seqIdToPrefix map (if it is a prefix)
if (seqID !== 0) {
this.seqIdToPrefix.set(seqID, inputTokens);
}

// 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data
const retGetChunks = getChunkedPrefillInputData(
inputData,
this.prefillChunkSize,
);
let retGetChunks;
if (maxMatchedLen === -1) {
retGetChunks = getChunkedPrefillInputData(
inputData,
this.prefillChunkSize,
);
} else {
// If a matched prefix exists, only forward the remaining tokens
retGetChunks = getChunkedPrefillInputData(
inputData.map((arr) =>
Array.isArray(arr) ? arr.slice(maxMatchedLen) : arr,
),
this.prefillChunkSize,
);
}
const chunks: Array<Array<number> | ImageURL>[] = retGetChunks[0];
const chunkLens: Array<number> = retGetChunks[1];

Expand All @@ -626,7 +710,7 @@ export class LLMChatPipeline {
const chunkLen = chunkLens[i];
const prevFilledLen = this.filledKVCacheLength;
logits = this.tvm.detachFromCurrentScope(
await this.embedAndForward(chunk, chunkLen),
await this.embedAndForward(chunk, chunkLen, seqID),
);
if (this.filledKVCacheLength !== prevFilledLen + chunkLen) {
throw new Error(
Expand All @@ -651,6 +735,41 @@ export class LLMChatPipeline {
this.processNextToken(nextToken, genConfig);
}

async prefillConvSequence(
messages: ChatCompletionMessageParam[],
inp_role_str?: string,
genConfig?: GenerationConfig,
): Promise<void> {
for (const message of messages) {
this.nextSequenceId = this.nextSequenceId + 1;
const newSeqId = this.nextSequenceId;
// Call the regular prefillStep with the new seqID
if (typeof message.content === "string") {
// Support long system prompt
if (message.role === "system") {
await this.prefillStep(
message.content,
Role.tool,
inp_role_str,
genConfig,
newSeqId,
);
} else {
throw Error(
"Invalid role in prefix message: " +
message.role +
", expected 'system'.",
);
}
} else {
throw Error(
"Invalid content in prefix message, does not support image input.",
);
}
}
this.conversation.reset();
}

async decodeStep(genConfig?: GenerationConfig): Promise<void> {
if (this.stopTriggered) {
throw Error("Cannot run decode when stopped");
Expand Down Expand Up @@ -869,13 +988,15 @@ export class LLMChatPipeline {
*
* @param inputData data to embed and forward
* @param inputDataLen length of this inputData, should smaller than prefill chunk size.
* @param seqID sequence ID of the input data in KV cache for prefix caching
* @returns The logits returned by this forward as tvmjs.NDArray on GPU.
*
* @note Precondition: inputData's data length is smaller than prefill chunk size
*/
private async embedAndForward(
inputData: Array<Array<number> | ImageURL>,
inputDataLen: number,
seqID = 0,
): Promise<tvmjs.NDArray> {
if (inputDataLen > this.prefillChunkSize) {
throw new Error(
Expand Down Expand Up @@ -913,7 +1034,8 @@ export class LLMChatPipeline {

// 3. Forward the concatenated embeddings
const inputLenShape = this.tvm.makeShapeTuple([inputDataLen]);
const seqIdsTuple = this.tvm.makeShapeTuple([0]);
// set seqIdsTuple to be childID
const seqIdsTuple = this.tvm.makeShapeTuple([seqID]);
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape);
let retValue;
if (inputDataLen > 1) {
Expand Down