1010import re
1111
1212from .analysis import calculate_collective_metrics
13- from .protobuf import xla_module_metadata
13+ from .protobuf import _hlo_cache , _remap_program_id , xla_module_metadata
14+ from .protobuf_utils import ensure_compiled_protos_are_importable
1415from .utils import default_data_prefix , make_child_mask , ProfilerData
1516
1617pd .options .mode .copy_on_write = True
2021def _is_communication (
2122 program_id : int , prefix : pathlib .Path , instruction_name : str
2223) -> bool :
23- if program_id == - 1 :
24+ if program_id == "unknown" :
2425 # Assume this is an autotuning execution.
2526 return False
2627 try :
@@ -143,10 +144,11 @@ def _sort_thunk_frame(df: pd.DataFrame) -> pd.DataFrame:
143144
144145def _load_nvtx_gpu_proj_trace_single (
145146 prefix : pathlib .Path ,
147+ replica : str | None ,
146148 file : pathlib .Path ,
147149 meta_file : pathlib .Path ,
148150 frames : set [str ],
149- ) -> dict [str , pd .DataFrame ]:
151+ ) -> tuple [ dict [str , pd .DataFrame ], dict [ tuple [ pathlib . Path , str ], set [ pathlib . Path ]] ]:
150152 # Load the thread metadata used to map module/thunk executions to global device IDs
151153 meta_df = _load_parquet_file (meta_file )
152154 # Match XLA's launcher thread name. These threads launch work if >1 GPU is being
@@ -299,22 +301,25 @@ def _load_nvtx_gpu_proj_trace_single(
299301 # The classic example where it is not set is during autotuning, where ops
300302 # to be autotuned are extracted into new HloModule instances, which are not
301303 # propagated to the GpuExecutable that emits the XlaModule annotation.
302- # Those are probably not interesting, so setting the ProgramId to -1 in
303- # such cases is acceptable.
304+ # Those are probably not interesting, so setting the ProgramId to
305+ # "unknown" in such cases is acceptable.
304306 module_re = (
305307 "^"
306308 + tsl_prefix
307309 + r"XlaModule:#(?:prefix=(.*?),|)hlo_module=([a-z0-9._-]+)(?:,program_id=(\d+)|)#$"
308310 )
309- mod_program_ids = (
310- df .loc [mod_ids , "Name" ]
311- .str .replace (
312- pat = module_re ,
313- repl = lambda m : "-1" if m .group (3 ) is None else m .group (3 ),
314- n = 1 ,
315- regex = True ,
316- )
317- .astype (np .int32 )
311+ # Apply a transformation to the program IDs to handle the case where profiles are
312+ # being combined from multiple processes, but the distributed application was not
313+ # strictly SPMD - so the IDs collected from different processes do not match for
314+ # "the same" program. The multi_process_program.py test in the nsys_jax test suite
315+ # explicitly constructs this scenario.
316+ mod_program_ids = df .loc [mod_ids , "Name" ].str .replace (
317+ pat = module_re ,
318+ repl = lambda m : _remap_program_id (
319+ old_id_str = m .group (3 ), name = m .group (2 ), prefix = prefix , replica = replica
320+ ),
321+ n = 1 ,
322+ regex = True ,
318323 )
319324 # Update each module and thunk row with the program ID it corresponds to
320325 df .loc [mod_ids , "ProgramId" ] = mod_program_ids
@@ -385,7 +390,7 @@ def clean_data_frame(d):
385390 "RangeStack" ,
386391 "TID" ,
387392 ]
388- ).astype ({"ProgramExecution" : np .int32 , "ProgramId" : np . int32 })
393+ ).astype ({"ProgramExecution" : np .int32 })
389394
390395 output = {}
391396 if "thunk" in frames :
@@ -427,7 +432,7 @@ def clean_data_frame(d):
427432 ["ProgramId" , "ProgramExecution" , "Device" ]
428433 )
429434
430- return output
435+ return output , _hlo_cache
431436
432437
433438def _enough_processes (work_items : int ) -> int :
@@ -440,33 +445,42 @@ def _load_nvtx_gpu_proj_trace(
440445 prefix : pathlib .Path ,
441446 frames : set [str ],
442447):
448+ # _remap_program_id needs to load protos
449+ ensure_compiled_protos_are_importable (prefix = prefix )
443450 path = prefix / "nvtx_gpu_proj_trace" / "trace.parquet"
444451 meta_path = prefix / "thread-metadata.parquet"
452+ replica_slugs : list [str | None ]
445453 if path .is_dir ():
446454 # We're looking at the output of nsys-jax-combine
447455 assert meta_path .is_dir ()
448456 filenames = sorted (path .iterdir ())
457+ replica_slugs = [fname .name for fname in filenames ]
449458 meta_filenames = sorted (meta_path .iterdir ())
450459 else :
451460 # We're looking at the output of nsys-jax
452461 assert not meta_path .is_dir ()
453462 filenames = [path ]
463+ replica_slugs = [None ]
454464 meta_filenames = [meta_path ]
455465
456466 if len (filenames ) > 1 :
457467 tmp = defaultdict (list )
458468 with multiprocessing .Pool (processes = _enough_processes (len (filenames ))) as pool :
459- for single_trace in pool .starmap (
469+ for single_trace , hlo_cache in pool .starmap (
460470 _load_nvtx_gpu_proj_trace_single ,
461471 zip (
462472 itertools .repeat (prefix ),
473+ replica_slugs ,
463474 filenames ,
464475 meta_filenames ,
465476 itertools .repeat (frames ),
466477 ),
467478 ):
468479 for k , v in single_trace .items ():
469480 tmp [k ].append (v )
481+ # Merge the caches from the pool worker processes into the main one.
482+ for k2 , v2 in hlo_cache .items ():
483+ _hlo_cache [k2 ] |= v2
470484 output = {}
471485 for k , v in tmp .items ():
472486 output [k ] = pd .concat (v , verify_integrity = True )
@@ -477,8 +491,9 @@ def _load_nvtx_gpu_proj_trace(
477491 if "thunk" in output :
478492 output ["thunk" ] = _sort_thunk_frame (output ["thunk" ])
479493 else :
480- output = _load_nvtx_gpu_proj_trace_single (
481- prefix , filenames [0 ], meta_filenames [0 ], frames
494+ # No explicit handling of the HLO cache, everything is in one process
495+ output , _ = _load_nvtx_gpu_proj_trace_single (
496+ prefix , None , filenames [0 ], meta_filenames [0 ], frames
482497 )
483498 if "module" in output :
484499 output ["module" ] = output ["module" ].sort_index ()
0 commit comments