Skip to content

Commit d7e8456

Browse files
authored
fix stopwords kv cache (#3494)
1 parent 7bbec66 commit d7e8456

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

lmdeploy/pytorch/engine/engine.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class InferOutput:
5050
def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
5151
"""tensorlize block_offsets."""
5252
from torch.nn.utils.rnn import pad_sequence
53-
block_offsets = [torch.from_numpy(off).to(dtype) for off in block_offsets]
54-
block_offsets = pad_sequence(block_offsets, batch_first=True)
53+
block_offsets = [torch.from_numpy(off) for off in block_offsets]
54+
block_offsets = pad_sequence(block_offsets, batch_first=True).to(dtype)
5555
return block_offsets
5656

5757

@@ -563,6 +563,7 @@ def __update_max_new_tokens(msg):
563563
req.data['token_ids'],
564564
multimodals=req.data.get('input_multimodals'),
565565
embeddings=req.data.get('input_embeddings'),
566+
append_tokens=True,
566567
)
567568
msg.num_new_tokens = 0
568569
msg.sampling_param = sampling_param
@@ -721,8 +722,6 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped
721722
msg.update_token_ids(update_token, model_meta=model_meta)
722723
msg.num_new_tokens += 1
723724
if stop:
724-
update_token = _EMPTY_TOKEN
725-
msg.update_token_ids(update_token, model_meta=model_meta)
726725
msg.status = MessageStatus.STOPPED
727726

728727
def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor,

lmdeploy/pytorch/messages.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,15 @@ def update_token_ids(self,
592592
token_ids: Tensor,
593593
multimodals: MultiModalInputs = None,
594594
embeddings: List[InputEmbeddings] = None,
595-
model_meta: Dict[str, Any] = None):
595+
model_meta: Dict[str, Any] = None,
596+
append_tokens: bool = False):
596597
"""Update token ids, old token ids will be added to history."""
597598
old_num_history_ids = self._num_history_ids
598599

599-
self._num_history_ids += self._num_token_ids
600+
# update history
601+
if not append_tokens:
602+
self._num_history_ids += self._num_token_ids
603+
600604
# update history image nums
601605
self._num_history_images += self._num_images
602606
self._num_images = 0
@@ -626,7 +630,10 @@ def update_token_ids(self,
626630
token_ids = np.array(token_ids)
627631
if token_ids.ndim == 0:
628632
token_ids = token_ids[None]
629-
self._num_token_ids = len(token_ids)
633+
if append_tokens:
634+
self._num_token_ids += len(token_ids)
635+
else:
636+
self._num_token_ids = len(token_ids)
630637
self.history_cache.append(token_ids)
631638
self.random_offsets += 1
632639
self.arrive_time = time.time()

0 commit comments

Comments
 (0)