Skip to content

Commit 0500c12

Browse files
committed
run black and flake8
1 parent 4780db0 commit 0500c12

16 files changed

Lines changed: 415 additions & 209 deletions

.devcontainer/devcontainer.json

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
"torchsparsegradutils/tests"
2424
],
2525
"python.testing.unittestEnabled": false,
26-
"python.testing.pytestEnabled": true
26+
"python.testing.pytestEnabled": true,
27+
"editor.formatOnSave": true,
28+
"python.formatting.provider": "black",
29+
"python.linting.enabled": true,
30+
"python.linting.flake8Enabled": true,
31+
"python.linting.pylintEnabled": false
2732
},
2833
"extensions": [
2934
"dbaeumer.vscode-eslint",
@@ -34,7 +39,9 @@
3439
"GitHub.vscode-pull-request-github",
3540
"GitHub.vscode-github-actions",
3641
"mhutchie.git-graph",
37-
"waderyan.gitblame"
42+
"waderyan.gitblame",
43+
"ms-python.black-formatter",
44+
"ms-python.flake8"
3845
]
3946
}
4047
},

torchsparsegradutils/tests/benchmark_sparse_mm.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#!/usr/bin/env python3
22
import sys
33
import os
4+
45
# Add the parent directory to sys.path to allow importing torchsparsegradutils
5-
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
6+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
67

78
import time
89
import torch
@@ -19,9 +20,9 @@
1920

2021
# problem sizes: (label, N, M, nnz)
2122
SIZES = [
22-
("small", 2_000, 128, 4_000),
23-
("medium", 5_000, 256, 10_000),
24-
("large", 10_000, 512, 20_000),
23+
("small", 2_000, 128, 4_000),
24+
("medium", 5_000, 256, 10_000),
25+
("large", 10_000, 512, 20_000),
2526
]
2627

2728
INDEX_DTYPES = [torch.int32, torch.int64]
@@ -33,6 +34,7 @@
3334
("dense.mm", lambda A, B: torch.matmul(A.to_dense(), B)),
3435
]
3536

37+
3638
def measure_op(op, A, B):
3739
"""
3840
Measure forward/backward times and peak mem.
@@ -64,6 +66,7 @@ def measure_op(op, A, B):
6466

6567
return (t1 - t0), mem_fwd, (t3 - t2), mem_bwd
6668

69+
6770
def main():
6871
records = []
6972
for size_label, N, M, nnz in SIZES:
@@ -74,10 +77,7 @@ def main():
7477
for val_dt in VALUE_DTYPES:
7578
# build one sparse COO for all algos
7679
A_coo = rand_sparse(
77-
A_shape, nnz, torch.sparse_coo,
78-
indices_dtype=idx_dt,
79-
values_dtype=val_dt,
80-
device=device
80+
A_shape, nnz, torch.sparse_coo, indices_dtype=idx_dt, values_dtype=val_dt, device=device
8181
).coalesce()
8282
B = torch.randn(B_shape, dtype=val_dt, device=device)
8383

@@ -89,28 +89,41 @@ def main():
8989
# run
9090
t_fwd, mem_fwd, t_bwd, mem_bwd = measure_op(alg_fn, A, B)
9191

92-
records.append({
93-
"size": size_label,
94-
"layout": layout_name,
95-
"algo": alg_name,
96-
"index_dt": str(idx_dt).split(".")[-1],
97-
"value_dt": str(val_dt).split(".")[-1],
98-
"N": N,
99-
"M": M,
100-
"nnz": nnz,
101-
"fwd_time_s": f"{t_fwd:.3f}",
102-
"fwd_mem_MB": f"{mem_fwd:.1f}",
103-
"bwd_time_s": f"{t_bwd:.3f}",
104-
"bwd_mem_MB": f"{mem_bwd:.1f}",
105-
})
92+
records.append(
93+
{
94+
"size": size_label,
95+
"layout": layout_name,
96+
"algo": alg_name,
97+
"index_dt": str(idx_dt).split(".")[-1],
98+
"value_dt": str(val_dt).split(".")[-1],
99+
"N": N,
100+
"M": M,
101+
"nnz": nnz,
102+
"fwd_time_s": f"{t_fwd:.3f}",
103+
"fwd_mem_MB": f"{mem_fwd:.1f}",
104+
"bwd_time_s": f"{t_bwd:.3f}",
105+
"bwd_mem_MB": f"{mem_bwd:.1f}",
106+
}
107+
)
106108

107109
df = pd.DataFrame.from_records(records)
108110
# reorder columns for clarity
109-
df = df[[
110-
"size", "layout", "algo", "index_dt", "value_dt",
111-
"N", "M", "nnz",
112-
"fwd_time_s", "fwd_mem_MB", "bwd_time_s", "bwd_mem_MB"
113-
]]
111+
df = df[
112+
[
113+
"size",
114+
"layout",
115+
"algo",
116+
"index_dt",
117+
"value_dt",
118+
"N",
119+
"M",
120+
"nnz",
121+
"fwd_time_s",
122+
"fwd_mem_MB",
123+
"bwd_time_s",
124+
"bwd_mem_MB",
125+
]
126+
]
114127

115128
md = df.to_markdown(index=False)
116129
with open("torchsparsegradutils/tests/benchmark_results_sparse_mm.md", "w") as f:
@@ -120,5 +133,6 @@ def main():
120133

121134
print("Written results to benchmark_results_sparse_mm.md")
122135

136+
123137
if __name__ == "__main__":
124138
main()

torchsparsegradutils/tests/profile_sparse_ops.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torchsparsegradutils import sparse_mm
44
from torchsparsegradutils.utils import rand_sparse
55

6+
67
def profile_sparse_mm(
78
A_shape,
89
B_shape,
@@ -11,11 +12,11 @@ def profile_sparse_mm(
1112
device=torch.device("cuda"),
1213
value_dtype=torch.float32,
1314
index_dtype=torch.int64,
14-
row_limit: int = 20
15+
row_limit: int = 20,
1516
):
1617
"""
1718
Profiles forward+backward of sparse_mm on a random sparse matrix.
18-
19+
1920
Args:
2021
A_shape (tuple): shape of A, e.g. (N, M) or (batch, N, M)
2122
B_shape (tuple): shape of B, must be compatible with A: (M, P) or (batch, M, P)
@@ -30,12 +31,7 @@ def profile_sparse_mm(
3031
raise RuntimeError("Profiler memory breakdown requires CUDA device")
3132

3233
# 1) build random sparse A and dense B
33-
A = rand_sparse(
34-
A_shape, A_nnz, layout,
35-
indices_dtype=index_dtype,
36-
values_dtype=value_dtype,
37-
device=device
38-
)
34+
A = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device)
3935
if layout == torch.sparse_coo:
4036
A = A.coalesce()
4137
B = torch.randn(*B_shape, dtype=value_dtype, device=device)
@@ -49,15 +45,12 @@ def profile_sparse_mm(
4945
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
5046
record_shapes=True,
5147
profile_memory=True,
52-
with_stack=False
48+
with_stack=False,
5349
) as prof:
5450
with record_function("sparse_mm_forward_and_backward"):
5551
out = sparse_mm(A, B)
5652
# you can sum or pick any scalar reduction
5753
out.sum().backward()
5854

