@@ -228,6 +228,12 @@ def beam_search(
228228 encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
229229 encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
230230
231+ num_sequences = input_ids .shape [0 ]
232+
233+ # Pre-allocate everything
234+ token_idxs = torch .full ((num_sequences , num_beams , 1 ), eos_idx ).to (dtype = torch .long , device = device )
235+ beam_idxs = torch .zeros ((num_sequences , num_beams , 1 )).to (dtype = torch .long , device = device )
236+
231237 def update_func (emissions , N , T , prev_step_token_idxs , prev_step_hyp_idxs , prev_step_model_states , timestep ):
232238 # `emissions` and `N` are unused in this current implementation
233239
@@ -236,16 +242,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
236242 # For first timestep, create previous step token_idxs and model_states
237243 if timestep == 0 :
238244 prev_step_token_idxs = [- 1 ]
239- prev_step_model_states = [
240- create_emitting_model_state (
241- Seq2SeqModelState (timestep = 0 , sequence = input_ids [i ].unsqueeze (0 ), lm_scores = None )
242- )
243- ]
244245
245246 encoder_output_for_curr_seq = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
246- prev_model_state_sequences = [
247- get_obj_from_emitting_model_state (state ).sequence for state in prev_step_model_states
248- ]
249247 out_probs , model_states = [], []
250248
251249 start = 0
@@ -261,66 +259,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
261259 if end > curr_beam_size :
262260 end = curr_beam_size
263261
264- num_samples = end - start
265-
266262 if prev_step_token_idxs != [- 1 ]:
267- state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
268- token_indices = (
269- torch .Tensor (prev_step_token_idxs [start :end ])
270- .to (dtype = torch .long , device = device )
271- .reshape (num_samples , 1 )
272- )
273-
274- state_and_tokens = torch .cat (
275- [state_sequences , token_indices ], dim = - 1
276- ) # [batch_size x (timestep + 1)]
277- assert state_and_tokens .shape == (
278- num_samples ,
279- timestep + 1 ,
280- ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
263+ token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = device )
264+ token_idxs [i , : len (token_indices ), 0 ] = token_indices
265+ curr_token_idxs = token_idxs [i , :, 0 ].reshape (num_beams , 1 )
281266 else :
282- assert len (prev_model_state_sequences ) == 1
283- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
284- num_beams , - 1
285- ) # TODO: Make this more robust
286-
287- # Cleanup -- combine this with the above
288- if self .is_encoder_decoder :
289- # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
290- # This is a view-only operation and doesn't copy
291- model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
292- num_samples if timestep > 0 else num_beams , - 1 , - 1
293- )
267+ if self .is_encoder_decoder :
268+ # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
269+ # This is a view-only operation and doesn't copy
270+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
271+ num_beams , - 1 , - 1
272+ )
273+ curr_token_idxs = torch .zeros ((num_beams , 1 )).to (dtype = torch .long , device = device )
274+
294275
295276 # Preprocess inputs for generation
296277 model_inputs = self .model .prepare_inputs_for_generation (
297- token_indices , ** model_kwargs
278+ curr_token_idxs , ** model_kwargs
298279 ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
299280 if self .is_huggingface_model :
300281 model_inputs .update (self ._huggingface_model_input_values )
301282 if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
302- beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
303-
304- # We could store this in model_kwargs
305- num_hyps_in_prev_step = model_kwargs ["past" ][0 ][0 ].shape [0 ]
306-
307- num_finished_hyps_in_step = num_hyps_in_prev_step - len (prev_step_hyp_idxs )
308- if num_finished_hyps_in_step > 0 :
309- beam_idxs = F .pad (beam_idxs , (0 , num_finished_hyps_in_step ), "constant" , 0 )
310-
311- beam_idxs = torch .clamp (beam_idxs , max = len (prev_step_hyp_idxs ) - 1 )
312-
313- reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
314-
315- if num_finished_hyps_in_step > 0 :
316- sliced_cache = ()
317- for states in reordered_cached :
318- sliced_state = ()
319- for state in states :
320- sliced_state = sliced_state + (state [: len (prev_step_hyp_idxs )],)
321- sliced_cache = sliced_cache + (sliced_state ,)
322- reordered_cached = sliced_cache
283+ beam_indices = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
284+ beam_idxs [i , : len (prev_step_hyp_idxs ), 0 ] = beam_indices
285+ curr_beam_idxs = beam_idxs [i , :, 0 ]
323286
287+ reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], curr_beam_idxs )
324288 model_inputs ["past_key_values" ] = reordered_cached
325289
326290 # Forward pass
@@ -334,18 +298,21 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
334298 if self .is_huggingface_model :
335299 self ._update_model_kwargs_for_generation (outputs , model_kwargs )
336300
301+ # Reset
302+ token_idxs [i , :, 0 ] = eos_idx
303+ beam_idxs [i , :, 0 ] = 0
304+
337305 # Keep track of probabilities over vocab for this pairing
338- # TODO: fix how we track the number here?
339- for i in range (lm_scores .shape [0 ]):
306+ for i in range (num_beams ):
340307 sample_lm_scores = lm_scores [i , - 1 ]
341308 out_probs .append (sample_lm_scores .tolist ())
342309 # Keep track of sequence and decoder hidden states
343310 model_states .append (
344311 create_emitting_model_state (
345312 Seq2SeqModelState (
346313 timestep = timestep ,
347- sequence = state_and_tokens [ i ]. unsqueeze ( 0 ) ,
348- lm_scores = sample_lm_scores ,
314+ sequence = [] ,
315+ lm_scores = 0 ,
349316 )
350317 )
351318 )
@@ -391,10 +358,6 @@ def is_not_neg_one(elem: int) -> bool:
391358 if not self .is_encoder_decoder :
392359 final_tokens = input_ids [timestep ].tolist () + final_tokens
393360
394- # Makeshift padding so that we can stack the tensors
395- while len (final_tokens ) < max_len :
396- final_tokens += [0 ]
397-
398361 # Convert from list to tensors
399362 final_tokens_as_tensors = torch .Tensor (final_tokens ).to (torch .long )
400363
0 commit comments