2828CUDA_GRAPH_TRANSCRIBE_NUM_WORKERS = 0
2929
3030
31- def test_full_graph_compile_fallback_handles_accelerator_error ():
32- from nemo .collections .asr .parts .submodules .transducer_decoding .rnnt_label_looping import (
33- GreedyBatchedRNNTLabelLoopingComputer ,
34- )
35-
36- accelerator_error = getattr (torch , "AcceleratorError" , RuntimeError )
37- computer = GreedyBatchedRNNTLabelLoopingComputer .__new__ (GreedyBatchedRNNTLabelLoopingComputer )
38- computer .cuda_graphs_mode = computer .CudaGraphsMode .FULL_GRAPH
39- computer .cuda_graphs_allow_fallback = True
40- partial_graph_compile_calls = []
41- computer ._partial_graphs_compile = lambda : partial_graph_compile_calls .append (True )
42-
43- computer ._fallback_to_no_while_loop_cuda_graphs (accelerator_error ("CUDA error: invalid argument" ))
44-
45- assert computer .cuda_graphs_mode == computer .CudaGraphsMode .NO_WHILE_LOOPS
46- assert partial_graph_compile_calls == [True ]
47-
48-
4931def test_forced_full_graph_compile_does_not_fallback ():
5032 from nemo .collections .asr .parts .submodules .transducer_decoding .rnnt_label_looping import (
5133 GreedyBatchedRNNTLabelLoopingComputer ,
@@ -56,7 +38,7 @@ def test_forced_full_graph_compile_does_not_fallback():
5638 computer .cuda_graphs_allow_fallback = False
5739
5840 with pytest .raises (RuntimeError , match = "Full CUDA graph decoding failed" ):
59- computer ._fallback_to_no_while_loop_cuda_graphs (accelerator_error ("CUDA error: invalid argument" ))
41+ computer ._raise_or_warn_no_while_loop_cuda_graphs (accelerator_error ("CUDA error: invalid argument" ))
6042
6143
6244@pytest .mark .with_downloads
0 commit comments