@@ -129,7 +129,7 @@ def _load_nvtx_gpu_proj_trace_single(
129129 file : pathlib .Path ,
130130 meta_file : pathlib .Path ,
131131 frames : set [str ],
132- ):
132+ ) -> dict [ str , pd . DataFrame ] :
133133 # Load the thread metadata used to map module/thunk executions to global device IDs
134134 meta_df = _load_parquet_file (meta_file )
135135 # Match XLA's launcher thread name. These threads launch work if >1 GPU is being
@@ -440,22 +440,28 @@ def _load_nvtx_gpu_proj_trace(
440440 filenames = [path ]
441441 meta_filenames = [meta_path ]
442442
443- tmp = defaultdict (list )
444- with multiprocessing .Pool (processes = _enough_processes (len (filenames ))) as pool :
445- for single_trace in pool .starmap (
446- _load_nvtx_gpu_proj_trace_single ,
447- zip (
448- itertools .repeat (prefix ),
449- filenames ,
450- meta_filenames ,
451- itertools .repeat (frames ),
452- ),
453- ):
454- for k , v in single_trace .items ():
455- tmp [k ].append (v )
456- output = {}
457- for k , v in tmp .items ():
458- output [k ] = pd .concat (v , verify_integrity = True ).sort_index ()
443+ if len (filenames ) > 1 :
444+ tmp = defaultdict (list )
445+ with multiprocessing .Pool (processes = _enough_processes (len (filenames ))) as pool :
446+ for single_trace in pool .starmap (
447+ _load_nvtx_gpu_proj_trace_single ,
448+ zip (
449+ itertools .repeat (prefix ),
450+ filenames ,
451+ meta_filenames ,
452+ itertools .repeat (frames ),
453+ ),
454+ ):
455+ for k , v in single_trace .items ():
456+ tmp [k ].append (v )
457+ output = {}
458+ for k , v in tmp .items ():
459+ output [k ] = pd .concat (v , verify_integrity = True ).sort_index ()
460+ else :
461+ output = _load_nvtx_gpu_proj_trace_single (
462+ prefix , filenames [0 ], meta_filenames [0 ], frames
463+ )
464+ output = {k : v .sort_index () for k , v in output .items ()}
459465 return output
460466
461467
@@ -644,12 +650,16 @@ def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataF
644650 filenames = [path ]
645651 keys = [prefix .name ]
646652
647- with multiprocessing .Pool (processes = _enough_processes (len (filenames ))) as pool :
648- return pd .concat (
649- pool .map (_load_nvtx_pushpop_trace_single , filenames ),
650- keys = keys ,
651- names = ["ProfileName" , "RangeId" ],
652- )
653+ if len (filenames ) > 1 :
654+ with multiprocessing .Pool (processes = _enough_processes (len (filenames ))) as pool :
655+ chunks = pool .map (_load_nvtx_pushpop_trace_single , filenames )
656+ else :
657+ chunks = [_load_nvtx_pushpop_trace_single (filenames [0 ])]
658+ return pd .concat (
659+ chunks ,
660+ keys = keys ,
661+ names = ["ProfileName" , "RangeId" ],
662+ )
653663
654664
655665def load_profiler_data (
0 commit comments