Skip to content

Commit 5a7cab9

Browse files
committed
The SBS class has been changes. It doesn't have the history of probabilities of sequences anymore.
1 parent d65164c commit 5a7cab9

File tree

1 file changed

+87
-66
lines changed

1 file changed

+87
-66
lines changed

src/decoding/speculative_decoding.py

Lines changed: 87 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def generate(self, src: 'torch.LongTensor') -> 'torch.LongTensor':
175175
return finished_predictions.unsqueeze(1) # (B, 1, Lg)
176176

177177

178-
179178
class TranslationInferenceBeamSearchSpeculativeBatchedWithoutLeftPads:
180179
def __init__(self,
181180
model, # TranslationModel
@@ -224,29 +223,28 @@ def __init__(self,
224223
def __str__(self):
225224
return f"SpeculativeSampling decoding (n_best={self.n_best}, max_len={self.max_len}, max_num_of_drafts={self.max_drafts_num}, draft_len={self.draft_len})"
226225

227-
def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts, b_size, bool_idx, n_accepted):
226+
def sample(self, curr_lines, curr_log_probs, pred_logits, chosen_drafts, b_size, draft_place_bool, n_accepted):
228227
"""
229228
This function samples all possible sequences within a selected draft. Each draft can
230229
produce (self.max_num_positions_for_sampling - 1) * num_of_approved_tokens + self.max_num_positions_for_sampling
231230
at most.
232231
233232
:param curr_lines: tensor (n_candidates, drafted_len),
234-
:param curr_log_probs_history: tensor (n_candidates, max_len),
233+
:param curr_log_probs: tensor (n_candidates, 1),
235234
:param pred_logits: tensor (n_candidates, draft_len + 1, vocab_size),
236235
:param chosen_drafts: tensor (n_candidates, draft_len),
237236
:param b_size: int,
238-
:param bool_idx: tensor (n_candidates, max_len), it contains true where the draft supposed to be in curr_lines,
237+
:param draft_place_bool: tensor (n_candidates, drafted_len), it contains true where the draft supposed to be in curr_lines,
239238
in each line there are draft_len trues
240239
:param n_accepted: tensor (n_candidates)
241240
:return:
242-
-> new_lines: tensor (num_lines, max_len),
243-
new_log_probs_history: tensor (num_lines, max_len)
241+
-> new_lines: tensor (num_lines, len),
242+
new_log_probs: tensor (num_lines, 1)
244243
num_of_new_seqs_for_each_in_batch: tensor (b_size)
245244
token_postn: tensor (num_lines), to calculate the number of accepted tokens in the next top n sequences
246245
later; self.acceptance_rate_pad_for_already_finished_seqs means that the given sequence had already the
247246
eos token and so didn't need subsequent tokens
248247
"""
249-
drafted_len = curr_lines.shape[1]
250248
n_candidates, draft_len_plus_one, vocab_size = pred_logits.size()
251249

252250
draft_len = draft_len_plus_one - 1
@@ -298,8 +296,8 @@ def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts,
298296
previous_roots = curr_lines[candts_inds] # (num, drafted_len)
299297
already_finished_given_seqs = (previous_roots == self.eos_token_idx).sum(-1).bool() # -> (num)
300298

301-
log_prob_history_of_roots = curr_log_probs_history[candts_inds] # (num, max_len)
302-
bool_idx = bool_idx[candts_inds] # (num, max_len)
299+
log_prob_of_roots = curr_log_probs[candts_inds] # (num, 1)
300+
draft_place_bool = draft_place_bool[candts_inds] # (num, max_len)
303301

304302
drafts = chosen_drafts[candts_inds] # (num, draft_len)
305303
tail = torch.full((num, 1), 0.).type_as(drafts) # -> (num, 1)
@@ -313,28 +311,45 @@ def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts,
313311

314312
new_seqs_log_probs = torch.gather(predicted_log_probs, dim=2, index=new_seqs.unsqueeze(-1)).squeeze(-1)
315313
# -> (num, draft_len + 1)
314+
new_seqs_log_probs.masked_fill_(mask_for_tokens_after_the_sampled, 0.)
315+
# -> (num, draft_len + 1)
316316
new_seqs_log_probs = new_seqs_log_probs.cumsum(dim=-1) # -> (num, draft_len + 1)
317317

318-
last_log_prob_from_roots = torch.min(log_prob_history_of_roots, dim=-1, keepdim=True).values
318+
last_log_prob_from_roots = torch.min(log_prob_of_roots, dim=-1, keepdim=True).values
319319
# (num, 1)
320-
new_seqs_log_probs = last_log_prob_from_roots + new_seqs_log_probs
321-
# -> (num, draft_len + 1)
320+
new_seqs_log_probs = last_log_prob_from_roots + new_seqs_log_probs[:, -1:]
321+
# -> (num, 1)
322322
new_seqs.masked_fill_(mask_for_tokens_after_the_sampled, self.pad_token_idx)
323323
# -> (num, draft_len + 1)
324-
new_seqs_log_probs.masked_fill_(mask_for_tokens_after_the_sampled, self.log_prob_pad)
325-
# -> (num, draft_len + 1)
326324

327-
tmp = torch.logical_or(bool_idx, torch.roll(bool_idx, 1, 1))
328-
# -> (num, max_len)
329-
previous_roots = torch.cat((previous_roots, tail), dim=-1) # (num, drafted_len + 1)
330-
previous_roots[tmp[:, :drafted_len + 1]] = new_seqs.reshape(
331-
-1) # it is new sequences sampled from the chosen drafts
332-
log_prob_history_of_roots[tmp] = new_seqs_log_probs.reshape(-1)
325+
new_seqs_place_bool = torch.logical_or(draft_place_bool, torch.roll(draft_place_bool, 1, 1))
326+
# -> (num, drafted_len) It contains draft_len+1 Trues in each line
327+
previous_roots[new_seqs_place_bool] = new_seqs.reshape(-1)
333328

334329
token_postn[already_finished_given_seqs] = self.acceptance_rate_pad_for_alredy_finished_seqs
335330
# the given sequences with eos didn't need the draft tokens. We
336331
# don't take pads into account calculating the acceptance rate
337-
return previous_roots, log_prob_history_of_roots, num_of_new_seqs_for_each_in_batch, token_postn
332+
return previous_roots, new_seqs_log_probs, num_of_new_seqs_for_each_in_batch, token_postn
333+
334+
def get_vocab_tokens_bool_lib(self, draft_lib):
335+
"""
336+
:param draft_lib: tensor (b_size, n_drafts, draft_len),
337+
338+
:return:
339+
-> vocab_tokens_bool_lib: tensor (b_sz, vocab_size, n_drafts),
340+
"""
341+
342+
draft_start_tokens = draft_lib[:, :, 0]
343+
# -> (b_sz, n_drafts)
344+
b_sz, n_drafts = draft_start_tokens.size()
345+
vocab_tokens = torch.arange(self.vocab_size).unsqueeze(0).unsqueeze(-1).expand(b_sz, self.vocab_size, n_drafts)
346+
# -> (b_sz, vocab_size, n_drafts)
347+
vocab_tokens_bool = draft_start_tokens.unsqueeze(1).expand(b_sz, self.vocab_size, n_drafts) == vocab_tokens.type_as(draft_lib)
348+
# -> (b_sz, vocab_size, n_drafts)
349+
t = vocab_tokens_bool.view(-1, n_drafts)
350+
t[t.sum(-1) == 0, 0] = True
351+
t[t.cumsum(-1) > self.requested_drafts_num] = False
352+
return vocab_tokens_bool
338353

339354
def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
340355
# we don't need the bos token in drafts
@@ -355,19 +370,24 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
355370

356371
iters = -1
357372

358-
generated_tokens = torch.full((1, 1), self.bos_token_idx).type_as(src).long().repeat(b_size, 1)
359-
# -> (b_size, 1)
373+
generated_tokens = torch.full((b_size, 1), self.bos_token_idx, device=src.device)
374+
# -> (b_size, 1)
360375

361-
log_probs_history = torch.full((1, self.max_len), self.log_prob_pad).type_as(src).float().repeat(b_size, 1)
362-
# -> (b_size, max_len)
363-
log_probs_history[:, 0] = 0.
376+
log_probs = torch.full((b_size, 1), 0., device=src.device)
377+
# -> (b_size, 1)
364378

365-
possible_draft_len = self.max_len - 2
379+
num_of_empty_columns = ((generated_tokens == self.pad_token_idx).sum(0) == b_size).sum().item()
380+
# -> (1,)
381+
postn_of_last_meaning_token = generated_tokens.shape[1] - num_of_empty_columns
382+
# -> (1,)
383+
possible_draft_len = self.max_len - postn_of_last_meaning_token - 1
384+
# -> (b_size, 1)
385+
beam_size = 1
366386

367387
logits_base = torch.full((b_size * n_drafts, draft_len + 1, self.vocab_size), 0., device=src.device)
368388
# -> (b_s * n_drafts, draft_len + 1, vocab_size)
369389

370-
while possible_draft_len > 1 and iters < self.max_len:
390+
while possible_draft_len >= 1 and postn_of_last_meaning_token <= self.max_len:
371391
iters += 1
372392
logits_base = logits_base * 0.
373393
# We use artificial logits to avoid calculation of obvious pad predicting after eos
@@ -393,34 +413,35 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
393413

394414
pads_num = (generated_tokens == self.pad_token_idx).sum(-1)
395415
# -> (n_candidates)
396-
pad_base_len = draft_len - torch.min(pads_num).item()
397-
if pad_base_len > 0:
398-
draft_base = torch.full((n_candidates, pad_base_len), self.pad_token_idx, device=src.device)
399-
generated_tokens = torch.cat((generated_tokens, draft_base), dim=-1)
416+
draft_place_len = draft_len + 1 - num_of_empty_columns
417+
if draft_place_len > 0:
418+
draft_place = torch.full((n_candidates, draft_place_len), self.pad_token_idx, device=src.device)
419+
generated_tokens = torch.cat((generated_tokens, draft_place), dim=-1)
400420
# -> (n_candidates, drafted_len)
401421

402422
logits_base = logits_base[:, :draft_len + 1, :]
403423

404424
self.model_calls_num += 1
405-
log_prob_pad_t_bool = log_probs_history == self.log_prob_pad
425+
pad_place_bool = generated_tokens == self.pad_token_idx
426+
# -> (n_candidates, drafted_len)
427+
draft_place_bool = torch.logical_and(pad_place_bool,
428+
pad_place_bool.cumsum(-1) <= draft_len)
429+
# -> (n_candidates, drafted_len)
406430

407-
bool_idx = torch.logical_and(log_prob_pad_t_bool,
408-
log_prob_pad_t_bool.cumsum(-1) <= draft_len)
409-
# -> (b_s * bm_sz, max_len)
410-
bool_idx_input = bool_idx[:, :generated_tokens.shape[1]].unsqueeze(1).repeat(1, n_drafts, 1)
431+
draft_place_bool_idx_input = draft_place_bool.unsqueeze(1).repeat(1, n_drafts, 1)
411432
# -> (b_s * bm_sz, n_drafts, drafted_len)
412433
generated_tokens_input = generated_tokens.unsqueeze(1).repeat(1, n_drafts, 1)
413434
# -> (b_s * bm_sz, n_drafts, drafted_len)
414435

415-
generated_tokens_input[bool_idx_input] = draft_tokens.reshape(-1)
416-
bool_idx_input = bool_idx_input.flatten(end_dim=1)
436+
generated_tokens_input[draft_place_bool_idx_input] = draft_tokens.reshape(-1)
437+
draft_place_bool_idx_input = draft_place_bool_idx_input.flatten(end_dim=1)
417438
# -> (b_s * bm_sz * n_drafts, drafted_len)
418439
generated_tokens_input = generated_tokens_input.flatten(end_dim=1)
419440
# # -> (b_s * bm_sz * n_drafts, drafted_len, vocab_size)
420441

421442
bool_idx_of_unfinished = bool_idx_of_unfinished.unsqueeze(-1).repeat(1, n_drafts).flatten(end_dim=1)
422443
# -> (b_s * bm_sz * n_drafts)
423-
bool_idx_input = bool_idx_input[bool_idx_of_unfinished] #
444+
draft_place_bool_idx_input = draft_place_bool_idx_input[bool_idx_of_unfinished]
424445
# -> (num_of_unfinished, drafted_len)
425446
pred_logits = self.model.decode_tgt(generated_tokens_input[bool_idx_of_unfinished],
426447
memory[bool_idx_of_unfinished],
@@ -429,7 +450,7 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
429450

430451
vocab_size = pred_logits.shape[-1]
431452

432-
pred_logits = pred_logits[torch.logical_or(bool_idx_input, torch.roll(bool_idx_input, -1, 1))].reshape(
453+
pred_logits = pred_logits[torch.logical_or(draft_place_bool_idx_input, torch.roll(draft_place_bool_idx_input, -1, 1))].reshape(
433454
-1, draft_len + 1, vocab_size)
434455
# -> (num_of_unfinished, draft_len + 1, vocab_size)
435456

@@ -441,7 +462,7 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
441462
# approved tokens is the best draft for the given candidate. #########################################
442463

443464
# All unapproved tokens in masked_probs have zero probability
444-
# We use nucleus=0.9975 and max_num_of_unmasked_positions=5 to avoid sampling of low probable sequences
465+
# We use nucleus=0.9975 and max_num_of_unmasked_positions=n_best to avoid sampling of low probable sequences
445466
# and reduce calculation
446467
masked_probs = mask_with_num_logits_according_nucleus(pred_logits, nucleus=0.9975,
447468
max_num_of_unmasked_positions=self.n_best,
@@ -473,10 +494,9 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
473494

474495
# Sample all possible lines within the chosen drafts:
475496
# new_candidates have the initial tokens and the new ones
476-
477-
new_candidates, new_log_probs_history, num_of_new_seqs_for_each_in_batch, accepted_tokens_num = \
478-
self.sample(generated_tokens, log_probs_history, pred_logits,
479-
chosen_drafts, b_size, bool_idx, n_accepted.squeeze(-1))
497+
new_candidates, new_log_probs, num_of_new_seqs_for_each_in_batch, accepted_tokens_num = \
498+
self.sample(generated_tokens, log_probs, pred_logits,
499+
chosen_drafts, b_size, draft_place_bool, n_accepted.squeeze(-1))
480500

481501
###########################################################################################################
482502
max_num_of_new_seqs = torch.max(num_of_new_seqs_for_each_in_batch).item()
@@ -501,26 +521,23 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
501521
-1) # -> (b_size * max_num_of_new_seqs)
502522
new_candidates = new_candidates[inds]
503523
# -> (b_size * max_num_of_new_seqs, drafted_len + 1)
504-
new_log_probs_history = new_log_probs_history[inds]
505-
# -> (b_size * max_num_of_new_seqs, max_len)
524+
new_log_probs = new_log_probs[inds]
525+
# -> (b_size * max_num_of_new_seqs, 1)
506526
accepted_tokens_num = accepted_tokens_num[inds]
507527
# -> (b_size * max_num_of_new_seqs)
508528
new_candidates[mask_for_fake_seqs, 1] = self.eos_token_idx # fake sequences
509-
new_log_probs_history[mask_for_fake_seqs, 1] = -float("inf") # fake probabilities
529+
new_log_probs[mask_for_fake_seqs, 0] = -float("inf") # fake probabilities
510530
accepted_tokens_num[mask_for_fake_seqs] = self.acceptance_rate_pad_for_fake_seqs # fake
511531
#############################################################################################
512532

513-
new_log_probs = torch.min(new_log_probs_history, dim=1).values
514-
# -> (b_size * max_num_of_new_seqs)
515533
new_log_probs = new_log_probs.reshape(b_size, max_num_of_new_seqs)
516534
# -> (b_size, max_num_of_new_seqs)
517-
v, top_inds = new_log_probs.topk(k=self.n_best, axis=-1, sorted=True)
535+
new_log_probs, top_inds = new_log_probs.topk(k=self.n_best, axis=-1, sorted=True)
518536
# -> (b_size, beam_size)
519537

520538
new_candidates = new_candidates.reshape(b_size, max_num_of_new_seqs, -1)
521-
# -> (b_size, max_num_of_new_seqs, drafted_len + 1)
522-
new_log_probs_history = new_log_probs_history.reshape(b_size, max_num_of_new_seqs, -1)
523-
# -> (b_size, max_num_of_new_seqs, max_len)
539+
# -> (b_size, max_num_of_new_seqs, drafted_len)
540+
524541
accepted_tokens_num = accepted_tokens_num.reshape(b_size, max_num_of_new_seqs)
525542
# -> (b_size, max_num_of_new_seqs)
526543

@@ -534,25 +551,29 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
534551
self.accepted_tokens_num += curr_accepted_tokens_num
535552
self.produced_non_pad_tokens += curr_accepted_tokens_num + accepted_tokens_num.size(0)
536553

537-
top_inds = top_inds.unsqueeze(-1).repeat(1, 1, new_log_probs_history.shape[-1])
538-
# -> (b_size, beam_size, max_len)
539-
new_log_probs_history = torch.gather(new_log_probs_history, 1, top_inds)
540-
# -> (b_size, beam_size, max_len)
541-
new_candidates = torch.gather(new_candidates, 1, top_inds[:, :, :new_candidates.shape[-1]])
542-
# -> (b_size, beam_size, drafted_len + 1)
554+
top_inds = top_inds.unsqueeze(-1).repeat(1, 1, new_candidates.shape[-1])
555+
# -> (b_size, beam_size, drafted_len)
556+
557+
new_candidates = torch.gather(new_candidates, 1, top_inds)
558+
# -> (b_size, beam_size, drafted_len)
543559

544560
if (new_candidates[not_fake_bool] == self.eos_token_idx).sum(-1).bool().sum() == b_size * self.n_best:
545561
break
546562

547563
generated_tokens = new_candidates.reshape(b_size * self.n_best, -1)
548-
# -> (b_size * beam_size, drafted_len + 1)
549-
new_log_probs_history = new_log_probs_history.reshape(b_size * self.n_best, -1)
550-
# -> (b_size * beam_size, max_len)
564+
# -> (b_size * beam_size, drafted_len)
565+
log_probs = new_log_probs.reshape(b_size * self.n_best, 1)
566+
# -> (b_size * beam_size, 1)
551567
not_fake_bool = not_fake_bool.reshape(b_size * self.n_best)
552568
# -> (b_size * beam_size)
553-
log_probs_history = new_log_probs_history
554569

555-
possible_draft_len = torch.min((new_log_probs_history[not_fake_bool] == self.log_prob_pad).sum(-1)).item() - 1
570+
num_of_empty_columns = torch.min((generated_tokens[not_fake_bool] == self.pad_token_idx).sum(-1)).item()
571+
# -> (1,)
572+
postn_of_last_meaning_token = generated_tokens[not_fake_bool].shape[1] - num_of_empty_columns
573+
# -> (1,)
574+
possible_draft_len = self.max_len - postn_of_last_meaning_token - 1
575+
# -> (b_size, 1)
576+
556577
return new_candidates
557578

558579
def calculate_n_accepted_in_drafts(self, draft_tokens, masked_probs):

0 commit comments

Comments
 (0)