1+ import argparse
2+ from collections import deque
3+
4+ import numpy as np_cpu
5+
16import pykokkos as pk
27
38if pk .get_default_space () in pk .DeviceExecutionSpace :
49 import cupy as np
510else :
611 import numpy as np
712
8- import argparse
13+
14+ def _view_to_numpy_host (x ):
15+ """Host NumPy array for comparisons (handles Cupy-backed views)."""
16+ if hasattr (x , "get" ):
17+ return np_cpu .asarray (x .get ())
18+ return np_cpu .asarray (x )
19+
20+
21+ def reference_grid_bfs_distances (N : int , M : int , mat ) -> np_cpu .ndarray :
22+ """
23+ Shortest hop count (4-neighbor grid) from every cell to any cell with mat==0.
24+ Same graph as the PyKokkos workunits: vertices are grid cells, edges to N/E/S/W
25+ neighbors. Multi-source BFS using a queue, following the standard pattern in
26+ https://www.geeksforgeeks.org/python/python-program-for-breadth-first-search-or-bfs-for-a-graph/
27+ """
28+ mat_h = _view_to_numpy_host (mat )
29+ dist = np_cpu .full (N * M , - 1 , dtype = np_cpu .int32 )
30+ q = deque ()
31+ for r in range (N ):
32+ for c in range (M ):
33+ if mat_h [r , c ] == 0 :
34+ idx = r * M + c
35+ dist [idx ] = 0
36+ q .append ((r , c ))
37+ while q :
38+ r , c = q .popleft ()
39+ d = int (dist [r * M + c ])
40+ for dr , dc in ((- 1 , 0 ), (1 , 0 ), (0 , - 1 ), (0 , 1 )):
41+ nr , nc = r + dr , c + dc
42+ if 0 <= nr < N and 0 <= nc < M :
43+ ni = nr * M + nc
44+ if dist [ni ] == - 1 :
45+ dist [ni ] = d + 1
46+ q .append ((nr , nc ))
47+ return dist .astype (np_cpu .float64 )
48+
49+
50+ def assert_bfs_matches_pykokkos (N : int , M : int , mat , val , max_arr ) -> None :
51+ ref = reference_grid_bfs_distances (N , M , mat )
52+ val_h = _view_to_numpy_host (val )
53+ max_h = float (_view_to_numpy_host (max_arr )[0 ])
54+ if not np_cpu .allclose (val_h , ref , rtol = 0.0 , atol = 1e-9 ):
55+ bad = np_cpu .where (np_cpu .abs (val_h - ref ) > 1e-9 )[0 ][:16 ]
56+ raise AssertionError (
57+ f"distance mismatch at linear indices (first few): { bad .tolist ()} "
58+ )
59+ ref_max = float (np_cpu .max (ref ))
60+ if not np_cpu .isclose (max_h , ref_max , rtol = 0.0 , atol = 1e-9 ):
61+ raise AssertionError (f"max distance mismatch: got { max_h } , expected { ref_max } " )
62+ print ("BFS correctness check: OK (matches queue-based NumPy reference)" )
963
1064
1165def main (N : int , M : int ):
@@ -29,6 +83,8 @@ def main(N: int, M: int):
2983 pk .parallel_for (N , extend2D , N = N , max_arr = max_arr , max_arr2D = max_arr2D )
3084 pk .parallel_for (N , reduce1D , N = N , max_arr = max_arr , max_arr2D = max_arr2D )
3185
86+ assert_bfs_matches_pykokkos (N , M , mat , val , max_arr )
87+
3288 print (f"\n distance of every cell:\n " )
3389 for i in range (element ):
3490 print (f"val ({ val [i ]} ) " , end = "" )
@@ -73,21 +129,21 @@ def check_vis(
73129 if min_val > val [i - M ]:
74130 min_val = val [i - M ]
75131
76- # check the neighbor on the next row
132+ # check the neighbor on the next row
77133 if i // M < (N - 1 ):
78134 if visited [i + M ] == 1 :
79135 flag = 1
80136 if min_val > val [i + M ]:
81137 min_val = val [i + M ]
82138
83- # check the neighbor on the left
139+ # check the neighbor on the left
84140 if i % M > 0 :
85141 if visited [i - 1 ] == 1 :
86142 flag = 1
87143 if min_val > val [i - 1 ]:
88144 min_val = val [i - 1 ]
89145
90- # check the neighbor on the right
146+ # check the neighbor on the right
91147 if i % M < (M - 1 ):
92148 if visited [i + 1 ] == 1 :
93149 flag = 1
@@ -96,10 +152,10 @@ def check_vis(
96152
97153 # if there is at least one neighbor visited, the value of
98154 # the current index can be updated and should be marked as visited
99- if flag == 1 :
100- if val [i ] > min_val :
101- val [i ] = min_val + 1
102- visited [i ] = 1
155+ if flag == 1 :
156+ if val [i ] > min_val :
157+ val [i ] = min_val + 1
158+ visited [i ] = 1
103159
104160 ################################
105161 # findmax will find the maximum value of cell in each row
0 commit comments