Skip to content

Commit fb5e940

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Export input/output data to directory (#168)
Summary: We need to pickle input/output data to directory for numeric checks across runs, such as CUDA version and Triton version. Pull Request resolved: #168 Test Plan: ``` python run.py --op flash_attention --only triton_tutorial_flash_v2 --num-inputs 1 --export both --export-dir $PWD/flash_attention python run.py --op flash_attention --only triton_tutorial_flash_v2 --num-inputs 1 --export both --export-dir .data/flash_attention/ --bwd --causal ``` ``` ls .data/flash_attention -rw-r--r-- 1 xzhao9 users 9438227 Feb 28 07:40 'x_(4, 48, 128, 64)-input.pkl' -rw-r--r-- 1 xzhao9 users 9438257 Feb 28 07:39 'x_(4, 48, 128, 64)-triton_tutorial_flash_v2-bwd-grad.pkl' -rw-r--r-- 1 xzhao9 users 3146151 Feb 28 07:40 'x_(4, 48, 128, 64)-triton_tutorial_flash_v2-fwd-output.pkl' ``` Fixes #84 Reviewed By: adamomainz Differential Revision: D70394911 Pulled By: xuzhao9 fbshipit-source-id: 468980b0b50684a278f309afc8af0d0a55f8008f
1 parent b093486 commit fb5e940

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .export import export_data
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Serialize pickled tensors to directory.
3+
"""
4+
5+
import pickle
6+
from pathlib import Path
7+
8+
from typing import Any, Callable
9+
10+
from tritonbench.utils.input import input_cast
11+
12+
13+
def get_input_gradients(inputs):
14+
all_input_grads = []
15+
input_cast(lambda x: True, lambda y: all_input_grads.append(y.grad), inputs)
16+
return all_input_grads
17+
18+
19+
def export_data(
20+
x_val: str,
21+
inputs: Any,
22+
fn_mode: str,
23+
fn: Callable,
24+
export_type: str,
25+
export_dir: str,
26+
):
27+
# pickle naming convention
28+
# x_<x_val>-input.pkl
29+
# x_<x_val>-<fn_name>-fwd-output.pkl
30+
# x_<x_val>-<fn_name>-bwd-grad.pkl
31+
assert export_dir, "Export dir must be specified."
32+
export_path = Path(export_dir)
33+
assert export_path.exists(), f"Export path {export_dir} must exist."
34+
if export_type == "input" or export_type == "both":
35+
input_file_name = f"x_{x_val}-input.pkl"
36+
input_file_path = export_path.joinpath(input_file_name)
37+
with open(input_file_path, "wb") as ifp:
38+
pickle.dump(inputs, ifp)
39+
if export_type == "output" or export_type == "both":
40+
if fn_mode == "fwd":
41+
output_type = "output"
42+
output = fn()
43+
elif fn_mode == "bwd":
44+
output_type = "grad"
45+
# output of the backward pass are the input gradients
46+
output = get_input_gradients(inputs)
47+
output_file_name = f"x_{x_val}-{fn._name}-{fn_mode}-{output_type}.pkl"
48+
output_file_path = export_path.joinpath(output_file_name)
49+
with open(output_file_path, "wb") as ofp:
50+
pickle.dump(output, ofp)

tritonbench/utils/parser.py

+12
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,18 @@ def get_parser(args=None):
180180
action="store_true",
181181
help="Include cold start time in compile_time and compile_trace.",
182182
)
183+
parser.add_argument(
184+
"--export",
185+
default=None,
186+
choices=["in", "out", "both"],
187+
help="Export input or output. Must be used together with --export-dir.",
188+
)
189+
parser.add_argument(
190+
"--export-dir",
191+
default=None,
192+
type=str,
193+
help="The directory to store input or output.",
194+
)
183195

184196
if IS_FBCODE:
185197
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")

tritonbench/utils/triton_op.py

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import triton
2525

2626
from tritonbench.components.do_bench import do_bench_wrapper, Latency
27+
from tritonbench.components.export import export_data
2728
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
2829
from tritonbench.utils.env_utils import apply_precision, set_env, set_random_seed
2930
from tritonbench.utils.input import input_cast
@@ -1371,6 +1372,15 @@ def _init_extra_metrics() -> Dict[str, Any]:
13711372
use_cuda_profiler_range=True,
13721373
)
13731374
metrics.extra_metrics["_nsys_rep_in_task"] = "success"
1375+
if self.tb_args.export:
1376+
export_data(
1377+
x_val=self.get_x_val(self.example_inputs),
1378+
inputs=self.example_inputs,
1379+
fn_mode=self.mode.value,
1380+
fn=fn,
1381+
export_type=self.tb_args.export,
1382+
export_dir=self.tb_args.export_dir,
1383+
)
13741384
# generate customized metrics
13751385
if self.name in REGISTERED_METRICS:
13761386
for metric_name in REGISTERED_METRICS[self.name]:

0 commit comments

Comments
 (0)