|
21 | 21 | import pylibcugraph as plc |
22 | 22 | from networkx.utils import create_random_state |
23 | 23 |
|
24 | | -import nx_cugraph as nxcg |
25 | 24 | from nx_cugraph.convert import _to_graph |
26 | 25 | from nx_cugraph.utils import ( |
27 | 26 | _dtype_param, |
28 | 27 | _get_float_dtype, |
29 | 28 | _seed_to_int, |
| 29 | + _update_cpu_gpu_graphs, |
30 | 30 | networkx_algorithm, |
31 | 31 | ) |
32 | 32 |
|
@@ -75,6 +75,7 @@ def forceatlas2_layout( |
75 | 75 | if len(G) == 0: |
76 | 76 | return {} |
77 | 77 |
|
| 78 | + # Mutate original graph if store_pos_as is given. |
78 | 79 | G_orig = G |
79 | 80 |
|
80 | 81 | if dim != 2: |
@@ -149,29 +150,24 @@ def forceatlas2_layout( |
149 | 150 | pos = G._nodearrays_to_dict( |
150 | 151 | node_ids=vertices, values=pos_arr, values_as_arrays=True |
151 | 152 | ) |
152 | | - |
| 153 | + store_pos_as = "pos" |
153 | 154 | if store_pos_as is not None: |
154 | | - if isinstance(G_orig, nxcg.Graph): |
155 | | - # Could be on GPU, CPU, or both. Update both GPU and CPU (if present) |
156 | | - if G_orig._is_on_gpu: |
157 | | - cuda_graph = G_orig._cudagraph |
| 155 | + |
| 156 | + def update_cpu(graph): |
| 157 | + nx.set_node_attributes(graph, pos, store_pos_as) |
| 158 | + |
| 159 | + update_pos_array = True |
| 160 | + |
| 161 | + def update_gpu(cuda_graph): |
| 162 | + # Ensure vertices are in order with their positions. |
| 163 | + # Use nonlocal variable to do this only once to ensure idempotency. |
| 164 | + nonlocal update_pos_array |
| 165 | + if update_pos_array: |
158 | 166 | pos_arr[vertices] = pos_arr |
159 | | - cuda_graph.node_values[store_pos_as] = pos_arr |
160 | | - else: |
161 | | - cuda_graph = None |
162 | | - if G_orig._is_on_cpu: |
163 | | - # This clears the cache (including on GPU) |
164 | | - nx.set_node_attributes(G_orig, pos, store_pos_as) |
165 | | - if cuda_graph is not None: |
166 | | - # Add back to GPU |
167 | | - G_orig._set_cudagraph(cuda_graph, clear_cpu=False) |
168 | | - elif isinstance(G_orig, nxcg.CudaGraph): |
169 | | - # ensure vertices are in order with their positions |
170 | | - pos_arr[vertices] = pos_arr |
171 | | - G_orig.node_values[store_pos_as] = pos_arr |
172 | | - else: |
173 | | - # Default: networkx graph |
174 | | - nx.set_node_attributes(G_orig, pos, store_pos_as) |
| 167 | + update_pos_array = False |
| 168 | + cuda_graph.node_values[store_pos_as] = pos_arr |
| 169 | + |
| 170 | + _update_cpu_gpu_graphs(G_orig, update_cpu=update_cpu, update_gpu=update_gpu) |
175 | 171 |
|
176 | 172 | return pos |
177 | 173 |
|
|
0 commit comments