Skip to content

Commit d715895

Browse files
rebel-eunjirebel-jonghewk
authored andcommitted
fix(core): fix handling logprobs when use rbln_sampler (#184)
* fix indexing logprobs * fix comment
1 parent 94c9bec commit d715895

2 files changed

Lines changed: 18 additions & 1 deletion

File tree

vllm_rbln/v1/sample/rbln_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ def get_bucket_sizes(max_num_seqs: int) -> list[int]:
241241
[1, 2, 4] + list(range(8, 256, 8)) + list(
242242
range(256, max_num_seqs + 1, 16))
243243
"""
244+
# FIXME(eunji.lee)
245+
# Not used. To be removed.
244246
bucket_sizes = [i for i in [1, 2, 4] if i <= max_num_seqs]
245247
if max_num_seqs >= 8:
246248
# Step size 8 for small batch sizes, up to 256(not included)

vllm_rbln/v1/worker/optimum_model_runner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,10 @@ def execute_model(
321321
sampler_output.sampled_token_ids[:num_reqs]
322322
if sampler_output.logprobs_tensors is not None:
323323
sampler_output.logprobs_tensors = \
324-
sampler_output.logprobs_tensors[:num_reqs]
324+
self.post_process_logprobs_tensors(
325+
sampler_output.logprobs_tensors,
326+
num_reqs
327+
)
325328

326329
with record_function_or_nullcontext("Bookkeep"):
327330
(
@@ -1223,3 +1226,15 @@ def get_bucket_sizes(max_num_seqs: int) -> list[int]:
12231226
# Step size 16 for larger batch sizes
12241227
bucket_sizes += list(range(256, max_num_seqs + 1, 16))
12251228
return bucket_sizes
1229+
1230+
def post_process_logprobs_tensors(self, logprobs_tensors: LogprobsTensors,
1231+
num_reqs: int) -> LogprobsTensors:
1232+
# NOTE(eunji.lee):
1233+
# This implementation is not efficient but kept for debugging purposes.
1234+
# TODO: Modify this code in the next version when the shape of
1235+
# logprobs_tensors changes.
1236+
dict = {}
1237+
for field_name in logprobs_tensors._fields:
1238+
tensor = getattr(logprobs_tensors, field_name)
1239+
dict[field_name] = tensor[:num_reqs]
1240+
return LogprobsTensors(**dict)

0 commit comments

Comments
 (0)