5955
# 3) print a table sorted by the CUDA memory usage of each op
60-
print(prof.key_averages().table(
61-
sort_by="self_cuda_memory_usage",
62-
row_limit=row_limit
63-
))
56+
print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=row_limit))

torchsparsegradutils/tests/test_bicgstab.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33
from torchsparsegradutils.utils import bicgstab
44

55
# Device fixture
6-
DEVICES = [torch.device('cpu')]
6+
DEVICES = [torch.device("cpu")]
77
if torch.cuda.is_available():
8-
DEVICES.append(torch.device('cuda:0'))
8+
DEVICES.append(torch.device("cuda:0"))
9+
10+
11+
def _id_device(d):
12+
return str(d)
913

10-
def _id_device(d): return str(d)
1114

1215
@pytest.fixture(params=DEVICES, ids=_id_device)
1316
def device(request):
1417
return request.param
1518

19+
1620
def test_bicgstab(device):
1721
# setup SPD test problem
1822
size = 100

torchsparsegradutils/tests/test_cupy_bindings.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88

99
# Device fixture
10-
DEVICES = [torch.device('cpu')]
10+
DEVICES = [torch.device("cpu")]
1111
if torch.cuda.is_available():
12-
DEVICES.append(torch.device('cuda:0'))
12+
DEVICES.append(torch.device("cuda:0"))
13+
14+
15+
def _id_device(d):
16+
return str(d)
1317

14-
def _id_device(d): return str(d)
1518

1619
@pytest.fixture(params=DEVICES, ids=_id_device)
1720
def device(request):
@@ -23,24 +26,27 @@ def device(request):
2326
import cupy as cp
2427
import cupyx.scipy.sparse as csp
2528

29+
2630
# Helper to convert Cupy array to NumPy
2731
def _c2n(x_cupy):
2832
return cp.asnumpy(x_cupy) if tsgucupy.have_cupy else np.asarray(x_cupy)
2933

34+
3035
# I/O setup fixture
3136
@pytest.fixture
3237
def cupy_bindings_io(device):
3338
x_shape = (4, 2)
3439
x_t = torch.randn(x_shape, dtype=torch.float64, device=device)
3540
rng = np.random.default_rng()
3641
x_n = rng.standard_normal(x_shape, dtype=np.float64)
37-
if tsgucupy.have_cupy and device.type == 'cuda':
42+
if tsgucupy.have_cupy and device.type == "cuda":
3843
xp, xsp = cp, csp
3944
else:
4045
xp, xsp = np, nsp
4146
x_c = xp.asarray(x_n)
4247
return x_t, x_c, xsp
4348

