@@ -175,7 +175,6 @@ def generate(self, src: 'torch.LongTensor') -> 'torch.LongTensor':
175
175
return finished_predictions .unsqueeze (1 ) # (B, 1, Lg)
176
176
177
177
178
-
179
178
class TranslationInferenceBeamSearchSpeculativeBatchedWithoutLeftPads :
180
179
def __init__ (self ,
181
180
model , # TranslationModel
@@ -224,29 +223,28 @@ def __init__(self,
224
223
def __str__ (self ):
225
224
return f"SpeculativeSampling decoding (n_best={ self .n_best } , max_len={ self .max_len } , max_num_of_drafts={ self .max_drafts_num } , draft_len={ self .draft_len } )"
226
225
227
- def sample (self , curr_lines , curr_log_probs_history , pred_logits , chosen_drafts , b_size , bool_idx , n_accepted ):
226
+ def sample (self , curr_lines , curr_log_probs , pred_logits , chosen_drafts , b_size , draft_place_bool , n_accepted ):
228
227
"""
229
228
This function samples all possible sequences within a selected draft. Each draft can
230
229
produce (self.max_num_positions_for_sampling - 1) * num_of_approved_tokens + self.max_num_positions_for_sampling
231
230
at most.
232
231
233
232
:param curr_lines: tensor (n_candidates, drafted_len),
234
- :param curr_log_probs_history : tensor (n_candidates, max_len ),
233
+ :param curr_log_probs : tensor (n_candidates, 1 ),
235
234
:param pred_logits: tensor (n_candidates, draft_len + 1, vocab_size),
236
235
:param chosen_drafts: tensor (n_candidates, draft_len),
237
236
:param b_size: int,
238
- :param bool_idx : tensor (n_candidates, max_len ), it contains true where the draft supposed to be in curr_lines,
237
+ :param draft_place_bool : tensor (n_candidates, drafted_len ), it contains true where the draft supposed to be in curr_lines,
239
238
in each line there are draft_len trues
240
239
:param n_accepted: tensor (n_candidates)
241
240
:return:
242
- -> new_lines: tensor (num_lines, max_len ),
243
- new_log_probs_history : tensor (num_lines, max_len )
241
+ -> new_lines: tensor (num_lines, len ),
242
+ new_log_probs : tensor (num_lines, 1 )
244
243
num_of_new_seqs_for_each_in_batch: tensor (b_size)
245
244
token_postn: tensor (num_lines), to calculate the number of accepted tokens in the next top n sequences
246
245
later; self.acceptance_rate_pad_for_already_finished_seqs means that the given sequence had already the
247
246
eos token and so didn't need subsequent tokens
248
247
"""
249
- drafted_len = curr_lines .shape [1 ]
250
248
n_candidates , draft_len_plus_one , vocab_size = pred_logits .size ()
251
249
252
250
draft_len = draft_len_plus_one - 1
@@ -298,8 +296,8 @@ def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts,
298
296
previous_roots = curr_lines [candts_inds ] # (num, drafted_len)
299
297
already_finished_given_seqs = (previous_roots == self .eos_token_idx ).sum (- 1 ).bool () # -> (num)
300
298
301
- log_prob_history_of_roots = curr_log_probs_history [candts_inds ] # (num, max_len )
302
- bool_idx = bool_idx [candts_inds ] # (num, max_len)
299
+ log_prob_of_roots = curr_log_probs [candts_inds ] # (num, 1 )
300
+ draft_place_bool = draft_place_bool [candts_inds ] # (num, max_len)
303
301
304
302
drafts = chosen_drafts [candts_inds ] # (num, draft_len)
305
303
tail = torch .full ((num , 1 ), 0. ).type_as (drafts ) # -> (num, 1)
@@ -313,28 +311,45 @@ def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts,
313
311
314
312
new_seqs_log_probs = torch .gather (predicted_log_probs , dim = 2 , index = new_seqs .unsqueeze (- 1 )).squeeze (- 1 )
315
313
# -> (num, draft_len + 1)
314
+ new_seqs_log_probs .masked_fill_ (mask_for_tokens_after_the_sampled , 0. )
315
+ # -> (num, draft_len + 1)
316
316
new_seqs_log_probs = new_seqs_log_probs .cumsum (dim = - 1 ) # -> (num, draft_len + 1)
317
317
318
- last_log_prob_from_roots = torch .min (log_prob_history_of_roots , dim = - 1 , keepdim = True ).values
318
+ last_log_prob_from_roots = torch .min (log_prob_of_roots , dim = - 1 , keepdim = True ).values
319
319
# (num, 1)
320
- new_seqs_log_probs = last_log_prob_from_roots + new_seqs_log_probs
321
- # -> (num, draft_len + 1)
320
+ new_seqs_log_probs = last_log_prob_from_roots + new_seqs_log_probs [:, - 1 :]
321
+ # -> (num, 1)
322
322
new_seqs .masked_fill_ (mask_for_tokens_after_the_sampled , self .pad_token_idx )
323
323
# -> (num, draft_len + 1)
324
- new_seqs_log_probs .masked_fill_ (mask_for_tokens_after_the_sampled , self .log_prob_pad )
325
- # -> (num, draft_len + 1)
326
324
327
- tmp = torch .logical_or (bool_idx , torch .roll (bool_idx , 1 , 1 ))
328
- # -> (num, max_len)
329
- previous_roots = torch .cat ((previous_roots , tail ), dim = - 1 ) # (num, drafted_len + 1)
330
- previous_roots [tmp [:, :drafted_len + 1 ]] = new_seqs .reshape (
331
- - 1 ) # it is new sequences sampled from the chosen drafts
332
- log_prob_history_of_roots [tmp ] = new_seqs_log_probs .reshape (- 1 )
325
+ new_seqs_place_bool = torch .logical_or (draft_place_bool , torch .roll (draft_place_bool , 1 , 1 ))
326
+ # -> (num, drafted_len) It contains draft_len+1 Trues in each line
327
+ previous_roots [new_seqs_place_bool ] = new_seqs .reshape (- 1 )
333
328
334
329
token_postn [already_finished_given_seqs ] = self .acceptance_rate_pad_for_alredy_finished_seqs
335
330
# the given sequences with eos didn't need the draft tokens. We
336
331
# don't take pads into account calculating the acceptance rate
337
- return previous_roots , log_prob_history_of_roots , num_of_new_seqs_for_each_in_batch , token_postn
332
+ return previous_roots , new_seqs_log_probs , num_of_new_seqs_for_each_in_batch , token_postn
333
+
334
+ def get_vocab_tokens_bool_lib (self , draft_lib ):
335
+ """
336
+ :param draft_lib: tensor (b_size, n_drafts, draft_len),
337
+
338
+ :return:
339
+ -> vocab_tokens_bool_lib: tensor (b_sz, vocab_size, n_drafts),
340
+ """
341
+
342
+ draft_start_tokens = draft_lib [:, :, 0 ]
343
+ # -> (b_sz, n_drafts)
344
+ b_sz , n_drafts = draft_start_tokens .size ()
345
+ vocab_tokens = torch .arange (self .vocab_size ).unsqueeze (0 ).unsqueeze (- 1 ).expand (b_sz , self .vocab_size , n_drafts )
346
+ # -> (b_sz, vocab_size, n_drafts)
347
+ vocab_tokens_bool = draft_start_tokens .unsqueeze (1 ).expand (b_sz , self .vocab_size , n_drafts ) == vocab_tokens .type_as (draft_lib )
348
+ # -> (b_sz, vocab_size, n_drafts)
349
+ t = vocab_tokens_bool .view (- 1 , n_drafts )
350
+ t [t .sum (- 1 ) == 0 , 0 ] = True
351
+ t [t .cumsum (- 1 ) > self .requested_drafts_num ] = False
352
+ return vocab_tokens_bool
338
353
339
354
def generate (self , src : 'torch.LongTensor' ) -> list ['torch.LongTensor' ]:
340
355
# we don't need the bos token in drafts
@@ -355,19 +370,24 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
355
370
356
371
iters = - 1
357
372
358
- generated_tokens = torch .full ((1 , 1 ), self .bos_token_idx ). type_as ( src ). long (). repeat ( b_size , 1 )
359
- # -> (b_size, 1)
373
+ generated_tokens = torch .full ((b_size , 1 ), self .bos_token_idx , device = src . device )
374
+ # -> (b_size, 1)
360
375
361
- log_probs_history = torch .full ((1 , self .max_len ), self .log_prob_pad ).type_as (src ).float ().repeat (b_size , 1 )
362
- # -> (b_size, max_len)
363
- log_probs_history [:, 0 ] = 0.
376
+ log_probs = torch .full ((b_size , 1 ), 0. , device = src .device )
377
+ # -> (b_size, 1)
364
378
365
- possible_draft_len = self .max_len - 2
379
+ num_of_empty_columns = ((generated_tokens == self .pad_token_idx ).sum (0 ) == b_size ).sum ().item ()
380
+ # -> (1,)
381
+ postn_of_last_meaning_token = generated_tokens .shape [1 ] - num_of_empty_columns
382
+ # -> (1,)
383
+ possible_draft_len = self .max_len - postn_of_last_meaning_token - 1
384
+ # -> (b_size, 1)
385
+ beam_size = 1
366
386
367
387
logits_base = torch .full ((b_size * n_drafts , draft_len + 1 , self .vocab_size ), 0. , device = src .device )
368
388
# -> (b_s * n_drafts, draft_len + 1, vocab_size)
369
389
370
- while possible_draft_len > 1 and iters < self .max_len :
390
+ while possible_draft_len >= 1 and postn_of_last_meaning_token <= self .max_len :
371
391
iters += 1
372
392
logits_base = logits_base * 0.
373
393
# We use artificial logits to avoid calculation of obvious pad predicting after eos
@@ -393,34 +413,35 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
393
413
394
414
pads_num = (generated_tokens == self .pad_token_idx ).sum (- 1 )
395
415
# -> (n_candidates)
396
- pad_base_len = draft_len - torch . min ( pads_num ). item ()
397
- if pad_base_len > 0 :
398
- draft_base = torch .full ((n_candidates , pad_base_len ), self .pad_token_idx , device = src .device )
399
- generated_tokens = torch .cat ((generated_tokens , draft_base ), dim = - 1 )
416
+ draft_place_len = draft_len + 1 - num_of_empty_columns
417
+ if draft_place_len > 0 :
418
+ draft_place = torch .full ((n_candidates , draft_place_len ), self .pad_token_idx , device = src .device )
419
+ generated_tokens = torch .cat ((generated_tokens , draft_place ), dim = - 1 )
400
420
# -> (n_candidates, drafted_len)
401
421
402
422
logits_base = logits_base [:, :draft_len + 1 , :]
403
423
404
424
self .model_calls_num += 1
405
- log_prob_pad_t_bool = log_probs_history == self .log_prob_pad
425
+ pad_place_bool = generated_tokens == self .pad_token_idx
426
+ # -> (n_candidates, drafted_len)
427
+ draft_place_bool = torch .logical_and (pad_place_bool ,
428
+ pad_place_bool .cumsum (- 1 ) <= draft_len )
429
+ # -> (n_candidates, drafted_len)
406
430
407
- bool_idx = torch .logical_and (log_prob_pad_t_bool ,
408
- log_prob_pad_t_bool .cumsum (- 1 ) <= draft_len )
409
- # -> (b_s * bm_sz, max_len)
410
- bool_idx_input = bool_idx [:, :generated_tokens .shape [1 ]].unsqueeze (1 ).repeat (1 , n_drafts , 1 )
431
+ draft_place_bool_idx_input = draft_place_bool .unsqueeze (1 ).repeat (1 , n_drafts , 1 )
411
432
# -> (b_s * bm_sz, n_drafts, drafted_len)
412
433
generated_tokens_input = generated_tokens .unsqueeze (1 ).repeat (1 , n_drafts , 1 )
413
434
# -> (b_s * bm_sz, n_drafts, drafted_len)
414
435
415
- generated_tokens_input [bool_idx_input ] = draft_tokens .reshape (- 1 )
416
- bool_idx_input = bool_idx_input .flatten (end_dim = 1 )
436
+ generated_tokens_input [draft_place_bool_idx_input ] = draft_tokens .reshape (- 1 )
437
+ draft_place_bool_idx_input = draft_place_bool_idx_input .flatten (end_dim = 1 )
417
438
# -> (b_s * bm_sz * n_drafts, drafted_len)
418
439
generated_tokens_input = generated_tokens_input .flatten (end_dim = 1 )
419
440
# # -> (b_s * bm_sz * n_drafts, drafted_len, vocab_size)
420
441
421
442
bool_idx_of_unfinished = bool_idx_of_unfinished .unsqueeze (- 1 ).repeat (1 , n_drafts ).flatten (end_dim = 1 )
422
443
# -> (b_s * bm_sz * n_drafts)
423
- bool_idx_input = bool_idx_input [bool_idx_of_unfinished ] #
444
+ draft_place_bool_idx_input = draft_place_bool_idx_input [bool_idx_of_unfinished ]
424
445
# -> (num_of_unfinished, drafted_len)
425
446
pred_logits = self .model .decode_tgt (generated_tokens_input [bool_idx_of_unfinished ],
426
447
memory [bool_idx_of_unfinished ],
@@ -429,7 +450,7 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
429
450
430
451
vocab_size = pred_logits .shape [- 1 ]
431
452
432
- pred_logits = pred_logits [torch .logical_or (bool_idx_input , torch .roll (bool_idx_input , - 1 , 1 ))].reshape (
453
+ pred_logits = pred_logits [torch .logical_or (draft_place_bool_idx_input , torch .roll (draft_place_bool_idx_input , - 1 , 1 ))].reshape (
433
454
- 1 , draft_len + 1 , vocab_size )
434
455
# -> (num_of_unfinished, draft_len + 1, vocab_size)
435
456
@@ -441,7 +462,7 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
441
462
# approved tokens is the best draft for the given candidate. #########################################
442
463
443
464
# All unapproved tokens in masked_probs have zero probability
444
- # We use nucleus=0.9975 and max_num_of_unmasked_positions=5 to avoid sampling of low probable sequences
465
+ # We use nucleus=0.9975 and max_num_of_unmasked_positions=n_best to avoid sampling of low probable sequences
445
466
# and reduce calculation
446
467
masked_probs = mask_with_num_logits_according_nucleus (pred_logits , nucleus = 0.9975 ,
447
468
max_num_of_unmasked_positions = self .n_best ,
@@ -473,10 +494,9 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
473
494
474
495
# Sample all possible lines within the chosen drafts:
475
496
# new_candidates have the initial tokens and the new ones
476
-
477
- new_candidates , new_log_probs_history , num_of_new_seqs_for_each_in_batch , accepted_tokens_num = \
478
- self .sample (generated_tokens , log_probs_history , pred_logits ,
479
- chosen_drafts , b_size , bool_idx , n_accepted .squeeze (- 1 ))
497
+ new_candidates , new_log_probs , num_of_new_seqs_for_each_in_batch , accepted_tokens_num = \
498
+ self .sample (generated_tokens , log_probs , pred_logits ,
499
+ chosen_drafts , b_size , draft_place_bool , n_accepted .squeeze (- 1 ))
480
500
481
501
###########################################################################################################
482
502
max_num_of_new_seqs = torch .max (num_of_new_seqs_for_each_in_batch ).item ()
@@ -501,26 +521,23 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
501
521
- 1 ) # -> (b_size * max_num_of_new_seqs)
502
522
new_candidates = new_candidates [inds ]
503
523
# -> (b_size * max_num_of_new_seqs, drafted_len + 1)
504
- new_log_probs_history = new_log_probs_history [inds ]
505
- # -> (b_size * max_num_of_new_seqs, max_len )
524
+ new_log_probs = new_log_probs [inds ]
525
+ # -> (b_size * max_num_of_new_seqs, 1 )
506
526
accepted_tokens_num = accepted_tokens_num [inds ]
507
527
# -> (b_size * max_num_of_new_seqs)
508
528
new_candidates [mask_for_fake_seqs , 1 ] = self .eos_token_idx # fake sequences
509
- new_log_probs_history [mask_for_fake_seqs , 1 ] = - float ("inf" ) # fake probabilities
529
+ new_log_probs [mask_for_fake_seqs , 0 ] = - float ("inf" ) # fake probabilities
510
530
accepted_tokens_num [mask_for_fake_seqs ] = self .acceptance_rate_pad_for_fake_seqs # fake
511
531
#############################################################################################
512
532
513
- new_log_probs = torch .min (new_log_probs_history , dim = 1 ).values
514
- # -> (b_size * max_num_of_new_seqs)
515
533
new_log_probs = new_log_probs .reshape (b_size , max_num_of_new_seqs )
516
534
# -> (b_size, max_num_of_new_seqs)
517
- v , top_inds = new_log_probs .topk (k = self .n_best , axis = - 1 , sorted = True )
535
+ new_log_probs , top_inds = new_log_probs .topk (k = self .n_best , axis = - 1 , sorted = True )
518
536
# -> (b_size, beam_size)
519
537
520
538
new_candidates = new_candidates .reshape (b_size , max_num_of_new_seqs , - 1 )
521
- # -> (b_size, max_num_of_new_seqs, drafted_len + 1)
522
- new_log_probs_history = new_log_probs_history .reshape (b_size , max_num_of_new_seqs , - 1 )
523
- # -> (b_size, max_num_of_new_seqs, max_len)
539
+ # -> (b_size, max_num_of_new_seqs, drafted_len)
540
+
524
541
accepted_tokens_num = accepted_tokens_num .reshape (b_size , max_num_of_new_seqs )
525
542
# -> (b_size, max_num_of_new_seqs)
526
543
@@ -534,25 +551,29 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
534
551
self .accepted_tokens_num += curr_accepted_tokens_num
535
552
self .produced_non_pad_tokens += curr_accepted_tokens_num + accepted_tokens_num .size (0 )
536
553
537
- top_inds = top_inds .unsqueeze (- 1 ).repeat (1 , 1 , new_log_probs_history .shape [- 1 ])
538
- # -> (b_size, beam_size, max_len)
539
- new_log_probs_history = torch .gather (new_log_probs_history , 1 , top_inds )
540
- # -> (b_size, beam_size, max_len)
541
- new_candidates = torch .gather (new_candidates , 1 , top_inds [:, :, :new_candidates .shape [- 1 ]])
542
- # -> (b_size, beam_size, drafted_len + 1)
554
+ top_inds = top_inds .unsqueeze (- 1 ).repeat (1 , 1 , new_candidates .shape [- 1 ])
555
+ # -> (b_size, beam_size, drafted_len)
556
+
557
+ new_candidates = torch .gather (new_candidates , 1 , top_inds )
558
+ # -> (b_size, beam_size, drafted_len)
543
559
544
560
if (new_candidates [not_fake_bool ] == self .eos_token_idx ).sum (- 1 ).bool ().sum () == b_size * self .n_best :
545
561
break
546
562
547
563
generated_tokens = new_candidates .reshape (b_size * self .n_best , - 1 )
548
- # -> (b_size * beam_size, drafted_len + 1 )
549
- new_log_probs_history = new_log_probs_history .reshape (b_size * self .n_best , - 1 )
550
- # -> (b_size * beam_size, max_len )
564
+ # -> (b_size * beam_size, drafted_len)
565
+ log_probs = new_log_probs .reshape (b_size * self .n_best , 1 )
566
+ # -> (b_size * beam_size, 1 )
551
567
not_fake_bool = not_fake_bool .reshape (b_size * self .n_best )
552
568
# -> (b_size * beam_size)
553
- log_probs_history = new_log_probs_history
554
569
555
- possible_draft_len = torch .min ((new_log_probs_history [not_fake_bool ] == self .log_prob_pad ).sum (- 1 )).item () - 1
570
+ num_of_empty_columns = torch .min ((generated_tokens [not_fake_bool ] == self .pad_token_idx ).sum (- 1 )).item ()
571
+ # -> (1,)
572
+ postn_of_last_meaning_token = generated_tokens [not_fake_bool ].shape [1 ] - num_of_empty_columns
573
+ # -> (1,)
574
+ possible_draft_len = self .max_len - postn_of_last_meaning_token - 1
575
+ # -> (b_size, 1)
576
+
556
577
return new_candidates
557
578
558
579
def calculate_n_accepted_in_drafts (self , draft_tokens , masked_probs ):
0 commit comments