Skip to content

Commit 21652bb

Browse files
committed
Fix tests
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
1 parent 68316d0 commit 21652bb

1 file changed

Lines changed: 1 addition & 19 deletions

File tree

tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,6 @@
2828
CUDA_GRAPH_TRANSCRIBE_NUM_WORKERS = 0
2929

3030

31-
def test_full_graph_compile_fallback_handles_accelerator_error():
32-
from nemo.collections.asr.parts.submodules.transducer_decoding.rnnt_label_looping import (
33-
GreedyBatchedRNNTLabelLoopingComputer,
34-
)
35-
36-
accelerator_error = getattr(torch, "AcceleratorError", RuntimeError)
37-
computer = GreedyBatchedRNNTLabelLoopingComputer.__new__(GreedyBatchedRNNTLabelLoopingComputer)
38-
computer.cuda_graphs_mode = computer.CudaGraphsMode.FULL_GRAPH
39-
computer.cuda_graphs_allow_fallback = True
40-
partial_graph_compile_calls = []
41-
computer._partial_graphs_compile = lambda: partial_graph_compile_calls.append(True)
42-
43-
computer._fallback_to_no_while_loop_cuda_graphs(accelerator_error("CUDA error: invalid argument"))
44-
45-
assert computer.cuda_graphs_mode == computer.CudaGraphsMode.NO_WHILE_LOOPS
46-
assert partial_graph_compile_calls == [True]
47-
48-
4931
def test_forced_full_graph_compile_does_not_fallback():
5032
from nemo.collections.asr.parts.submodules.transducer_decoding.rnnt_label_looping import (
5133
GreedyBatchedRNNTLabelLoopingComputer,
@@ -56,7 +38,7 @@ def test_forced_full_graph_compile_does_not_fallback():
5638
computer.cuda_graphs_allow_fallback = False
5739

5840
with pytest.raises(RuntimeError, match="Full CUDA graph decoding failed"):
59-
computer._fallback_to_no_while_loop_cuda_graphs(accelerator_error("CUDA error: invalid argument"))
41+
computer._raise_or_warn_no_while_loop_cuda_graphs(accelerator_error("CUDA error: invalid argument"))
6042

6143

6244
@pytest.mark.with_downloads

0 commit comments

Comments
 (0)