@@ -35,8 +35,6 @@ class EmbeddingInfo:
3535 embedding : torch .Tensor = None
3636 prompt_positions : torch .Tensor = None
3737 mrope_position_delta : torch .Tensor = None
38- stale : bool = False
39-
4038
4139class ModelRunner :
4240 def __init__ (
@@ -257,7 +255,6 @@ def mm_prepare_inputs(self, seqs: List[Sequence]):
257255 position = None
258256 if seq .computed_prompt :
259257 embedding_info = self .embedding_cache [seq .seq_id ]
260- assert embedding_info .stale
261258 embedding = self .model .embed_input_ids (
262259 torch .tensor (seq .to_compute_tokens )
263260 )
@@ -269,10 +266,7 @@ def mm_prepare_inputs(self, seqs: List[Sequence]):
269266 position = torch .tensor (position , device = "cpu" )
270267 else :
271268 embedding_info = None
272- if (
273- seq .seq_id not in self .embedding_cache
274- or self .embedding_cache [seq .seq_id ].stale
275- ):
269+ if seq .seq_id not in self .embedding_cache :
276270 mm_embeddings = None
277271 image_grid_thw : torch .Tensor = None
278272 video_grid_thw : torch .Tensor = None
@@ -338,7 +332,6 @@ def mm_prepare_inputs(self, seqs: List[Sequence]):
338332 ]
339333 if seq .seq_len == seq .prompt_len :
340334 # invalidate embedding_cache
341- embedding_info .stale = True
342335 embedding_info .embedding = None
343336 batch_embeddings .append (embedding )
344337 batch_positions .append (position )
@@ -482,3 +475,5 @@ def step_once(self):
482475
483476 def free (self , seq : Sequence ):
484477 self .memory_manager .free (seq )
478+ if self .use_mm :
479+ self .embedding_cache .pop (seq .seq_id )
0 commit comments