Skip to content

Commit 7288fd1

Browse files
committed
fix(sample): satisfy mypy for sampling context
Signed-off-by: freyfwt <freytian1996@163.com>
1 parent 2ee8ea3 commit 7288fd1

2 files changed

Lines changed: 14 additions & 15 deletions

File tree

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2298,6 +2298,7 @@ def _maybe_build_v1_sampling_context(
22982298
return None
22992299

23002300
num_reqs = self.input_batch.num_reqs
2301+
req_indices_np: np.ndarray
23012302
if spec_decode_metadata is None:
23022303
req_indices_np = np.arange(num_reqs, dtype=np.int32)
23032304
else:

vllm_ascend/worker/v1/sample/sampling_context.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,42 +110,40 @@ def from_model_runner_inputs(
110110

111111
if idx_mapping_np is None:
112112
expanded_idx_mapping = req_indices_at_logits.to(device=device, dtype=torch.int32)
113-
idx_mapping_np = expanded_idx_mapping.detach().cpu().numpy().astype(np.int32)
113+
idx_mapping = expanded_idx_mapping.detach().cpu().numpy().astype(np.int32)
114114
else:
115-
idx_mapping_np = np.asarray(idx_mapping_np, dtype=np.int32)
116-
if int(idx_mapping_np.shape[0]) != num_logits:
115+
idx_mapping = np.asarray(idx_mapping_np, dtype=np.int32)
116+
if int(idx_mapping.shape[0]) != num_logits:
117117
raise ValueError("idx_mapping_np must have one entry per logits row")
118118
if req_indices_at_logits.device.type == "cpu":
119119
req_indices_np = req_indices_at_logits.detach().numpy().astype(np.int32, copy=False)
120-
if not np.array_equal(idx_mapping_np, req_indices_np):
120+
if not np.array_equal(idx_mapping, req_indices_np):
121121
raise ValueError("idx_mapping_np must match req_indices_at_logits")
122-
expanded_idx_mapping = torch.from_numpy(idx_mapping_np).to(device=device, dtype=torch.int32)
123-
if idx_mapping_np.size and (idx_mapping_np.min() < 0 or idx_mapping_np.max() >= num_reqs):
122+
expanded_idx_mapping = torch.from_numpy(idx_mapping).to(device=device, dtype=torch.int32)
123+
if idx_mapping.size and (idx_mapping.min() < 0 or idx_mapping.max() >= num_reqs):
124124
raise ValueError("req_indices_at_logits contains an out-of-range request index")
125125

126126
if expanded_local_pos is None:
127-
local_pos_np = np.empty(num_logits, dtype=np.int64)
128-
counters = np.zeros(num_reqs, dtype=np.int64)
129-
for row, req_idx in enumerate(idx_mapping_np):
127+
local_pos_np: np.ndarray = np.empty(num_logits, dtype=np.int64)
128+
counters: np.ndarray = np.zeros(num_reqs, dtype=np.int64)
129+
for row, req_idx in enumerate(idx_mapping):
130130
local_pos_np[row] = counters[req_idx]
131131
counters[req_idx] += 1
132132
expanded_local_pos = torch.from_numpy(local_pos_np).to(device=device)
133133
else:
134134
expanded_local_pos = expanded_local_pos.to(device=device, dtype=torch.int64)
135135

136-
expanded_logits = num_logits != num_reqs or not np.array_equal(
137-
idx_mapping_np, np.arange(num_reqs, dtype=np.int32)
138-
)
139-
if expanded_logits and not V1SamplingContext._is_grouped_by_request(idx_mapping_np):
136+
expanded_logits = num_logits != num_reqs or not np.array_equal(idx_mapping, np.arange(num_reqs, dtype=np.int32))
137+
if expanded_logits and not V1SamplingContext._is_grouped_by_request(idx_mapping):
140138
raise ValueError("expanded logits rows must be grouped by request")
141139
if cu_num_logits_np is None and expanded_logits:
142140
cu_num_logits_np = np.concatenate(
143-
(np.array([0], dtype=np.int32), np.cumsum(np.bincount(idx_mapping_np, minlength=num_reqs)))
141+
(np.array([0], dtype=np.int32), np.cumsum(np.bincount(idx_mapping, minlength=num_reqs)))
144142
).astype(np.int32)
145143

146144
return V1SamplingContext(
147145
expanded_idx_mapping=expanded_idx_mapping,
148-
idx_mapping_np=idx_mapping_np,
146+
idx_mapping_np=idx_mapping,
149147
pos=positions_at_logits.to(device=device),
150148
input_ids=input_ids_at_logits.to(device=device),
151149
expanded_local_pos=expanded_local_pos,

0 commit comments

Comments
 (0)