Skip to content

Commit

Permalink
The SBS class has been changes. It doesn't have the history of probab…
Browse files Browse the repository at this point in the history
…ilities of sequences anymore.
  • Loading branch information
AndroNata committed Dec 22, 2024
1 parent d65164c commit 5a7cab9
Showing 1 changed file with 87 additions and 66 deletions.
153 changes: 87 additions & 66 deletions src/decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def generate(self, src: 'torch.LongTensor') -> 'torch.LongTensor':
return finished_predictions.unsqueeze(1) # (B, 1, Lg)



class TranslationInferenceBeamSearchSpeculativeBatchedWithoutLeftPads:
def __init__(self,
model, # TranslationModel
Expand Down Expand Up @@ -224,29 +223,28 @@ def __init__(self,
def __str__(self):
return f"SpeculativeSampling decoding (n_best={self.n_best}, max_len={self.max_len}, max_num_of_drafts={self.max_drafts_num}, draft_len={self.draft_len})"

def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts, b_size, bool_idx, n_accepted):
def sample(self, curr_lines, curr_log_probs, pred_logits, chosen_drafts, b_size, draft_place_bool, n_accepted):
"""
This function samples all possible sequences within a selected draft. Each draft can
produce (self.max_num_positions_for_sampling - 1) * num_of_approved_tokens + self.max_num_positions_for_sampling
at most.
:param curr_lines: tensor (n_candidates, drafted_len),
:param curr_log_probs_history: tensor (n_candidates, max_len),
:param curr_log_probs: tensor (n_candidates, 1),
:param pred_logits: tensor (n_candidates, draft_len + 1, vocab_size),
:param chosen_drafts: tensor (n_candidates, draft_len),
:param b_size: int,
:param bool_idx: tensor (n_candidates, max_len), it contains true where the draft supposed to be in curr_lines,
:param draft_place_bool: tensor (n_candidates, drafted_len), it contains true where the draft supposed to be in curr_lines,
in each line there are draft_len trues
:param n_accepted: tensor (n_candidates)
:return:
-> new_lines: tensor (num_lines, max_len),
new_log_probs_history: tensor (num_lines, max_len)
-> new_lines: tensor (num_lines, len),
new_log_probs: tensor (num_lines, 1)
num_of_new_seqs_for_each_in_batch: tensor (b_size)
token_postn: tensor (num_lines), to calculate the number of accepted tokens in the next top n sequences
later; self.acceptance_rate_pad_for_already_finished_seqs means that the given sequence had already the
eos token and so didn't need subsequent tokens
"""
drafted_len = curr_lines.shape[1]
n_candidates, draft_len_plus_one, vocab_size = pred_logits.size()

draft_len = draft_len_plus_one - 1
Expand Down Expand Up @@ -298,8 +296,8 @@ def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts,
previous_roots = curr_lines[candts_inds] # (num, drafted_len)
already_finished_given_seqs = (previous_roots == self.eos_token_idx).sum(-1).bool() # -> (num)

log_prob_history_of_roots = curr_log_probs_history[candts_inds] # (num, max_len)
bool_idx = bool_idx[candts_inds] # (num, max_len)
log_prob_of_roots = curr_log_probs[candts_inds] # (num, 1)
draft_place_bool = draft_place_bool[candts_inds] # (num, max_len)

drafts = chosen_drafts[candts_inds] # (num, draft_len)
tail = torch.full((num, 1), 0.).type_as(drafts) # -> (num, 1)
Expand All @@ -313,28 +311,45 @@ def sample(self, curr_lines, curr_log_probs_history, pred_logits, chosen_drafts,

new_seqs_log_probs = torch.gather(predicted_log_probs, dim=2, index=new_seqs.unsqueeze(-1)).squeeze(-1)
# -> (num, draft_len + 1)
new_seqs_log_probs.masked_fill_(mask_for_tokens_after_the_sampled, 0.)
# -> (num, draft_len + 1)
new_seqs_log_probs = new_seqs_log_probs.cumsum(dim=-1) # -> (num, draft_len + 1)

last_log_prob_from_roots = torch.min(log_prob_history_of_roots, dim=-1, keepdim=True).values
last_log_prob_from_roots = torch.min(log_prob_of_roots, dim=-1, keepdim=True).values
# (num, 1)
new_seqs_log_probs = last_log_prob_from_roots + new_seqs_log_probs
# -> (num, draft_len + 1)
new_seqs_log_probs = last_log_prob_from_roots + new_seqs_log_probs[:, -1:]
# -> (num, 1)
new_seqs.masked_fill_(mask_for_tokens_after_the_sampled, self.pad_token_idx)
# -> (num, draft_len + 1)
new_seqs_log_probs.masked_fill_(mask_for_tokens_after_the_sampled, self.log_prob_pad)
# -> (num, draft_len + 1)

tmp = torch.logical_or(bool_idx, torch.roll(bool_idx, 1, 1))
# -> (num, max_len)
previous_roots = torch.cat((previous_roots, tail), dim=-1) # (num, drafted_len + 1)
previous_roots[tmp[:, :drafted_len + 1]] = new_seqs.reshape(
-1) # it is new sequences sampled from the chosen drafts
log_prob_history_of_roots[tmp] = new_seqs_log_probs.reshape(-1)
new_seqs_place_bool = torch.logical_or(draft_place_bool, torch.roll(draft_place_bool, 1, 1))
# -> (num, drafted_len) It contains draft_len+1 Trues in each line
previous_roots[new_seqs_place_bool] = new_seqs.reshape(-1)

token_postn[already_finished_given_seqs] = self.acceptance_rate_pad_for_alredy_finished_seqs
# the given sequences with eos didn't need the draft tokens. We
# don't take pads into account calculating the acceptance rate
return previous_roots, log_prob_history_of_roots, num_of_new_seqs_for_each_in_batch, token_postn
return previous_roots, new_seqs_log_probs, num_of_new_seqs_for_each_in_batch, token_postn

def get_vocab_tokens_bool_lib(self, draft_lib):
"""
:param draft_lib: tensor (b_size, n_drafts, draft_len),
:return:
-> vocab_tokens_bool_lib: tensor (b_sz, vocab_size, n_drafts),
"""

draft_start_tokens = draft_lib[:, :, 0]
# -> (b_sz, n_drafts)
b_sz, n_drafts = draft_start_tokens.size()
vocab_tokens = torch.arange(self.vocab_size).unsqueeze(0).unsqueeze(-1).expand(b_sz, self.vocab_size, n_drafts)
# -> (b_sz, vocab_size, n_drafts)
vocab_tokens_bool = draft_start_tokens.unsqueeze(1).expand(b_sz, self.vocab_size, n_drafts) == vocab_tokens.type_as(draft_lib)
# -> (b_sz, vocab_size, n_drafts)
t = vocab_tokens_bool.view(-1, n_drafts)
t[t.sum(-1) == 0, 0] = True
t[t.cumsum(-1) > self.requested_drafts_num] = False
return vocab_tokens_bool

def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
# we don't need the bos token in drafts
Expand All @@ -355,19 +370,24 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:

iters = -1

generated_tokens = torch.full((1, 1), self.bos_token_idx).type_as(src).long().repeat(b_size, 1)
# -> (b_size, 1)
generated_tokens = torch.full((b_size, 1), self.bos_token_idx, device=src.device)
# -> (b_size, 1)

log_probs_history = torch.full((1, self.max_len), self.log_prob_pad).type_as(src).float().repeat(b_size, 1)
# -> (b_size, max_len)
log_probs_history[:, 0] = 0.
log_probs = torch.full((b_size, 1), 0., device=src.device)
# -> (b_size, 1)

possible_draft_len = self.max_len - 2
num_of_empty_columns = ((generated_tokens == self.pad_token_idx).sum(0) == b_size).sum().item()
# -> (1,)
postn_of_last_meaning_token = generated_tokens.shape[1] - num_of_empty_columns
# -> (1,)
possible_draft_len = self.max_len - postn_of_last_meaning_token - 1
# -> (b_size, 1)
beam_size = 1

logits_base = torch.full((b_size * n_drafts, draft_len + 1, self.vocab_size), 0., device=src.device)
# -> (b_s * n_drafts, draft_len + 1, vocab_size)

while possible_draft_len > 1 and iters < self.max_len:
while possible_draft_len >= 1 and postn_of_last_meaning_token <= self.max_len:
iters += 1
logits_base = logits_base * 0.
# We use artificial logits to avoid calculation of obvious pad predicting after eos
Expand All @@ -393,34 +413,35 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:

pads_num = (generated_tokens == self.pad_token_idx).sum(-1)
# -> (n_candidates)
pad_base_len = draft_len - torch.min(pads_num).item()
if pad_base_len > 0:
draft_base = torch.full((n_candidates, pad_base_len), self.pad_token_idx, device=src.device)
generated_tokens = torch.cat((generated_tokens, draft_base), dim=-1)
draft_place_len = draft_len + 1 - num_of_empty_columns
if draft_place_len > 0:
draft_place = torch.full((n_candidates, draft_place_len), self.pad_token_idx, device=src.device)
generated_tokens = torch.cat((generated_tokens, draft_place), dim=-1)
# -> (n_candidates, drafted_len)

logits_base = logits_base[:, :draft_len + 1, :]

self.model_calls_num += 1
log_prob_pad_t_bool = log_probs_history == self.log_prob_pad
pad_place_bool = generated_tokens == self.pad_token_idx
# -> (n_candidates, drafted_len)
draft_place_bool = torch.logical_and(pad_place_bool,
pad_place_bool.cumsum(-1) <= draft_len)
# -> (n_candidates, drafted_len)

bool_idx = torch.logical_and(log_prob_pad_t_bool,
log_prob_pad_t_bool.cumsum(-1) <= draft_len)
# -> (b_s * bm_sz, max_len)
bool_idx_input = bool_idx[:, :generated_tokens.shape[1]].unsqueeze(1).repeat(1, n_drafts, 1)
draft_place_bool_idx_input = draft_place_bool.unsqueeze(1).repeat(1, n_drafts, 1)
# -> (b_s * bm_sz, n_drafts, drafted_len)
generated_tokens_input = generated_tokens.unsqueeze(1).repeat(1, n_drafts, 1)
# -> (b_s * bm_sz, n_drafts, drafted_len)

generated_tokens_input[bool_idx_input] = draft_tokens.reshape(-1)
bool_idx_input = bool_idx_input.flatten(end_dim=1)
generated_tokens_input[draft_place_bool_idx_input] = draft_tokens.reshape(-1)
draft_place_bool_idx_input = draft_place_bool_idx_input.flatten(end_dim=1)
# -> (b_s * bm_sz * n_drafts, drafted_len)
generated_tokens_input = generated_tokens_input.flatten(end_dim=1)
# # -> (b_s * bm_sz * n_drafts, drafted_len, vocab_size)

bool_idx_of_unfinished = bool_idx_of_unfinished.unsqueeze(-1).repeat(1, n_drafts).flatten(end_dim=1)
# -> (b_s * bm_sz * n_drafts)
bool_idx_input = bool_idx_input[bool_idx_of_unfinished] #
draft_place_bool_idx_input = draft_place_bool_idx_input[bool_idx_of_unfinished]
# -> (num_of_unfinished, drafted_len)
pred_logits = self.model.decode_tgt(generated_tokens_input[bool_idx_of_unfinished],
memory[bool_idx_of_unfinished],
Expand All @@ -429,7 +450,7 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:

vocab_size = pred_logits.shape[-1]

pred_logits = pred_logits[torch.logical_or(bool_idx_input, torch.roll(bool_idx_input, -1, 1))].reshape(
pred_logits = pred_logits[torch.logical_or(draft_place_bool_idx_input, torch.roll(draft_place_bool_idx_input, -1, 1))].reshape(
-1, draft_len + 1, vocab_size)
# -> (num_of_unfinished, draft_len + 1, vocab_size)

Expand All @@ -441,7 +462,7 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
# approved tokens is the best draft for the given candidate. #########################################

# All unapproved tokens in masked_probs have zero probability
# We use nucleus=0.9975 and max_num_of_unmasked_positions=5 to avoid sampling of low probable sequences
# We use nucleus=0.9975 and max_num_of_unmasked_positions=n_best to avoid sampling of low probable sequences
# and reduce calculation
masked_probs = mask_with_num_logits_according_nucleus(pred_logits, nucleus=0.9975,
max_num_of_unmasked_positions=self.n_best,
Expand Down Expand Up @@ -473,10 +494,9 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:

# Sample all possible lines within the chosen drafts:
# new_candidates have the initial tokens and the new ones

new_candidates, new_log_probs_history, num_of_new_seqs_for_each_in_batch, accepted_tokens_num = \
self.sample(generated_tokens, log_probs_history, pred_logits,
chosen_drafts, b_size, bool_idx, n_accepted.squeeze(-1))
new_candidates, new_log_probs, num_of_new_seqs_for_each_in_batch, accepted_tokens_num = \
self.sample(generated_tokens, log_probs, pred_logits,
chosen_drafts, b_size, draft_place_bool, n_accepted.squeeze(-1))

###########################################################################################################
max_num_of_new_seqs = torch.max(num_of_new_seqs_for_each_in_batch).item()
Expand All @@ -501,26 +521,23 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
-1) # -> (b_size * max_num_of_new_seqs)
new_candidates = new_candidates[inds]
# -> (b_size * max_num_of_new_seqs, drafted_len + 1)
new_log_probs_history = new_log_probs_history[inds]
# -> (b_size * max_num_of_new_seqs, max_len)
new_log_probs = new_log_probs[inds]
# -> (b_size * max_num_of_new_seqs, 1)
accepted_tokens_num = accepted_tokens_num[inds]
# -> (b_size * max_num_of_new_seqs)
new_candidates[mask_for_fake_seqs, 1] = self.eos_token_idx # fake sequences
new_log_probs_history[mask_for_fake_seqs, 1] = -float("inf") # fake probabilities
new_log_probs[mask_for_fake_seqs, 0] = -float("inf") # fake probabilities
accepted_tokens_num[mask_for_fake_seqs] = self.acceptance_rate_pad_for_fake_seqs # fake
#############################################################################################

new_log_probs = torch.min(new_log_probs_history, dim=1).values
# -> (b_size * max_num_of_new_seqs)
new_log_probs = new_log_probs.reshape(b_size, max_num_of_new_seqs)
# -> (b_size, max_num_of_new_seqs)
v, top_inds = new_log_probs.topk(k=self.n_best, axis=-1, sorted=True)
new_log_probs, top_inds = new_log_probs.topk(k=self.n_best, axis=-1, sorted=True)
# -> (b_size, beam_size)

new_candidates = new_candidates.reshape(b_size, max_num_of_new_seqs, -1)
# -> (b_size, max_num_of_new_seqs, drafted_len + 1)
new_log_probs_history = new_log_probs_history.reshape(b_size, max_num_of_new_seqs, -1)
# -> (b_size, max_num_of_new_seqs, max_len)
# -> (b_size, max_num_of_new_seqs, drafted_len)

accepted_tokens_num = accepted_tokens_num.reshape(b_size, max_num_of_new_seqs)
# -> (b_size, max_num_of_new_seqs)

Expand All @@ -534,25 +551,29 @@ def generate(self, src: 'torch.LongTensor') -> list['torch.LongTensor']:
self.accepted_tokens_num += curr_accepted_tokens_num
self.produced_non_pad_tokens += curr_accepted_tokens_num + accepted_tokens_num.size(0)

top_inds = top_inds.unsqueeze(-1).repeat(1, 1, new_log_probs_history.shape[-1])
# -> (b_size, beam_size, max_len)
new_log_probs_history = torch.gather(new_log_probs_history, 1, top_inds)
# -> (b_size, beam_size, max_len)
new_candidates = torch.gather(new_candidates, 1, top_inds[:, :, :new_candidates.shape[-1]])
# -> (b_size, beam_size, drafted_len + 1)
top_inds = top_inds.unsqueeze(-1).repeat(1, 1, new_candidates.shape[-1])
# -> (b_size, beam_size, drafted_len)

new_candidates = torch.gather(new_candidates, 1, top_inds)
# -> (b_size, beam_size, drafted_len)

if (new_candidates[not_fake_bool] == self.eos_token_idx).sum(-1).bool().sum() == b_size * self.n_best:
break

generated_tokens = new_candidates.reshape(b_size * self.n_best, -1)
# -> (b_size * beam_size, drafted_len + 1)
new_log_probs_history = new_log_probs_history.reshape(b_size * self.n_best, -1)
# -> (b_size * beam_size, max_len)
# -> (b_size * beam_size, drafted_len)
log_probs = new_log_probs.reshape(b_size * self.n_best, 1)
# -> (b_size * beam_size, 1)
not_fake_bool = not_fake_bool.reshape(b_size * self.n_best)
# -> (b_size * beam_size)
log_probs_history = new_log_probs_history

possible_draft_len = torch.min((new_log_probs_history[not_fake_bool] == self.log_prob_pad).sum(-1)).item() - 1
num_of_empty_columns = torch.min((generated_tokens[not_fake_bool] == self.pad_token_idx).sum(-1)).item()
# -> (1,)
postn_of_last_meaning_token = generated_tokens[not_fake_bool].shape[1] - num_of_empty_columns
# -> (1,)
possible_draft_len = self.max_len - postn_of_last_meaning_token - 1
# -> (b_size, 1)

return new_candidates

def calculate_n_accepted_in_drafts(self, draft_tokens, masked_probs):
Expand Down

0 comments on commit 5a7cab9

Please sign in to comment.