@@ -110,42 +110,40 @@ def from_model_runner_inputs(
110110
111111 if idx_mapping_np is None :
112112 expanded_idx_mapping = req_indices_at_logits .to (device = device , dtype = torch .int32 )
113- idx_mapping_np = expanded_idx_mapping .detach ().cpu ().numpy ().astype (np .int32 )
113+ idx_mapping = expanded_idx_mapping .detach ().cpu ().numpy ().astype (np .int32 )
114114 else :
115- idx_mapping_np = np .asarray (idx_mapping_np , dtype = np .int32 )
116- if int (idx_mapping_np .shape [0 ]) != num_logits :
115+ idx_mapping = np .asarray (idx_mapping_np , dtype = np .int32 )
116+ if int (idx_mapping .shape [0 ]) != num_logits :
117117 raise ValueError ("idx_mapping_np must have one entry per logits row" )
118118 if req_indices_at_logits .device .type == "cpu" :
119119 req_indices_np = req_indices_at_logits .detach ().numpy ().astype (np .int32 , copy = False )
120- if not np .array_equal (idx_mapping_np , req_indices_np ):
120+ if not np .array_equal (idx_mapping , req_indices_np ):
121121 raise ValueError ("idx_mapping_np must match req_indices_at_logits" )
122- expanded_idx_mapping = torch .from_numpy (idx_mapping_np ).to (device = device , dtype = torch .int32 )
123- if idx_mapping_np .size and (idx_mapping_np .min () < 0 or idx_mapping_np .max () >= num_reqs ):
122+ expanded_idx_mapping = torch .from_numpy (idx_mapping ).to (device = device , dtype = torch .int32 )
123+ if idx_mapping .size and (idx_mapping .min () < 0 or idx_mapping .max () >= num_reqs ):
124124 raise ValueError ("req_indices_at_logits contains an out-of-range request index" )
125125
126126 if expanded_local_pos is None :
127- local_pos_np = np .empty (num_logits , dtype = np .int64 )
128- counters = np .zeros (num_reqs , dtype = np .int64 )
129- for row , req_idx in enumerate (idx_mapping_np ):
127+ local_pos_np : np . ndarray = np .empty (num_logits , dtype = np .int64 )
128+ counters : np . ndarray = np .zeros (num_reqs , dtype = np .int64 )
129+ for row , req_idx in enumerate (idx_mapping ):
130130 local_pos_np [row ] = counters [req_idx ]
131131 counters [req_idx ] += 1
132132 expanded_local_pos = torch .from_numpy (local_pos_np ).to (device = device )
133133 else :
134134 expanded_local_pos = expanded_local_pos .to (device = device , dtype = torch .int64 )
135135
136- expanded_logits = num_logits != num_reqs or not np .array_equal (
137- idx_mapping_np , np .arange (num_reqs , dtype = np .int32 )
138- )
139- if expanded_logits and not V1SamplingContext ._is_grouped_by_request (idx_mapping_np ):
136+ expanded_logits = num_logits != num_reqs or not np .array_equal (idx_mapping , np .arange (num_reqs , dtype = np .int32 ))
137+ if expanded_logits and not V1SamplingContext ._is_grouped_by_request (idx_mapping ):
140138 raise ValueError ("expanded logits rows must be grouped by request" )
141139 if cu_num_logits_np is None and expanded_logits :
142140 cu_num_logits_np = np .concatenate (
143- (np .array ([0 ], dtype = np .int32 ), np .cumsum (np .bincount (idx_mapping_np , minlength = num_reqs )))
141+ (np .array ([0 ], dtype = np .int32 ), np .cumsum (np .bincount (idx_mapping , minlength = num_reqs )))
144142 ).astype (np .int32 )
145143
146144 return V1SamplingContext (
147145 expanded_idx_mapping = expanded_idx_mapping ,
148- idx_mapping_np = idx_mapping_np ,
146+ idx_mapping_np = idx_mapping ,
149147 pos = positions_at_logits .to (device = device ),
150148 input_ids = input_ids_at_logits .to (device = device ),
151149 expanded_local_pos = expanded_local_pos ,
0 commit comments