@@ -50,8 +50,8 @@ class InferOutput:
50
50
def _tensorlize_block_offsets (block_offsets , dtype = torch .int32 ):
51
51
"""tensorlize block_offsets."""
52
52
from torch .nn .utils .rnn import pad_sequence
53
- block_offsets = [torch .from_numpy (off ). to ( dtype ) for off in block_offsets ]
54
- block_offsets = pad_sequence (block_offsets , batch_first = True )
53
+ block_offsets = [torch .from_numpy (off ) for off in block_offsets ]
54
+ block_offsets = pad_sequence (block_offsets , batch_first = True ). to ( dtype )
55
55
return block_offsets
56
56
57
57
@@ -563,6 +563,7 @@ def __update_max_new_tokens(msg):
563
563
req .data ['token_ids' ],
564
564
multimodals = req .data .get ('input_multimodals' ),
565
565
embeddings = req .data .get ('input_embeddings' ),
566
+ append_tokens = True ,
566
567
)
567
568
msg .num_new_tokens = 0
568
569
msg .sampling_param = sampling_param
@@ -721,8 +722,6 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped
721
722
msg .update_token_ids (update_token , model_meta = model_meta )
722
723
msg .num_new_tokens += 1
723
724
if stop :
724
- update_token = _EMPTY_TOKEN
725
- msg .update_token_ids (update_token , model_meta = model_meta )
726
725
msg .status = MessageStatus .STOPPED
727
726
728
727
def update_running_migration (self , running : SeqList , next_token_ids : np .ndarray , stopped : torch .Tensor ,
0 commit comments