Skip to content

Commit 289927f

Browse files
author
sambit-giri
committed
torch backend for Chi calc
1 parent 0bd0544 commit 289927f

2 files changed

Lines changed: 68 additions & 1 deletion

File tree

src/tools21cm/ViteBetti.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def jit(func, *args, **kwargs):
3030
joblib_available = False
3131
Parallel, delayed, shared_memory = None, None, None
3232

33+
# --- Optional PyTorch (GPU) Support ---
34+
try:
35+
import torch
36+
torch_available = True
37+
except ImportError:
38+
torch_available = False
39+
torch = None
40+
3341
# --- Core Algorithm (Pure Python) ---
3442
def CubeMap(arr, multi_marker=True):
3543
"""
@@ -192,4 +200,61 @@ def CubeMap_joblib(arr, multi_marker=True, n_jobs=-1):
192200
shm.close()
193201
shm.unlink()
194202

195-
return result_cubemap
203+
return result_cubemap
204+
205+
# --- PyTorch GPU Implementation ---
206+
def CubeMap_torch(arr, multi_marker=True):
207+
"""
208+
Generates a cubical complex map using PyTorch for GPU acceleration.
209+
"""
210+
if not torch_available:
211+
raise ImportError("PyTorch is not installed. Cannot use 'torch' backend.")
212+
213+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
214+
print(f"Using PyTorch on device: {device}")
215+
216+
arr_tensor = torch.tensor(arr, dtype=torch.int32, device=device)
217+
nx, ny, nz = arr_tensor.shape
218+
Nx, Ny, Nz = 2 * nx, 2 * ny, 2 * nz
219+
cubemap = torch.zeros((Nx, Ny, Nz), dtype=torch.int32, device=device)
220+
221+
markers = (1, 1, 1, 1)
222+
if multi_marker:
223+
markers = (1, 2, 3, 4)
224+
225+
# Vertices
226+
coords = torch.nonzero(arr_tensor, as_tuple=False)
227+
if coords.shape[0] > 0:
228+
cubemap[coords[:, 0] * 2, coords[:, 1] * 2, coords[:, 2] * 2] = markers[0]
229+
230+
# Edges
231+
mask = cubemap == 0
232+
edge_mask = (
233+
(cubemap.roll(1, 0) == markers[0]) & (cubemap.roll(-1, 0) == markers[0]) |
234+
(cubemap.roll(1, 1) == markers[0]) & (cubemap.roll(-1, 1) == markers[0]) |
235+
(cubemap.roll(1, 2) == markers[0]) & (cubemap.roll(-1, 2) == markers[0])
236+
)
237+
cubemap[mask & edge_mask] = markers[1]
238+
239+
# Faces
240+
mask = cubemap == 0
241+
face_mask = (
242+
(cubemap.roll(1, 0) == markers[1]) & (cubemap.roll(-1, 0) == markers[1]) &
243+
(cubemap.roll(1, 1) == markers[1]) & (cubemap.roll(-1, 1) == markers[1]) |
244+
(cubemap.roll(1, 1) == markers[1]) & (cubemap.roll(-1, 1) == markers[1]) &
245+
(cubemap.roll(1, 2) == markers[1]) & (cubemap.roll(-1, 2) == markers[1]) |
246+
(cubemap.roll(1, 2) == markers[1]) & (cubemap.roll(-1, 2) == markers[1]) &
247+
(cubemap.roll(1, 0) == markers[1]) & (cubemap.roll(-1, 0) == markers[1])
248+
)
249+
cubemap[mask & face_mask] = markers[2]
250+
251+
# Cubes
252+
mask = cubemap == 0
253+
cube_mask = (
254+
(cubemap.roll(1, 0) == markers[2]) & (cubemap.roll(-1, 0) == markers[2]) &
255+
(cubemap.roll(1, 1) == markers[2]) & (cubemap.roll(-1, 1) == markers[2]) &
256+
(cubemap.roll(1, 2) == markers[2]) & (cubemap.roll(-1, 2) == markers[2])
257+
)
258+
cubemap[mask & cube_mask] = markers[3]
259+
260+
return cubemap.cpu().numpy()

src/tools21cm/topology.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def EulerCharacteristic(data, thres=0.5, neighbors=6, speed_up='cython', verbose
4242
C = VB.CubeMap_cython(A)
4343
elif speed_up.lower() == 'numba' and VB.numba_available:
4444
C = VB.CubeMap_numba(A)
45+
elif speed_up.lower() == 'torch' and VB.torch_available:
46+
C = VB.CubeMap_torch(A)
4547
else:
4648
if speed_up.lower() not in ['numpy', 'python']:
4749
print(f"Warning: '{speed_up}' backend not found. Falling back to pure Python.", file=sys.stderr)

0 commit comments

Comments
 (0)