Skip to content

Commit 20c82c7

Browse files
committed
add values_as_arrays arg
1 parent 623c4a6 commit 20c82c7

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

nx_cugraph/classes/graph.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,17 +1171,27 @@ def _nodearray_to_set(self, node_ids: cp.ndarray[IndexValue]) -> set[NodeKey]:
11711171
def _nodearray_to_dict(
11721172
self,
11731173
values: cp.ndarray[NodeValue],
1174+
values_as_arrays: bool = False,
11741175
) -> dict[NodeKey, NodeValue]:
1175-
# values_as_arrays: bool | None = None,
1176-
it = enumerate(values.tolist())
1176+
if values_as_arrays:
1177+
it = enumerate(cp.asnumpy(values))
1178+
else:
1179+
it = enumerate(values.tolist())
11771180
if (id_to_key := self.id_to_key) is not None:
11781181
return {id_to_key[key]: val for key, val in it}
11791182
return dict(it)
11801183

11811184
def _nodearrays_to_dict(
1182-
self, node_ids: cp.ndarray[IndexValue], values: any_ndarray[NodeValue]
1185+
self,
1186+
node_ids: cp.ndarray[IndexValue],
1187+
values: any_ndarray[NodeValue],
1188+
values_as_arrays: bool = False,
11831189
) -> dict[NodeKey, NodeValue]:
1184-
it = zip(node_ids.tolist(), values.tolist())
1190+
if values_as_arrays:
1191+
vals = cp.asnumpy(values)
1192+
else:
1193+
vals = values.tolist()
1194+
it = zip(node_ids.tolist(), vals)
11851195
if (id_to_key := self.id_to_key) is not None:
11861196
return {id_to_key[key]: val for key, val in it}
11871197
return dict(it)

0 commit comments

Comments
 (0)