Skip to content

Commit 40e1dd9

Browse files
committed
fix for true bfloat16 inference
Signed-off-by: naymaraq <dkaramyan@nvidia.com>
1 parent 160a742 commit 40e1dd9

5 files changed

Lines changed: 13 additions & 3 deletions

File tree

examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ asr:
66
device: cuda # Device for inference: 'cuda' or 'cpu'
77
device_id: 0 # GPU device ID
88
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
9-
use_amp: true # Enable Automatic Mixed Precision
9+
use_amp: false # Enable Automatic Mixed Precision
1010

1111

1212
# ==========================================

examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ asr:
66
device: cuda # Device for inference: 'cuda' or 'cpu'
77
device_id: 0 # GPU device ID
88
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
9-
use_amp: true # Enable Automatic Mixed Precision
9+
use_amp: false # Enable Automatic Mixed Precision
1010
decoding:
1111
strategy: "greedy_batch"
1212
preserve_alignments: false

nemo/collections/asr/inference/model_wrappers/cache_aware_asr_inference_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_initial_cache_state(self, batch_size: int) -> tuple[Tensor, Tensor, Tens
5555
Returns:
5656
(tuple[Tensor, Tensor, Tensor]) the initial cache state of the encoder.
5757
"""
58-
return self.asr_model.encoder.get_initial_cache_state(batch_size=batch_size)
58+
return self.asr_model.encoder.get_initial_cache_state(batch_size=batch_size, dtype=self.cast_dtype)
5959

6060
def get_drop_extra_pre_encoded(self) -> int:
6161
"""

nemo/collections/asr/inference/model_wrappers/cache_aware_ctc_inference_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def __post_init__(self) -> None:
5757

5858
self.drop_extra_pre_encoded = self.get_drop_extra_pre_encoded()
5959

60+
self.cast_dtype = torch.float32 if self.use_amp else self.compute_dtype
61+
self.asr_model.to(self.cast_dtype)
62+
6063
def get_blank_id(self) -> int:
6164
"""
6265
Returns id of the blank token.
@@ -180,6 +183,8 @@ def stream_step(
180183
if processed_signal_length.device != self.device:
181184
processed_signal_length = processed_signal_length.to(self.device)
182185

186+
processed_signal = processed_signal.to(self.cast_dtype)
187+
183188
if context is None:
184189
# create a dummy context
185190
context = CacheAwareContext()

nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __post_init__(self) -> None:
5656

5757
self.drop_extra_pre_encoded = self.get_drop_extra_pre_encoded()
5858

59+
self.cast_dtype = torch.float32 if self.use_amp else self.compute_dtype
60+
self.asr_model.to(self.cast_dtype)
61+
5962
def get_blank_id(self) -> int:
6063
"""
6164
Returns id of the blank token.
@@ -170,6 +173,8 @@ def stream_step(
170173
if processed_signal_length.device != self.device:
171174
processed_signal_length = processed_signal_length.to(self.device)
172175

176+
processed_signal = processed_signal.to(self.cast_dtype)
177+
173178
if context is None:
174179
# create a dummy context
175180
context = CacheAwareContext()

0 commit comments

Comments
 (0)