Skip to content

Commit 008f9f8

Browse files
committed
Use public pylibcudf APIs in graph_primtypes
1 parent f7deb65 commit 008f9f8

File tree

2 files changed

+80
-29
lines changed

2 files changed

+80
-29
lines changed

python/cugraph/cugraph/structure/graph_primtypes.pxd

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -20,6 +20,7 @@ from libcpp cimport bool
2020
from libcpp.memory cimport unique_ptr
2121
from libcpp.utility cimport pair
2222
from libcpp.vector cimport vector
23+
from pylibcudf cimport DataType
2324
from pylibraft.common.handle cimport *
2425
from rmm.librmm.device_buffer cimport device_buffer
2526

@@ -154,8 +155,17 @@ ctypedef GraphCOOView[int,int,double] GraphCOOViewDouble
154155
ctypedef GraphCSRView[int,int,float] GraphCSRViewFloat
155156
ctypedef GraphCSRView[int,int,double] GraphCSRViewDouble
156157

157-
cdef move_device_buffer_to_column(unique_ptr[device_buffer] device_buffer_unique_ptr, dtype)
158-
cdef move_device_buffer_to_series(unique_ptr[device_buffer] device_buffer_unique_ptr, dtype, series_name)
158+
cdef move_device_buffer_to_column(
159+
unique_ptr[device_buffer] device_buffer_unique_ptr,
160+
DataType dtype,
161+
size_t itemsize,
162+
)
163+
cdef move_device_buffer_to_series(
164+
unique_ptr[device_buffer] device_buffer_unique_ptr,
165+
DataType dtype,
166+
size_t itemsize,
167+
series_name
168+
)
159169
cdef coo_to_df(GraphCOOPtrType graph)
160170
cdef csr_to_series(GraphCSRPtrType graph)
161171
cdef GraphCOOViewFloat get_coo_float_graph_view(input_graph, bool weighted=*)

python/cugraph/cugraph/structure/graph_primtypes.pyx

+67-26
Original file line numberDiff line numberDiff line change
@@ -16,61 +16,88 @@
1616
# cython: embedsignature = True
1717
# cython: language_level = 3
1818

19-
import numpy as np
2019
from libc.stdint cimport uintptr_t
2120
from libcpp.utility cimport move
2221

23-
from rmm.pylibrmm.device_buffer cimport DeviceBuffer
24-
from cudf.core.buffer import as_buffer
22+
from pylibcudf cimport Column, DataType, type_id
2523
import cudf
2624

2725

2826
cdef move_device_buffer_to_column(
29-
unique_ptr[device_buffer] device_buffer_unique_ptr, dtype):
27+
unique_ptr[device_buffer] device_buffer_unique_ptr,
28+
DataType dtype,
29+
size_t itemsize,
30+
):
3031
"""
3132
Transfers ownership of device_buffer_unique_ptr to a cuDF buffer which is
3233
used to construct a cudf column object, which is then returned. If the
3334
intermediate buffer is empty, the device_buffer_unique_ptr is still
3435
transfered but None is returned.
3536
"""
36-
buff = DeviceBuffer.c_from_unique_ptr(move(device_buffer_unique_ptr))
37-
buff = as_buffer(buff)
38-
if buff.nbytes != 0:
39-
column = cudf.core.column.build_column(buff, dtype=cudf.dtype(dtype))
40-
return column
37+
cdef size_t buff_size = device_buffer_unique_ptr.get().size()
38+
cdef size_t col_size = buff_size // itemsize
39+
cdef Column result_column = Column.from_rmm_buffer(
40+
move(device_buffer_unique_ptr),
41+
dtype,
42+
col_size,
43+
[],
44+
)
45+
if buff_size != 0:
46+
return result_column
4147
return None
4248

4349

