Skip to content

Commit 74a3d94

Browse files
committed
Convert .csv to .parquet in nsys-jax to avoid compressing a large .csv with Python's lzma.
1 parent 9c56b12 commit 74a3d94

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -593,15 +593,25 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame:
593593

594594

595595
def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame:
596-
def keep_column(name):
597-
return name not in {"PID", "Lvl", "NameTree"}
598-
599-
return pd.read_csv(
600-
lzma.open(file, "rt", newline=""),
601-
dtype={"RangeId": np.int32},
602-
index_col="RangeId",
603-
usecols=keep_column,
604-
)
596+
# `file` follows one of two patterns, depending on whether we are loading the
597+
# results from a single profile or from multiple merged profiles:
598+
# - nsys-jax: /path/to/report_nvtx_pushpop_trace.parquet
599+
# - nsys-jax-combine: /path/to/report_nvtx_pushpop_trace.parquet/rank5
600+
new_name = "report_nvtx_pushpop_trace.parquet"
601+
if file.name == new_name or file.parent.name == new_name:
602+
# New mode; the .csv to .parquet conversion is done in nsys-jax
603+
return pd.read_parquet(file)
604+
else:
605+
606+
def keep_column(name):
607+
return name not in {"PID", "Lvl", "NameTree"}
608+
609+
return pd.read_csv(
610+
lzma.open(file, "rt", newline=""),
611+
dtype={"RangeId": np.int32},
612+
index_col="RangeId",
613+
usecols=keep_column,
614+
)
605615

606616

607617
def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame:
@@ -640,7 +650,9 @@ def remove_program_id_and_name(row):
640650

641651

642652
def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame:
643-
path = prefix / "report_nvtx_pushpop_trace.csv.xz"
653+
new_path = prefix / "report_nvtx_pushpop_trace.parquet"
654+
legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz"
655+
path = new_path if new_path.exists() else legacy_path
644656
if path.is_dir():
645657
# We're looking at the output of nsys-jax-combine
646658
filenames = sorted(path.iterdir())

.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import contextmanager
44
from glob import glob, iglob
55
import lzma
6+
import numpy as np
67
import os
78
import os.path as osp
89
import pandas as pd # type: ignore
@@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
369370
if osp.isdir(full_path) or not osp.exists(full_path):
370371
continue
371372
output_queue.put((ofile, full_path, COMPRESS_NONE))
372-
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
373+
print(
374+
f"{archive_name}: recipe post-processing finished in {time.time()-start:.2f}s"
375+
)
373376

374377
def compress_and_archive(prefix, file, output_queue):
375378
"""
@@ -401,9 +404,25 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
401404
],
402405
check=True,
403406
)
404-
for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir):
405-
compress_and_archive(tmp_dir, ofile, output_queue)
406-
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
407+
output_path = osp.join(tmp_dir, f"report_{report}.csv")
408+
409+
# TODO: avoid the .csv indirection
410+
def keep_column(name):
411+
return name not in {"PID", "Lvl", "NameTree"}
412+
413+
df = pd.read_csv(
414+
output_path,
415+
dtype={"RangeId": np.int32},
416+
index_col="RangeId",
417+
usecols=keep_column,
418+
)
419+
parquet_name = f"report_{report}.parquet"
420+
parquet_path = osp.join(tmp_dir, parquet_name)
421+
df.to_parquet(parquet_path)
422+
output_queue.put((parquet_name, parquet_path, COMPRESS_NONE))
423+
print(
424+
f"{archive_name}: stats post-processing finished in {time.time()-start:.2f}s"
425+
)
407426

408427
def save_device_stream_thread_names(tmp_dir, report, output_queue):
409428
"""

0 commit comments

Comments
 (0)