-
Notifications
You must be signed in to change notification settings - Fork 1k
[Cache] Add cachedPrefixes for caching repeated system prompts #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ import { | |
NonNegativeError, | ||
RangeError, | ||
} from "./error"; | ||
import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion"; | ||
|
||
/** | ||
* Conversation template config | ||
|
@@ -114,6 +115,7 @@ export interface MLCEngineConfig { | |
initProgressCallback?: InitProgressCallback; | ||
logitProcessorRegistry?: Map<string, LogitProcessor>; | ||
logLevel?: LogLevel; | ||
cachedPrefixes?: ChatCompletionMessageParam[][]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also add an |
||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ import { | |
PrefillChunkSizeSmallerThanImageError, | ||
CannotFindImageEmbedError, | ||
} from "./error"; | ||
import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion"; | ||
|
||
type ImageURL = ChatCompletionContentPartImage.ImageURL; | ||
|
||
|
@@ -128,6 +129,8 @@ export class LLMChatPipeline { | |
private curRoundGrammarInitTotalTime = 0; | ||
// Total time of getting next bitmask and accepting token in seconds | ||
private curRoundGrammarPerTokenTotalTime = 0; | ||
private seqIdToPrefix: Map<number, number[]>; | ||
private nextSequenceId: number; | ||
|
||
constructor( | ||
tvm: tvmjs.Instance, | ||
|
@@ -173,6 +176,8 @@ export class LLMChatPipeline { | |
log.info("token_postproc_method: ", this.token_postproc_method); | ||
log.info("prepend_space_in_encode: ", this.prepend_space_in_encode); | ||
|
||
this.seqIdToPrefix = new Map<number, number[]>(); | ||
this.nextSequenceId = 0; | ||
this.device = this.tvm.webgpu(); | ||
|
||
// 1. Create VM and get the core functions | ||
|
@@ -344,7 +349,12 @@ export class LLMChatPipeline { | |
* Reset KV Cache | ||
*/ | ||
resetKVCache() { | ||
this.fclearKVCaches(this.kvCache); | ||
// Check whether to keep prefixes in the KV cache | ||
if (this.seqIdToPrefix.size === 0) { | ||
this.fclearKVCaches(this.kvCache); | ||
} else { | ||
this.fKVCacheRemoveSequence!(this.kvCache, new tvmjs.Scalar(0, "int64")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we have multiple sequence IDs, let's make a constant, say |
||
} | ||
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64")); | ||
if (this.slidingWindowSize != -1) { | ||
this.fKVCacheEnableSlidingWindowForSeq( | ||
|
@@ -483,6 +493,15 @@ export class LLMChatPipeline { | |
await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule()); | ||
} | ||
|
||
matchPrefix(inputTokens: number[], prefixTokens: number[]): number { | ||
for (let i = 0; i < prefixTokens.length; i++) { | ||
if (inputTokens[i] !== prefixTokens[i]) { | ||
return i; | ||
} | ||
} | ||
return prefixTokens.length; | ||
} | ||
|
||
/** | ||
* Generate the first token given input prompt | ||
*/ | ||
|
@@ -491,11 +510,17 @@ export class LLMChatPipeline { | |
msgRole: Role, // either user or tool | ||
inp_role_str?: string, | ||
genConfig?: GenerationConfig, | ||
seqID = 0, | ||
): Promise<void> { | ||
if (msgRole !== Role.user && msgRole !== Role.tool) { | ||
throw new MessageOrderError( | ||
"The last message should be from `user` or `tool`.", | ||
); | ||
if (seqID === 0) { | ||
if (msgRole !== Role.user && msgRole !== Role.tool) { | ||
throw new MessageOrderError( | ||
"The last message should be from `user` or `tool`.", | ||
); | ||
} | ||
} else { | ||
// Set the input as system prompt during prefix prefilling | ||
this.conversation.override_system_message = inp; | ||
} | ||
if (this.resetStatsPerPrefill) { | ||
this.resetRuntimeStats(); | ||
|
@@ -583,11 +608,13 @@ export class LLMChatPipeline { | |
} | ||
|
||
// 0. Get inputData from conversation | ||
if (conversation.isTextCompletion) { | ||
conversation.prompt = inp; | ||
} else { | ||
conversation.appendMessage(msgRole, inp, inp_role_str); | ||
conversation.appendReplyHeader(Role.assistant); | ||
if (seqID === 0) { | ||
if (conversation.isTextCompletion) { | ||
conversation.prompt = inp; | ||
} else { | ||
conversation.appendMessage(msgRole, inp, inp_role_str); | ||
conversation.appendReplyHeader(Role.assistant); | ||
} | ||
} | ||
const retGetInputData = this.getInputData(); | ||
const inputData: Array<Array<number> | ImageURL> = retGetInputData[0]; | ||
|
@@ -610,11 +637,68 @@ export class LLMChatPipeline { | |
throw new CannotFindImageEmbedError(); | ||
} | ||
|
||
let maxMatchedLen = -1; | ||
let matchedSeqId = -1; | ||
|
||
// Prefix matching and forking | ||
const inputTokens = inputData.flat() as number[]; | ||
for (const [id, prefixTokens] of this.seqIdToPrefix) { | ||
const matchedLen = this.matchPrefix(inputTokens, prefixTokens); | ||
if (matchedLen > maxMatchedLen) { | ||
maxMatchedLen = matchedLen; | ||
matchedSeqId = id; | ||
} | ||
} | ||
|
||
// If a match is found, fork the sequence | ||
if (matchedSeqId !== -1 && maxMatchedLen > 0) { | ||
console.log( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
"Forking sequence", | ||
matchedSeqId, | ||
"at position", | ||
maxMatchedLen, | ||
); | ||
if (seqID === 0) { | ||
this.fKVCacheRemoveSequence!( | ||
this.kvCache, | ||
new tvmjs.Scalar(seqID, "int64"), | ||
); | ||
} | ||
this.tvm.beginScope(); | ||
this.tvm.getGlobalFunc("vm.builtin.kv_state_fork_sequence")( | ||
this.kvCache, | ||
new tvmjs.Scalar(matchedSeqId, "int64"), // fork_parent_id | ||
new tvmjs.Scalar(seqID, "int64"), // fork_child_id | ||
new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position | ||
); | ||
this.tvm.endScope(); | ||
} else if (seqID !== 0) { | ||
// If no match is found, add the new sequence to the KV cache | ||
console.log("Adding prefix to KV cache: ", seqID); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64")); | ||
} | ||
|
||
// Add the new sequence to the seqIdToPrefix map (if it is a prefix) | ||
if (seqID !== 0) { | ||
this.seqIdToPrefix.set(seqID, inputTokens); | ||
} | ||
|
||
// 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data | ||
const retGetChunks = getChunkedPrefillInputData( | ||
inputData, | ||
this.prefillChunkSize, | ||
); | ||
let retGetChunks; | ||
if (maxMatchedLen === -1) { | ||
retGetChunks = getChunkedPrefillInputData( | ||
inputData, | ||
this.prefillChunkSize, | ||
); | ||
} else { | ||
// If a matched prefix exists, only forward the remaining tokens | ||
retGetChunks = getChunkedPrefillInputData( | ||
inputData.map((arr) => | ||
Array.isArray(arr) ? arr.slice(maxMatchedLen) : arr, | ||
), | ||
this.prefillChunkSize, | ||
); | ||
} | ||
const chunks: Array<Array<number> | ImageURL>[] = retGetChunks[0]; | ||
const chunkLens: Array<number> = retGetChunks[1]; | ||
|
||
|
@@ -626,7 +710,7 @@ export class LLMChatPipeline { | |
const chunkLen = chunkLens[i]; | ||
const prevFilledLen = this.filledKVCacheLength; | ||
logits = this.tvm.detachFromCurrentScope( | ||
await this.embedAndForward(chunk, chunkLen), | ||
await this.embedAndForward(chunk, chunkLen, seqID), | ||
); | ||
if (this.filledKVCacheLength !== prevFilledLen + chunkLen) { | ||
throw new Error( | ||
|
@@ -651,6 +735,41 @@ export class LLMChatPipeline { | |
this.processNextToken(nextToken, genConfig); | ||
} | ||
|
||
async prefillConvSequence( | ||
messages: ChatCompletionMessageParam[], | ||
inp_role_str?: string, | ||
genConfig?: GenerationConfig, | ||
): Promise<void> { | ||
for (const message of messages) { | ||
this.nextSequenceId = this.nextSequenceId + 1; | ||
const newSeqId = this.nextSequenceId; | ||
// Call the regular prefillStep with the new seqID | ||
if (typeof message.content === "string") { | ||
// Support long system prompt | ||
if (message.role === "system") { | ||
await this.prefillStep( | ||
message.content, | ||
Role.tool, | ||
inp_role_str, | ||
genConfig, | ||
newSeqId, | ||
); | ||
} else { | ||
throw Error( | ||
"Invalid role in prefix message: " + | ||
message.role + | ||
", expected 'system'.", | ||
); | ||
} | ||
} else { | ||
throw Error( | ||
"Invalid content in prefix message, does not support image input.", | ||
); | ||
} | ||
} | ||
this.conversation.reset(); | ||
} | ||
|
||
async decodeStep(genConfig?: GenerationConfig): Promise<void> { | ||
if (this.stopTriggered) { | ||
throw Error("Cannot run decode when stopped"); | ||
|
@@ -869,13 +988,15 @@ export class LLMChatPipeline { | |
* | ||
* @param inputData data to embed and forward | ||
* @param inputDataLen length of this inputData, should smaller than prefill chunk size. | ||
* @param seqID sequence ID of the input data in KV cache for prefix caching | ||
* @returns The logits returned by this forward as tvmjs.NDArray on GPU. | ||
* | ||
* @note Precondition: inputData's data length is smaller than prefill chunk size | ||
*/ | ||
private async embedAndForward( | ||
inputData: Array<Array<number> | ImageURL>, | ||
inputDataLen: number, | ||
seqID = 0, | ||
): Promise<tvmjs.NDArray> { | ||
if (inputDataLen > this.prefillChunkSize) { | ||
throw new Error( | ||
|
@@ -913,7 +1034,8 @@ export class LLMChatPipeline { | |
|
||
// 3. Forward the concatenated embeddings | ||
const inputLenShape = this.tvm.makeShapeTuple([inputDataLen]); | ||
const seqIdsTuple = this.tvm.makeShapeTuple([0]); | ||
// set seqIdsTuple to be childID | ||
const seqIdsTuple = this.tvm.makeShapeTuple([seqID]); | ||
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape); | ||
let retValue; | ||
if (inputDataLen > 1) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add docs to
MLCEngineConfig
, specifying the behavior ofcachedPrefixes
(e.g. will prefill when loading the engine to create the prefixes' KV, will only dispose these KV when reloading the engine). Perhaps we can also mark this asexperimental
to signify potential future API/behavior change