4450
cdef move_device_buffer_to_series(
45-
unique_ptr[device_buffer] device_buffer_unique_ptr, dtype, series_name):
51+
unique_ptr[device_buffer] device_buffer_unique_ptr,
52+
DataType dtype,
53+
size_t itemsize,
54+
series_name
55+
):
4656
"""
4757
Transfers ownership of device_buffer_unique_ptr to a cuDF buffer which is
4858
used to construct a cudf.Series object with name series_name, which is then
4959
returned. If the intermediate buffer is empty, the device_buffer_unique_ptr
5060
is still transfered but None is returned.
5161
"""
52-
column = move_device_buffer_to_column(move(device_buffer_unique_ptr), dtype)
62+
column = move_device_buffer_to_column(
63+
move(device_buffer_unique_ptr),
64+
dtype,
65+
itemsize,
66+
)
5367
if column is not None:
54-
series = cudf.Series._from_data({series_name: column})
55-
return series
68+
return cudf.Series.from_pylibcudf(column, metadata={"name": series_name})
5669
return None
5770

5871

5972
cdef coo_to_df(GraphCOOPtrType graph):
6073
# FIXME: this function assumes columns named "src" and "dst" and can only
6174
# be used for SG graphs due to that assumption.
6275
contents = move(graph.get()[0].release())
63-
src = move_device_buffer_to_column(move(contents.src_indices), "int32")
64-
dst = move_device_buffer_to_column(move(contents.dst_indices), "int32")
76+
src = move_device_buffer_to_column(
77+
move(contents.src_indices),
78+
DataType(type_id.INT32),
79+
4,
80+
)
81+
dst = move_device_buffer_to_column(
82+
move(contents.dst_indices),
83+
DataType(type_id.INT32),
84+
4,
85+
)
6586

6687
if GraphCOOPtrType is GraphCOOPtrFloat:
67-
weight_type = "float32"
88+
weight_type = DataType(type_id.FLOAT32)
89+
itemsize = 4
6890
elif GraphCOOPtrType is GraphCOOPtrDouble:
69-
weight_type = "float64"
91+
weight_type = DataType(type_id.FLOAT64)
92+
itemsize = 8
7093
else:
7194
raise TypeError("Invalid GraphCOOPtrType")
7295

73-
wgt = move_device_buffer_to_column(move(contents.edge_data), weight_type)
96+
wgt = move_device_buffer_to_column(
97+
move(contents.edge_data),
98+
weight_type,
99+
itemsize
100+
)
74101

75102
df = cudf.DataFrame()
76103
df['src'] = src
@@ -83,20 +110,34 @@ cdef coo_to_df(GraphCOOPtrType graph):
83110

84111
cdef csr_to_series(GraphCSRPtrType graph):
85112
contents = move(graph.get()[0].release())
86-
csr_offsets = move_device_buffer_to_series(move(contents.offsets),
87-
"int32", "csr_offsets")
88-
csr_indices = move_device_buffer_to_series(move(contents.indices),
89-
"int32", "csr_indices")
113+
csr_offsets = move_device_buffer_to_series(
114+
move(contents.offsets),
115+
DataType(type_id.INT32),
116+
4,
117+
"csr_offsets"
118+
)
119+
csr_indices = move_device_buffer_to_series(
120+
move(contents.indices),
121+
DataType(type_id.INT32),
122+
4,
123+
"csr_indices"
124+
)
90125

91126
if GraphCSRPtrType is GraphCSRPtrFloat:
92-
weight_type = "float32"
127+
weight_type = DataType(type_id.FLOAT32)
128+
itemsize = 4
93129
elif GraphCSRPtrType is GraphCSRPtrDouble:
94-
weight_type = "float64"
130+
weight_type = DataType(type_id.FLOAT64)
131+
itemsize = 8
95132
else:
96133
raise TypeError("Invalid GraphCSRPtrType")
97134

98-
csr_weights = move_device_buffer_to_series(move(contents.edge_data),
99-
weight_type, "csr_weights")
135+
csr_weights = move_device_buffer_to_series(
136+
move(contents.edge_data),
137+
weight_type,
138+
itemsize,
139+
"csr_weights"
140+
)
100141

101142
return (csr_offsets, csr_indices, csr_weights)
102143

0 commit comments

Comments
 (0)