2525from nx_cugraph .utils import (
2626 _dtype_param ,
2727 _get_float_dtype ,
28+ _seed_to_int ,
2829 networkx_algorithm ,
2930)
3031
@@ -92,7 +93,7 @@ def forceatlas2_layout(
9293 # NOTE currently only x & y (dim=2) coordinated are supported by PLC
9394 # greater dimensions should be supported in the future to align with nx
9495 start_pos_arr = G ._dict_to_nodearray (
95- pos , default = [np .nan ] * 2 , dtype = np .dtype (np .float32 , dim )
96+ pos , default = [np .nan ] * dim , dtype = np .dtype (( np .float32 , 2 ) )
9697 )
9798
9899 # find, if there exists, the missing position values
@@ -103,11 +104,13 @@ def forceatlas2_layout(
103104 if num_missing :
104105 xy_min = cp .nanmin (start_pos_arr , axis = 0 )
105106 xy_max = cp .nanmax (start_pos_arr , axis = 0 )
107+ # random state from seed to fill missing coords is different from random
108+ # state used for PLC
106109 seed = create_random_state (seed )
107110
108111 # fill missing gaps with valid random coords
109112 start_pos_arr [missing_vals ] = xy_min + cp .asarray (
110- seed .rand (num_missing , 2 ), dtype = np .float32
113+ seed .rand (num_missing , dim ), dtype = np .float32
111114 ) * (xy_max - xy_min )
112115
113116 x_start = start_pos_arr [:, 0 ]
@@ -116,6 +119,8 @@ def forceatlas2_layout(
116119 x_start = None
117120 y_start = None
118121
122+ seed = _seed_to_int (seed )
123+
119124 vertices , x_axis , y_axis = plc .force_atlas2 (
120125 plc .ResourceHandle (),
121126 random_state = seed ,
@@ -126,10 +131,15 @@ def forceatlas2_layout(
126131 outbound_attraction_distribution = outbound_attraction_distribution ,
127132 lin_log_mode = linlog ,
128133 prevent_overlapping = dissuade_hubs , # this might not be the right usage
134+ edge_weight_influence = 1 , # default
129135 jitter_tolerance = jitter_tolerance ,
136+ barnes_hut_optimize = False , # default
137+ barnes_hut_theta = 0 , # default ?
130138 scaling_ratio = scaling_ratio ,
131139 strong_gravity_mode = strong_gravity ,
132140 gravity = gravity ,
141+ verbose = False , # default
142+ do_expensive_check = False , # default
133143 )
134144
135145 pos_arr = cp .column_stack ((x_axis , y_axis ))
0 commit comments