77import zipfile
88
99
10- def nsys_jax_with_result (command ):
10+ def nsys_jax_with_result (command , * , out_dir ):
1111 """
1212 Helper to run nsys-jax with a unique output file that will be automatically
1313 cleaned up on destruction. Explicitly returns the `subprocess.CompletedProcess`
1414 instance.
1515 """
16- output = tempfile .NamedTemporaryFile (suffix = ".zip" )
16+ output = tempfile .NamedTemporaryFile (delete = False , dir = out_dir , suffix = ".zip" )
1717 result = subprocess .run (
1818 ["nsys-jax" , "--force-overwrite" , "--output" , output .name ] + command ,
1919 )
2020 return output , result
2121
2222
23- def nsys_jax (command ):
23+ def nsys_jax (command , * , out_dir ):
2424 """
2525 Helper to run nsys-jax with a unique output file that will be automatically
2626 cleaned up on destruction. Throws if running `nsys-jax` does not succeed.
2727 """
28- output , result = nsys_jax_with_result (command )
28+ output , result = nsys_jax_with_result (command , out_dir = out_dir )
2929 result .check_returncode ()
3030 return output
3131
@@ -42,11 +42,11 @@ def extract(archive):
4242 return tmpdir
4343
4444
45- def nsys_jax_archive (command ):
45+ def nsys_jax_archive (command , * , out_dir ):
4646 """
4747 Helper to run nsys-jax and automatically extract the output, yielding a directory.
4848 """
49- archive = nsys_jax (command )
49+ archive = nsys_jax (command , out_dir = out_dir )
5050 tmpdir = extract (archive )
5151 # Make sure the protobuf bindings can be imported, the generated .py will go into
5252 # a temporary directory that is not currently cleaned up. The bindings cannot be
@@ -58,13 +58,17 @@ def nsys_jax_archive(command):
5858
5959
6060def multi_process_nsys_jax (
61- num_processes : int , command : typing .Callable [[int ], list [str ]]
61+ num_processes : int ,
62+ command : typing .Callable [[int ], list [str ]],
63+ * ,
64+ out_dir ,
6265):
6366 """
6467 Helper to run a multi-process test under nsys-jax and yield several .zip
6568 """
6669 child_outputs = [
67- tempfile .NamedTemporaryFile (suffix = ".zip" ) for _ in range (num_processes )
70+ tempfile .NamedTemporaryFile (delete = False , dir = out_dir , suffix = ".zip" )
71+ for _ in range (num_processes )
6872 ]
6973 children = [
7074 subprocess .Popen (
0 commit comments