-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Open
Description
inference_bistream():
...
next_fill_index = -1
...
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
next_fill_index += (self.mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
这一段代码在第一次碰到fill_token的时候没有进行N:M的约束,是否需要修改
Metadata
Metadata
Assignees
Labels
No labels