49+
4450
# Test torch -> CuPy/NumPy COO conversion and back
4551
def test_t2c_and_c2t_coo(device, cupy_bindings_io):
4652
x_t, x_c, xsp = cupy_bindings_io
@@ -54,6 +60,7 @@ def test_t2c_and_c2t_coo(device, cupy_bindings_io):
5460
assert x_t2.shape == x_t_coo.shape
5561
assert np.allclose(x_t2.to_dense().cpu().numpy(), x_t_coo.to_dense().cpu().numpy())
5662

63+
5764
# Test NumPy/CuPy -> torch COO conversion and back
5865
def test_c2t_and_t2c_coo(device, cupy_bindings_io):
5966
x_t, x_c, xsp = cupy_bindings_io
@@ -65,6 +72,7 @@ def test_c2t_and_t2c_coo(device, cupy_bindings_io):
6572
assert x_c2.shape == x_c_coo.shape
6673
assert np.allclose(_c2n(x_c2.todense()), _c2n(x_c_coo.todense()))
6774

75+
6876
# Test torch -> CuPy/NumPy CSR conversion and back
6977
def test_t2c_and_c2t_csr(device, cupy_bindings_io):
7078
x_t, x_c, xsp = cupy_bindings_io
@@ -76,6 +84,7 @@ def test_t2c_and_c2t_csr(device, cupy_bindings_io):
7684
assert x_t2.shape == x_t_csr.shape
7785
assert np.allclose(x_t2.to_dense().cpu().numpy(), x_t_csr.to_dense().cpu().numpy())
7886

87+
7988
# Test NumPy/CuPy -> torch CSR conversion and back
8089
def test_c2t_and_t2c_csr(device, cupy_bindings_io):
8190
x_t, x_c, xsp = cupy_bindings_io

torchsparsegradutils/tests/test_cupy_sparse_solve.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,24 @@
2222
)
2323

2424
# Device fixture
25-
DEVICES = [torch.device('cpu')]
25+
DEVICES = [torch.device("cpu")]
2626
if torch.cuda.is_available():
27-
DEVICES.append(torch.device('cuda:0'))
27+
DEVICES.append(torch.device("cuda:0"))
28+
29+
30+
def _id_device(d):
31+
return str(d)
2832

29-
def _id_device(d): return str(d)
3033

3134
@pytest.fixture(params=DEVICES, ids=_id_device)
3235
def device(request):
3336
return request.param
3437

38+
3539
# relative tolerance for comparisons
3640
RTOL = 1e-3
3741

42+
3843
def _setup(device):
3944
# common setup
4045
A = torch.randn((4, 4), dtype=torch.float64, device=device)
@@ -44,21 +49,27 @@ def _setup(device):
4449
x_ref = torch.linalg.solve(A, B)
4550
return A_csr, B, x_ref
4651

52+
4753
def test_solver_c4t(device):
4854
A_csr, B, x_ref = _setup(device)
4955
x = tsgucupy.sparse_solve_c4t(A_csr.to(torch.float32), B.to(torch.float32))
5056
assert torch.allclose(x, x_ref.to(torch.float32), rtol=RTOL)
5157

58+
5259
def test_solver_gradient_c4t(device):
5360
A_csr, B, _ = _setup(device)
5461
# sparse solver gradient
55-
As1 = A_csr.to(torch.float32).detach().clone(); As1.requires_grad_(True)
56-
Bd1 = B.to(torch.float32).detach().clone(); Bd1.requires_grad_(True)
62+
As1 = A_csr.to(torch.float32).detach().clone()
63+
As1.requires_grad_(True)
64+
Bd1 = B.to(torch.float32).detach().clone()
65+
Bd1.requires_grad_(True)
5766
x = tsgucupy.sparse_solve_c4t(As1, Bd1)
5867
x.sum().backward()
5968
# dense reference gradient
60-
A = As1.to_dense().detach().clone(); A.requires_grad_(True)
61-
Bd2 = Bd1.detach().clone(); Bd2.requires_grad_(True)
69+
A = As1.to_dense().detach().clone()
70+
A.requires_grad_(True)
71+
Bd2 = Bd1.detach().clone()
72+
Bd2.requires_grad_(True)
6273
x2 = torch.linalg.solve(A, Bd2)
6374
x2.sum().backward()
6475
# compare results

torchsparsegradutils/tests/test_distributions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def set_seed():
9494

9595
# Convenience functions:
9696

97+
9798
def construct_distribution(sizes, layout, var, value_dtype, index_dtype, device, requires_grad=False):
9899
_, batch_size, event_size, sparsity = sizes
99100
loc = torch.randn(event_size, device=device, dtype=value_dtype, requires_grad=requires_grad)
@@ -160,6 +161,7 @@ def check_covariance_within_tolerance(
160161

161162
# Define Tests
162163

164+
163165
def test_rsample_forward_cov(device, layout, sizes, value_dtype, index_dtype):
164166
if layout == torch.sparse_coo and index_dtype == torch.int32:
165167
pytest.skip("Sparse COO with int32 indices is not supported")

0 commit comments

Comments
 (0)