Skip to content

Commit 493eed8

Browse files
author
sambit-giri
committed
verbose of Euler Characterstics with torch guarded
1 parent 60a827f commit 493eed8

3 files changed

Lines changed: 99 additions & 6 deletions

File tree

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import numpy as np
2+
import time
3+
import os
4+
import gc
5+
6+
from tools21cm import topology
7+
8+
def create_test_data(box_dim=128):
9+
"""Creates a reasonably complex 3D binary array for testing."""
10+
print(f"Generating a {box_dim}x{box_dim}x{box_dim} test data cube...")
11+
arr = np.zeros((box_dim, box_dim, box_dim), dtype=np.int32)
12+
13+
# A solid core
14+
size = box_dim // 4
15+
start = box_dim // 2 - size // 2
16+
end = start + size
17+
arr[start:end, start:end, start:end] = 1
18+
19+
# Random noise to make it less trivial
20+
noise = np.random.randint(0, 200, size=arr.shape)
21+
arr[noise > 198] = 1
22+
print(f"Test data generated with {arr.sum()} active cells.\n")
23+
return arr
24+
25+
if __name__ == "__main__":
26+
# --- Setup ---
27+
BOX_DIM = 32 #128
28+
test_data = create_test_data(BOX_DIM)
29+
timings = {}
30+
31+
# Define the backends to test
32+
backends_to_test = ['python', 'numba', 'cython', 'torch']
33+
chi_results = {}
34+
for backend in backends_to_test:
35+
print(f"--- Benchmarking '{backend}' Backend ---")
36+
37+
# Check if the backend is available
38+
available = False
39+
if backend == 'python': available = True
40+
elif backend == 'numba' and topology.VB.numba_available: available = True
41+
elif backend == 'cython' and topology.VB.cython_available: available = True
42+
elif backend == 'torch' and topology.VB.torch_available: available = True
43+
44+
if not available:
45+
print(f" Backend not available. Skipping.\n")
46+
continue
47+
48+
# For parallel Cython, set thread count to max
49+
if backend == 'cython':
50+
n_threads = os.cpu_count()
51+
os.environ['OMP_NUM_THREADS'] = str(n_threads)
52+
print(f" (Using {n_threads} threads for Cython/OpenMP)")
53+
54+
# Perform a warm-up run for JIT or GPU backends
55+
if backend in ['numba', 'torch']:
56+
print(" (Warm-up run...)")
57+
topology.EulerCharacteristic(test_data, speed_up=backend, verbose=False)
58+
59+
# Run the actual benchmark
60+
print(" (Benchmarking run...)")
61+
t_start = time.time()
62+
chi_value = topology.EulerCharacteristic(test_data, speed_up=backend, verbose=True)
63+
t_end = time.time()
64+
65+
duration = t_end - t_start
66+
timings[backend] = duration
67+
68+
print(f" Result Chi = {chi_value}, Time = {duration:.4f} seconds\n")
69+
chi_results[backend] = chi_value
70+
gc.collect()
71+
72+
# Clean up environment variable
73+
if 'OMP_NUM_THREADS' in os.environ:
74+
del os.environ['OMP_NUM_THREADS']
75+
76+
# --- Final Summary ---
77+
print("="*65)
78+
print(" Backend Benchmark Summary")
79+
print("="*65)
80+
if timings:
81+
sorted_results = sorted(timings.items(), key=lambda item: item[1])
82+
baseline_time = timings.get('python', 1e-9)
83+
84+
# Add the 'Chi Value' column to the header
85+
print(f"{'Implementation':<20} | {'Chi Value':<12} | {'Time (s)':<15} | {'Speedup'}")
86+
print("-"*65)
87+
88+
for name, t in sorted_results:
89+
chi_val = chi_results.get(name, 'N/A')
90+
speedup = baseline_time / t
91+
print(f"{name:<20} | {chi_val:<12.0f} | {t:<15.4f} | {speedup:.2f}x")
92+
else:
93+
print("No backends were benchmarked.")

src/tools21cm/ViteBetti.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def jit(func, *args, **kwargs):
3434
try:
3535
import torch
3636
torch_available = True
37+
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3738
except ImportError:
3839
torch_available = False
3940
torch = None
@@ -203,20 +204,17 @@ def CubeMap_joblib(arr, multi_marker=True, n_jobs=-1):
203204
return result_cubemap
204205

205206
# --- PyTorch GPU Implementation ---
206-
def CubeMap_torch(arr, multi_marker=True):
207+
def CubeMap_torch(arr, multi_marker=True, verbose=False):
207208
"""
208209
Generates a cubical complex map using PyTorch for GPU acceleration.
209210
"""
210211
if not torch_available:
211212
raise ImportError("PyTorch is not installed. Cannot use 'torch' backend.")
212213

213-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
214-
print(f"Using PyTorch on device: {device}")
215-
216-
arr_tensor = torch.tensor(arr, dtype=torch.int32, device=device)
214+
arr_tensor = torch.tensor(arr, dtype=torch.int32, device=torch_device)
217215
nx, ny, nz = arr_tensor.shape
218216
Nx, Ny, Nz = 2 * nx, 2 * ny, 2 * nz
219-
cubemap = torch.zeros((Nx, Ny, Nz), dtype=torch.int32, device=device)
217+
cubemap = torch.zeros((Nx, Ny, Nz), dtype=torch.int32, device=torch_device)
220218

221219
markers = (1, 1, 1, 1)
222220
if multi_marker:

src/tools21cm/topology.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def EulerCharacteristic(data, thres=0.5, neighbors=6, speed_up='cython', verbose
4444
elif speed_up.lower() == 'numba' and VB.numba_available:
4545
C = VB.CubeMap_numba(A)
4646
elif speed_up.lower() == 'torch' and VB.torch_available:
47+
if verbose:
48+
print(f"device={VB.torch_device}...", end="")
4749
C = VB.CubeMap_torch(A)
4850
else:
4951
if speed_up.lower() not in ['numpy', 'python']:

0 commit comments

Comments
 (0)