|
for b in range(bs): |
|
nbest_hyps_id_b = [np.fromiter(y, dtype=np.int64) for y in nbest_hyps_id[b]] |
|
nbest_hyps_id_batch += nbest_hyps_id_b |
|
scores_b = np2tensor(np.array(scores[b], dtype=np.float32), eouts.device) |
|
probs_b_norm = torch.softmax(scaling_factor * scores_b, dim=-1) # `[nbest]` |
|
wers_b = np2tensor(np.array([ |
|
compute_wer(ref=idx2token(ys_ref[b]).split(' '), |
|
hyp=idx2token(nbest_hyps_id_b[n]).split(' '))[0] / 100 |
|
for n in range(nbest)], dtype=np.float32), eouts.device) |
|
exp_wer_b = (probs_b_norm * wers_b).sum() |
|
grad_list += [(probs_b_norm * (wers_b - exp_wer_b)).sum()] |
|
exp_wer += exp_wer_b |
|
exp_wer /= bs |
|
|
I don't know much about mbr, according to these lines, it looks like a
mWER loss and gradient to me
neural_sp/neural_sp/models/seq2seq/decoders/las.py
Lines 535 to 548 in 2b10b9c
I don't know much about mbr, according to these lines, it looks like a mWER loss and gradient to me