11#!/usr/bin/env python3
22import sys
33import os
4+
45# Add the parent directory to sys.path to allow importing torchsparsegradutils
5- sys .path .insert (0 , os .path .join (os .path .dirname (__file__ ), '..' , '..' ))
6+ sys .path .insert (0 , os .path .join (os .path .dirname (__file__ ), ".." , ".." ))
67
78import time
89import torch
1920
2021# problem sizes: (label, N, M, nnz)
2122SIZES = [
22- ("small" , 2_000 , 128 , 4_000 ),
23- ("medium" , 5_000 , 256 , 10_000 ),
24- ("large" , 10_000 , 512 , 20_000 ),
23+ ("small" , 2_000 , 128 , 4_000 ),
24+ ("medium" , 5_000 , 256 , 10_000 ),
25+ ("large" , 10_000 , 512 , 20_000 ),
2526]
2627
2728INDEX_DTYPES = [torch .int32 , torch .int64 ]
3334 ("dense.mm" , lambda A , B : torch .matmul (A .to_dense (), B )),
3435]
3536
37+
3638def measure_op (op , A , B ):
3739 """
3840 Measure forward/backward times and peak mem.
@@ -64,6 +66,7 @@ def measure_op(op, A, B):
6466
6567 return (t1 - t0 ), mem_fwd , (t3 - t2 ), mem_bwd
6668
69+
6770def main ():
6871 records = []
6972 for size_label , N , M , nnz in SIZES :
@@ -74,10 +77,7 @@ def main():
7477 for val_dt in VALUE_DTYPES :
7578 # build one sparse COO for all algos
7679 A_coo = rand_sparse (
77- A_shape , nnz , torch .sparse_coo ,
78- indices_dtype = idx_dt ,
79- values_dtype = val_dt ,
80- device = device
80+ A_shape , nnz , torch .sparse_coo , indices_dtype = idx_dt , values_dtype = val_dt , device = device
8181 ).coalesce ()
8282 B = torch .randn (B_shape , dtype = val_dt , device = device )
8383
@@ -89,28 +89,41 @@ def main():
8989 # run
9090 t_fwd , mem_fwd , t_bwd , mem_bwd = measure_op (alg_fn , A , B )
9191
92- records .append ({
93- "size" : size_label ,
94- "layout" : layout_name ,
95- "algo" : alg_name ,
96- "index_dt" : str (idx_dt ).split ("." )[- 1 ],
97- "value_dt" : str (val_dt ).split ("." )[- 1 ],
98- "N" : N ,
99- "M" : M ,
100- "nnz" : nnz ,
101- "fwd_time_s" : f"{ t_fwd :.3f} " ,
102- "fwd_mem_MB" : f"{ mem_fwd :.1f} " ,
103- "bwd_time_s" : f"{ t_bwd :.3f} " ,
104- "bwd_mem_MB" : f"{ mem_bwd :.1f} " ,
105- })
92+ records .append (
93+ {
94+ "size" : size_label ,
95+ "layout" : layout_name ,
96+ "algo" : alg_name ,
97+ "index_dt" : str (idx_dt ).split ("." )[- 1 ],
98+ "value_dt" : str (val_dt ).split ("." )[- 1 ],
99+ "N" : N ,
100+ "M" : M ,
101+ "nnz" : nnz ,
102+ "fwd_time_s" : f"{ t_fwd :.3f} " ,
103+ "fwd_mem_MB" : f"{ mem_fwd :.1f} " ,
104+ "bwd_time_s" : f"{ t_bwd :.3f} " ,
105+ "bwd_mem_MB" : f"{ mem_bwd :.1f} " ,
106+ }
107+ )
106108
107109 df = pd .DataFrame .from_records (records )
108110 # reorder columns for clarity
109- df = df [[
110- "size" , "layout" , "algo" , "index_dt" , "value_dt" ,
111- "N" , "M" , "nnz" ,
112- "fwd_time_s" , "fwd_mem_MB" , "bwd_time_s" , "bwd_mem_MB"
113- ]]
111+ df = df [
112+ [
113+ "size" ,
114+ "layout" ,
115+ "algo" ,
116+ "index_dt" ,
117+ "value_dt" ,
118+ "N" ,
119+ "M" ,
120+ "nnz" ,
121+ "fwd_time_s" ,
122+ "fwd_mem_MB" ,
123+ "bwd_time_s" ,
124+ "bwd_mem_MB" ,
125+ ]
126+ ]
114127
115128 md = df .to_markdown (index = False )
116129 with open ("torchsparsegradutils/tests/benchmark_results_sparse_mm.md" , "w" ) as f :
@@ -120,5 +133,6 @@ def main():
120133
121134 print ("Written results to benchmark_results_sparse_mm.md" )
122135
136+
123137if __name__ == "__main__" :
124138 main ()
0 commit comments