Skip to content

Commit 1a77c9d

Browse files
committed
Fix confidence for TDT with duration confidence
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
1 parent d6e3db7 commit 1a77c9d

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

nemo/collections/asr/parts/submodules/rnnt_decoding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,12 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
843843
hyp.non_blank_step_confidence_precomputed is not None for hyp in hypotheses_list
844844
):
845845
for hyp in hypotheses_list:
846-
hyp.token_confidence = hyp.non_blank_step_confidence_precomputed
846+
if self.tdt_include_duration_confidence:
847+
hyp.token_confidence = [
848+
self._aggregate_confidence(c) for c in hyp.non_blank_step_confidence_precomputed
849+
]
850+
else:
851+
hyp.token_confidence = hyp.non_blank_step_confidence_precomputed
847852
else:
848853
maybe_pre_aggregate = (
849854
(lambda x: self._aggregate_confidence(x))

0 commit comments

Comments
 (0)