Skip to content

Commit 0299d5a

Browse files
committed
Skip multiprocessing.Pool if no parallelism is available; makes Python profiles clearer
1 parent f22ba1d commit 0299d5a

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _load_nvtx_gpu_proj_trace_single(
129129
file: pathlib.Path,
130130
meta_file: pathlib.Path,
131131
frames: set[str],
132-
):
132+
) -> dict[str, pd.DataFrame]:
133133
# Load the thread metadata used to map module/thunk executions to global device IDs
134134
meta_df = _load_parquet_file(meta_file)
135135
# Match XLA's launcher thread name. These threads launch work if >1 GPU is being
@@ -440,22 +440,28 @@ def _load_nvtx_gpu_proj_trace(
440440
filenames = [path]
441441
meta_filenames = [meta_path]
442442

443-
tmp = defaultdict(list)
444-
with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool:
445-
for single_trace in pool.starmap(
446-
_load_nvtx_gpu_proj_trace_single,
447-
zip(
448-
itertools.repeat(prefix),
449-
filenames,
450-
meta_filenames,
451-
itertools.repeat(frames),
452-
),
453-
):
454-
for k, v in single_trace.items():
455-
tmp[k].append(v)
456-
output = {}
457-
for k, v in tmp.items():
458-
output[k] = pd.concat(v, verify_integrity=True).sort_index()
443+
if len(filenames) > 1:
444+
tmp = defaultdict(list)
445+
with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool:
446+
for single_trace in pool.starmap(
447+
_load_nvtx_gpu_proj_trace_single,
448+
zip(
449+
itertools.repeat(prefix),
450+
filenames,
451+
meta_filenames,
452+
itertools.repeat(frames),
453+
),
454+
):
455+
for k, v in single_trace.items():
456+
tmp[k].append(v)
457+
output = {}
458+
for k, v in tmp.items():
459+
output[k] = pd.concat(v, verify_integrity=True).sort_index()
460+
else:
461+
output = _load_nvtx_gpu_proj_trace_single(
462+
prefix, filenames[0], meta_filenames[0], frames
463+
)
464+
output = {k: v.sort_index() for k, v in output.items()}
459465
return output
460466

461467

@@ -644,12 +650,16 @@ def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataF
644650
filenames = [path]
645651
keys = [prefix.name]
646652

647-
with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool:
648-
return pd.concat(
649-
pool.map(_load_nvtx_pushpop_trace_single, filenames),
650-
keys=keys,
651-
names=["ProfileName", "RangeId"],
652-
)
653+
if len(filenames) > 1:
654+
with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool:
655+
chunks = pool.map(_load_nvtx_pushpop_trace_single, filenames)
656+
else:
657+
chunks = [_load_nvtx_pushpop_trace_single(filenames[0])]
658+
return pd.concat(
659+
chunks,
660+
keys=keys,
661+
names=["ProfileName", "RangeId"],
662+
)
653663

654664

655665
def load_profiler_data(

0 commit comments

Comments
 (0)