Skip to content

Commit d8056e0

Browse files
committed
Convert .csv to .parquet in nsys-jax to avoid compressing a large .csv with Python's lzma.
1 parent 9c56b12 commit d8056e0

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -593,15 +593,25 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame:
593593

594594

595595
def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame:
596-
def keep_column(name):
597-
return name not in {"PID", "Lvl", "NameTree"}
598-
599-
return pd.read_csv(
600-
lzma.open(file, "rt", newline=""),
601-
dtype={"RangeId": np.int32},
602-
index_col="RangeId",
603-
usecols=keep_column,
604-
)
596+
# `file` follows one of two patterns, depending on whether we are loading the
597+
# results from a single profile or from multiple merged profiles:
598+
# - nsys-jax: /path/to/report_nvtx_pushpop_trace.parquet
599+
# - nsys-jax-combine: /path/to/report_nvtx_pushpop_trace.parquet/rank5
600+
new_name = "report_nvtx_pushpop_trace.parquet"
601+
if file.name == new_name or file.parent.name == new_name:
602+
# New mode; the .csv to .parquet conversion is done in nsys-jax
603+
return pd.read_parquet(file)
604+
else:
605+
606+
def keep_column(name):
607+
return name not in {"PID", "Lvl", "NameTree"}
608+
609+
return pd.read_csv(
610+
lzma.open(file, "rt", newline=""),
611+
dtype={"RangeId": np.int32},
612+
index_col="RangeId",
613+
usecols=keep_column,
614+
)
605615

606616

607617
def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame:
@@ -640,7 +650,9 @@ def remove_program_id_and_name(row):
640650

641651

642652
def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame:
643-
path = prefix / "report_nvtx_pushpop_trace.csv.xz"
653+
new_path = prefix / "report_nvtx_pushpop_trace.parquet"
654+
legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz"
655+
path = new_path if new_path.exists() else legacy_path
644656
if path.is_dir():
645657
# We're looking at the output of nsys-jax-combine
646658
filenames = sorted(path.iterdir())

.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import contextmanager
44
from glob import glob, iglob
55
import lzma
6+
import numpy as np
67
import os
78
import os.path as osp
89
import pandas as pd # type: ignore
@@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
369370
if osp.isdir(full_path) or not osp.exists(full_path):
370371
continue
371372
output_queue.put((ofile, full_path, COMPRESS_NONE))
372-
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
373+
print(
374+
f"{archive_name}: recipe post-processing finished in {time.time()-start:.2f}s"
375+
)
373376

374377
def compress_and_archive(prefix, file, output_queue):
375378
"""
@@ -401,9 +404,29 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
401404
],
402405
check=True,
403406
)
404-
for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir):
405-
compress_and_archive(tmp_dir, ofile, output_queue)
406-
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
407+
output_path = osp.join(tmp_dir, f"report_{report}.csv")
408+
409+
# TODO: avoid the .csv indirection
410+
def keep_column(name):
411+
return name not in {"PID", "Lvl", "NameTree"}
412+
413+
try:
414+
df = pd.read_csv(
415+
output_path,
416+
dtype={"RangeId": np.int32},
417+
index_col="RangeId",
418+
usecols=keep_column,
419+
)
420+
parquet_name = f"report_{report}.parquet"
421+
parquet_path = osp.join(tmp_dir, parquet_name)
422+
df.to_parquet(parquet_path)
423+
output_queue.put((parquet_name, parquet_path, COMPRESS_NONE))
424+
except pd.errors.EmptyDataError:
425+
# If there's no data, don't write a file to the output at all
426+
pass
427+
print(
428+
f"{archive_name}: stats post-processing finished in {time.time()-start:.2f}s"
429+
)
407430

408431
def save_device_stream_thread_names(tmp_dir, report, output_queue):
409432
"""

.github/workflows/_ci.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ jobs:
322322
set -o pipefail
323323
num_tests=0
324324
num_failures=0
325-
# Run the pytest-driven tests
325+
# Run the pytest-driven tests; failure is explicitly handled below so set +e to
326+
# avoid an early abort here.
327+
set +e
326328
docker run -i --shm-size=1g --gpus all \
327329
-v $PWD:/opt/output \
328330
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
@@ -333,6 +335,7 @@ jobs:
333335
test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())')
334336
pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}"
335337
EOF
338+
set -e
336339
GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU')
337340
for mode in 1-process 2-process process-per-gpu; do
338341
DOCKER="docker run --shm-size=1g --gpus all --env XLA_FLAGS=--xla_gpu_enable_command_buffer= --env XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -v ${PWD}:/opt/output ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}"

0 commit comments

Comments
 (0)