Skip to content

Commit b9186b9

Browse files
committed
Refactor updating graphs: nx.utils._update_cpu_gpu_graphs
1 parent 4cb9b2e commit b9186b9

File tree

3 files changed

+71
-23
lines changed

3 files changed

+71
-23
lines changed

nx_cugraph/classes/graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def clear(self) -> None:
8686
self._graph._reify_networkx()
8787
super().clear()
8888

89+
def _clear_no_reify_networkx(self):
90+
super().clear()
91+
8992

9093
class Graph(nx.Graph):
9194
# Tell networkx to dispatch calls with this object to nx-cugraph

nx_cugraph/drawing/layout.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
import pylibcugraph as plc
2222
from networkx.utils import create_random_state
2323

24-
import nx_cugraph as nxcg
2524
from nx_cugraph.convert import _to_graph
2625
from nx_cugraph.utils import (
2726
_dtype_param,
2827
_get_float_dtype,
2928
_seed_to_int,
29+
_update_cpu_gpu_graphs,
3030
networkx_algorithm,
3131
)
3232

@@ -75,6 +75,7 @@ def forceatlas2_layout(
7575
if len(G) == 0:
7676
return {}
7777

78+
# Mutate original graph if store_pos_as is given.
7879
G_orig = G
7980

8081
if dim != 2:
@@ -149,29 +150,24 @@ def forceatlas2_layout(
149150
pos = G._nodearrays_to_dict(
150151
node_ids=vertices, values=pos_arr, values_as_arrays=True
151152
)
152-
153+
store_pos_as = "pos"
153154
if store_pos_as is not None:
154-
if isinstance(G_orig, nxcg.Graph):
155-
# Could be on GPU, CPU, or both. Update both GPU and CPU (if present)
156-
if G_orig._is_on_gpu:
157-
cuda_graph = G_orig._cudagraph
155+
156+
def update_cpu(graph):
157+
nx.set_node_attributes(graph, pos, store_pos_as)
158+
159+
update_pos_array = True
160+
161+
def update_gpu(cuda_graph):
162+
# Ensure vertices are in order with their positions.
163+
# Use nonlocal variable to do this only once to ensure idempotency.
164+
nonlocal update_pos_array
165+
if update_pos_array:
158166
pos_arr[vertices] = pos_arr
159-
cuda_graph.node_values[store_pos_as] = pos_arr
160-
else:
161-
cuda_graph = None
162-
if G_orig._is_on_cpu:
163-
# This clears the cache (including on GPU)
164-
nx.set_node_attributes(G_orig, pos, store_pos_as)
165-
if cuda_graph is not None:
166-
# Add back to GPU
167-
G_orig._set_cudagraph(cuda_graph, clear_cpu=False)
168-
elif isinstance(G_orig, nxcg.CudaGraph):
169-
# ensure vertices are in order with their positions
170-
pos_arr[vertices] = pos_arr
171-
G_orig.node_values[store_pos_as] = pos_arr
172-
else:
173-
# Default: networkx graph
174-
nx.set_node_attributes(G_orig, pos, store_pos_as)
167+
update_pos_array = False
168+
cuda_graph.node_values[store_pos_as] = pos_arr
169+
170+
_update_cpu_gpu_graphs(G_orig, update_cpu=update_cpu, update_gpu=update_gpu)
175171

176172
return pos
177173

nx_cugraph/utils/misc.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
import cupy as cp
2222
import numpy as np
2323

24+
import nx_cugraph as nxcg
25+
2426
if TYPE_CHECKING:
25-
import nx_cugraph as nxcg
27+
from collections.abc import Callable
28+
29+
import networkx as nx
2630

2731
from ..typing import Dtype, EdgeKey
2832

@@ -46,6 +50,7 @@ def pairwise(it):
4650
"_get_float_dtype",
4751
"_dtype_param",
4852
"_cp_iscopied_asarray",
53+
"_update_cpu_gpu_graphs",
4954
]
5055

5156
# This may switch to np.uint32 at some point
@@ -240,6 +245,50 @@ def _cp_iscopied_asarray(a, *args, orig_object=None, **kwargs):
240245
return True, arr
241246

242247

248+
def _update_cpu_gpu_graphs(
249+
G,
250+
*,
251+
update_cpu: Callable[[nx.Graph | nxcg.Graph], None],
252+
update_gpu: Callable[[nxcg.CudaGraph], None],
253+
) -> None:
254+
"""Update graph in-place whether it's on CPU or GPU (or both).
255+
256+
This works with nx.Graph, nxcg.Graph, and nxcg.CudaGraph objects.
257+
nxcg.Graph instances will update both CPU and GPU data structures
258+
if applicable.
259+
260+
Parameters
261+
----------
262+
update_cpu : func
263+
Function to modify a networkx-compatible graph in-place.
264+
265+
update_gpu: func
266+
Function to modify a CudaGraph graph in-place. This also
267+
updates the CudaGraph of an nxcg.Graph if it is on the GPU.
268+
"""
269+
if isinstance(G, nxcg.Graph):
270+
# Could be on GPU, CPU, or both. Update both GPU and CPU (if present)
271+
if G._is_on_gpu:
272+
cuda_graph = G._cudagraph
273+
update_gpu(cuda_graph)
274+
# Clear anything else in the cache; cache is invalidated.
275+
# We will re-add the cuda_graph below.
276+
G.__networkx_cache__._clear_no_reify_networkx()
277+
else:
278+
cuda_graph = None
279+
if G._is_on_cpu:
280+
# This clears the cache (including on GPU)
281+
update_cpu(G)
282+
if cuda_graph is not None:
283+
# Add back to GPU
284+
G._set_cudagraph(cuda_graph, clear_cpu=False)
285+
elif isinstance(G, nxcg.CudaGraph):
286+
update_gpu(G)
287+
else:
288+
# Default: networkx graph
289+
update_cpu(G)
290+
291+
243292
class _And_NotImplementedError(NotImplementedError):
244293
"""Additionally make an exception a ``NotImplementedError``.
245294

0 commit comments

Comments
 (0)