Skip to content

Commit 8e6fef4

Browse files
committed
Fix post schedule and simplify embedding cache
1 parent d0b74cd commit 8e6fef4

2 files changed

Lines changed: 13 additions & 12 deletions

File tree

gllm/model_runner.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ class EmbeddingInfo:
3535
embedding: torch.Tensor = None
3636
prompt_positions: torch.Tensor = None
3737
mrope_position_delta: torch.Tensor = None
38-
stale: bool = False
39-
4038

4139
class ModelRunner:
4240
def __init__(
@@ -257,7 +255,6 @@ def mm_prepare_inputs(self, seqs: List[Sequence]):
257255
position = None
258256
if seq.computed_prompt:
259257
embedding_info = self.embedding_cache[seq.seq_id]
260-
assert embedding_info.stale
261258
embedding = self.model.embed_input_ids(
262259
torch.tensor(seq.to_compute_tokens)
263260
)
@@ -269,10 +266,7 @@ def mm_prepare_inputs(self, seqs: List[Sequence]):
269266
position = torch.tensor(position, device="cpu")
270267
else:
271268
embedding_info = None
272-
if (
273-
seq.seq_id not in self.embedding_cache
274-
or self.embedding_cache[seq.seq_id].stale
275-
):
269+
if seq.seq_id not in self.embedding_cache:
276270
mm_embeddings = None
277271
image_grid_thw: torch.Tensor = None
278272
video_grid_thw: torch.Tensor = None
@@ -338,7 +332,6 @@ def mm_prepare_inputs(self, seqs: List[Sequence]):
338332
]
339333
if seq.seq_len == seq.prompt_len:
340334
# invalidate embedding_cache
341-
embedding_info.stale = True
342335
embedding_info.embedding = None
343336
batch_embeddings.append(embedding)
344337
batch_positions.append(position)
@@ -482,3 +475,5 @@ def step_once(self):
482475

483476
def free(self, seq: Sequence):
484477
self.memory_manager.free(seq)
478+
if self.use_mm:
479+
self.embedding_cache.pop(seq.seq_id)

gllm/scheduler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,16 @@ def post_schedule(self, schedule_seqs: List[Sequence]):
160160
for seq in schedule_seqs:
161161
if seq.has_schedule:
162162
post_schedule_seq = copy.copy(seq)
163-
post_schedule_seq.to_compute_tokens = seq[
164-
seq.computed_token_num : seq.seq_len
165-
]
166-
post_schedule_seq.token_ids = None
163+
# MM prefill may still need full prompt token_ids to
164+
# build (or rebuild) cached multimodal embeddings/positions.
165+
# Keep token_ids for unfinished MM prefills; drop otherwise to
166+
# reduce IPC payload.
167+
keep_full_token_ids = self.model_runner.use_mm and (not seq.computed_prompt)
168+
post_schedule_seq.token_ids = seq.token_ids if keep_full_token_ids else None
169+
if not keep_full_token_ids:
170+
post_schedule_seq.to_compute_tokens = seq[
171+
seq.computed_token_num : seq.seq_len
172+
]
167173
post_schedule_seqs.append(post_schedule_seq)
168174
else:
169175
post_schedule_seqs.append(seq)

0 commit comments

Comments
 (0)