Skip to content

Commit 8d90245

Browse files
committed
Update
1 parent 88eec8b commit 8d90245

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

nx_cugraph/classes/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,8 +1169,10 @@ def _nodearray_to_set(self, node_ids: cp.ndarray[IndexValue]) -> set[NodeKey]:
11691169
return set(self._nodeiter_to_iter(node_ids.tolist()))
11701170

11711171
def _nodearray_to_dict(
1172-
self, values: cp.ndarray[NodeValue]
1172+
self,
1173+
values: cp.ndarray[NodeValue],
11731174
) -> dict[NodeKey, NodeValue]:
1175+
# values_as_arrays: bool | None = None,
11741176
it = enumerate(values.tolist())
11751177
if (id_to_key := self.id_to_key) is not None:
11761178
return {id_to_key[key]: val for key, val in it}

nx_cugraph/drawing/layout.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nx_cugraph.utils import (
2626
_dtype_param,
2727
_get_float_dtype,
28+
_seed_to_int,
2829
networkx_algorithm,
2930
)
3031

@@ -92,7 +93,7 @@ def forceatlas2_layout(
9293
# NOTE currently only x & y (dim=2) coordinated are supported by PLC
9394
# greater dimensions should be supported in the future to align with nx
9495
start_pos_arr = G._dict_to_nodearray(
95-
pos, default=[np.nan] * 2, dtype=np.dtype(np.float32, dim)
96+
pos, default=[np.nan] * dim, dtype=np.dtype((np.float32, 2))
9697
)
9798

9899
# find, if there exists, the missing position values
@@ -103,11 +104,13 @@ def forceatlas2_layout(
103104
if num_missing:
104105
xy_min = cp.nanmin(start_pos_arr, axis=0)
105106
xy_max = cp.nanmax(start_pos_arr, axis=0)
107+
# random state from seed to fill missing coords is different from random
108+
# state used for PLC
106109
seed = create_random_state(seed)
107110

108111
# fill missing gaps with valid random coords
109112
start_pos_arr[missing_vals] = xy_min + cp.asarray(
110-
seed.rand(num_missing, 2), dtype=np.float32
113+
seed.rand(num_missing, dim), dtype=np.float32
111114
) * (xy_max - xy_min)
112115

113116
x_start = start_pos_arr[:, 0]
@@ -116,6 +119,8 @@ def forceatlas2_layout(
116119
x_start = None
117120
y_start = None
118121

122+
seed = _seed_to_int(seed)
123+
119124
vertices, x_axis, y_axis = plc.force_atlas2(
120125
plc.ResourceHandle(),
121126
random_state=seed,
@@ -126,10 +131,15 @@ def forceatlas2_layout(
126131
outbound_attraction_distribution=outbound_attraction_distribution,
127132
lin_log_mode=linlog,
128133
prevent_overlapping=dissuade_hubs, # this might not be the right usage
134+
edge_weight_influence=1, # default
129135
jitter_tolerance=jitter_tolerance,
136+
barnes_hut_optimize=False, # default
137+
barnes_hut_theta=0, # default ?
130138
scaling_ratio=scaling_ratio,
131139
strong_gravity_mode=strong_gravity,
132140
gravity=gravity,
141+
verbose=False, # default
142+
do_expensive_check=False, # default
133143
)
134144

135145
pos_arr = cp.column_stack((x_axis, y_axis))

nx_cugraph/utils/misc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -20,6 +20,7 @@
2020

2121
import cupy as cp
2222
import numpy as np
23+
from numpy.random import RandomState
2324

2425
if TYPE_CHECKING:
2526
import nx_cugraph as nxcg
@@ -124,7 +125,7 @@ def _seed_to_int(seed: int | Random | None) -> int:
124125
"""Handle any valid seed argument and convert it to an int if necessary."""
125126
if seed is None:
126127
return
127-
if isinstance(seed, Random):
128+
if isinstance(seed, (Random, RandomState)):
128129
return seed.randint(0, sys.maxsize)
129130
return op.index(seed) # Ensure seed is integral
130131

0 commit comments

Comments
 (0)