@@ -30,6 +30,14 @@ def jit(func, *args, **kwargs):
3030 joblib_available = False
3131 Parallel , delayed , shared_memory = None , None , None
3232
33+ # --- Optional PyTorch (GPU) Support ---
34+ try :
35+ import torch
36+ torch_available = True
37+ except ImportError :
38+ torch_available = False
39+ torch = None
40+
3341# --- Core Algorithm (Pure Python) ---
3442def CubeMap (arr , multi_marker = True ):
3543 """
@@ -192,4 +200,61 @@ def CubeMap_joblib(arr, multi_marker=True, n_jobs=-1):
192200 shm .close ()
193201 shm .unlink ()
194202
195- return result_cubemap
203+ return result_cubemap
204+
205+ # --- PyTorch GPU Implementation ---
206+ def CubeMap_torch (arr , multi_marker = True ):
207+ """
208+ Generates a cubical complex map using PyTorch for GPU acceleration.
209+ """
210+ if not torch_available :
211+ raise ImportError ("PyTorch is not installed. Cannot use 'torch' backend." )
212+
213+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
214+ print (f"Using PyTorch on device: { device } " )
215+
216+ arr_tensor = torch .tensor (arr , dtype = torch .int32 , device = device )
217+ nx , ny , nz = arr_tensor .shape
218+ Nx , Ny , Nz = 2 * nx , 2 * ny , 2 * nz
219+ cubemap = torch .zeros ((Nx , Ny , Nz ), dtype = torch .int32 , device = device )
220+
221+ markers = (1 , 1 , 1 , 1 )
222+ if multi_marker :
223+ markers = (1 , 2 , 3 , 4 )
224+
225+ # Vertices
226+ coords = torch .nonzero (arr_tensor , as_tuple = False )
227+ if coords .shape [0 ] > 0 :
228+ cubemap [coords [:, 0 ] * 2 , coords [:, 1 ] * 2 , coords [:, 2 ] * 2 ] = markers [0 ]
229+
230+ # Edges
231+ mask = cubemap == 0
232+ edge_mask = (
233+ (cubemap .roll (1 , 0 ) == markers [0 ]) & (cubemap .roll (- 1 , 0 ) == markers [0 ]) |
234+ (cubemap .roll (1 , 1 ) == markers [0 ]) & (cubemap .roll (- 1 , 1 ) == markers [0 ]) |
235+ (cubemap .roll (1 , 2 ) == markers [0 ]) & (cubemap .roll (- 1 , 2 ) == markers [0 ])
236+ )
237+ cubemap [mask & edge_mask ] = markers [1 ]
238+
239+ # Faces
240+ mask = cubemap == 0
241+ face_mask = (
242+ (cubemap .roll (1 , 0 ) == markers [1 ]) & (cubemap .roll (- 1 , 0 ) == markers [1 ]) &
243+ (cubemap .roll (1 , 1 ) == markers [1 ]) & (cubemap .roll (- 1 , 1 ) == markers [1 ]) |
244+ (cubemap .roll (1 , 1 ) == markers [1 ]) & (cubemap .roll (- 1 , 1 ) == markers [1 ]) &
245+ (cubemap .roll (1 , 2 ) == markers [1 ]) & (cubemap .roll (- 1 , 2 ) == markers [1 ]) |
246+ (cubemap .roll (1 , 2 ) == markers [1 ]) & (cubemap .roll (- 1 , 2 ) == markers [1 ]) &
247+ (cubemap .roll (1 , 0 ) == markers [1 ]) & (cubemap .roll (- 1 , 0 ) == markers [1 ])
248+ )
249+ cubemap [mask & face_mask ] = markers [2 ]
250+
251+ # Cubes
252+ mask = cubemap == 0
253+ cube_mask = (
254+ (cubemap .roll (1 , 0 ) == markers [2 ]) & (cubemap .roll (- 1 , 0 ) == markers [2 ]) &
255+ (cubemap .roll (1 , 1 ) == markers [2 ]) & (cubemap .roll (- 1 , 1 ) == markers [2 ]) &
256+ (cubemap .roll (1 , 2 ) == markers [2 ]) & (cubemap .roll (- 1 , 2 ) == markers [2 ])
257+ )
258+ cubemap [mask & cube_mask ] = markers [3 ]
259+
260+ return cubemap .cpu ().numpy ()
0 commit comments