@@ -4,8 +4,10 @@ use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArra
4
4
use arrow:: datatypes:: { DataType , Field , Schema as ArrowSchema } ;
5
5
use arrow:: pyarrow:: { FromPyArrow as _, PyArrowType , ToPyArrow as _} ;
6
6
use datafusion:: common:: exec_err;
7
- use pyo3:: Bound ;
7
+ use pyo3:: types:: { PyAnyMethods , PyString , PyTuple } ;
8
+ use pyo3:: { Bound , IntoPyObject } ;
8
9
use pyo3:: { exceptions:: PyRuntimeError , pyclass, pymethods, Py , PyAny , PyRef , PyResult , Python } ;
10
+ use re_tuid:: Tuid ;
9
11
use tokio_stream:: StreamExt as _;
10
12
11
13
use re_chunk_store:: { ChunkStore , ChunkStoreHandle } ;
@@ -26,6 +28,7 @@ use re_protos::manifest_registry::v1alpha1::{
26
28
} ;
27
29
use re_sdk:: { ComponentDescriptor , ComponentName } ;
28
30
31
+ use crate :: catalog:: ConnectionHandle ;
29
32
use crate :: catalog:: {
30
33
dataframe_query:: PyDataframeQueryView , to_py_err, PyEntry , VectorDistanceMetricLike , VectorLike ,
31
34
} ;
@@ -98,54 +101,83 @@ impl PyDataset {
98
101
. to_string ( )
99
102
}
100
103
104
+ #[ getter]
101
105
fn partition_url_udf (
102
- self_ : PyRef < ' _ , Self > ,
103
- partition_id_expr : & Bound < ' _ , PyAny > ,
106
+ self_ : PyRef < ' _ , Self >
104
107
) -> PyResult < Py < PyAny > > {
108
+
105
109
let super_ = self_. as_super ( ) ;
106
110
let connection = super_. client . borrow ( self_. py ( ) ) . connection ( ) . clone ( ) ;
111
+ let py = self_. py ( ) ;
107
112
108
- let mut url = re_uri:: DatasetDataUri {
109
- origin : connection. origin ( ) . clone ( ) ,
110
- dataset_id : super_. details . id . id ,
111
- partition_id : "default" . to_owned ( ) , // to be replaced during loop
112
-
113
- //TODO(ab): add support for these two
114
- time_range : None ,
115
- fragment : Default :: default ( ) ,
116
- } ;
113
+ #[ pyclass]
114
+ struct PartitionUrlInner {
115
+ pub connection : ConnectionHandle ,
116
+ pub dataset_id : Tuid ,
117
+ }
117
118
118
- let array_data = ArrayData :: from_pyarrow_bound ( partition_id_expr) ?;
119
-
120
- match array_data. data_type ( ) {
121
- DataType :: Utf8 => {
122
- let str_array = StringArray :: from ( array_data) ;
123
- let str_iter = str_array. iter ( ) . map ( |maybe_id| {
124
- maybe_id. map ( |id| {
125
- url. partition_id = id. to_owned ( ) ;
126
- url. to_string ( )
127
- } )
128
- } ) ;
129
- let output_array: ArrayRef = Arc :: new ( str_iter. collect :: < StringArray > ( ) ) ;
130
- output_array. to_data ( ) . to_pyarrow ( super_. py ( ) )
119
+ #[ pymethods]
120
+ impl PartitionUrlInner {
121
+ pub fn __call__ ( & self , py : Python < ' _ > , partition_id_expr : & Bound < ' _ , PyAny > ) -> PyResult < Py < PyAny > > {
122
+
123
+ let mut url = re_uri:: DatasetDataUri {
124
+ origin : self . connection . origin ( ) . clone ( ) ,
125
+ dataset_id : self . dataset_id ,
126
+ partition_id : "default" . to_owned ( ) , // to be replaced during loop
127
+
128
+ //TODO(ab): add support for these two
129
+ time_range : None ,
130
+ fragment : Default :: default ( ) ,
131
+ } ;
132
+
133
+ let array_data = ArrayData :: from_pyarrow_bound ( partition_id_expr) ?;
134
+
135
+ match array_data. data_type ( ) {
136
+ DataType :: Utf8 => {
137
+ let str_array = StringArray :: from ( array_data) ;
138
+ let str_iter = str_array. iter ( ) . map ( |maybe_id| {
139
+ maybe_id. map ( |id| {
140
+ url. partition_id = id. to_owned ( ) ;
141
+ url. to_string ( )
142
+ } )
143
+ } ) ;
144
+ let output_array: ArrayRef = Arc :: new ( str_iter. collect :: < StringArray > ( ) ) ;
145
+ output_array. to_data ( ) . to_pyarrow ( py)
146
+ }
147
+ DataType :: Utf8View => {
148
+ let str_array = StringViewArray :: from ( array_data) ;
149
+ let str_iter = str_array. iter ( ) . map ( |maybe_id| {
150
+ maybe_id. map ( |id| {
151
+ url. partition_id = id. to_owned ( ) ;
152
+ url. to_string ( )
153
+ } )
154
+ } ) ;
155
+ let output_array: ArrayRef = Arc :: new ( str_iter. collect :: < StringViewArray > ( ) ) ;
156
+ output_array. to_data ( ) . to_pyarrow ( py)
157
+ }
158
+ _ => exec_err ! (
159
+ "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}" ,
160
+ array_data. data_type( )
161
+ )
162
+ . map_err ( to_py_err) ,
163
+ }
131
164
}
132
- DataType :: Utf8View => {
133
- let str_array = StringViewArray :: from ( array_data) ;
134
- let str_iter = str_array. iter ( ) . map ( |maybe_id| {
135
- maybe_id. map ( |id| {
136
- url. partition_id = id. to_owned ( ) ;
137
- url. to_string ( )
138
- } )
139
- } ) ;
140
- let output_array: ArrayRef = Arc :: new ( str_iter. collect :: < StringViewArray > ( ) ) ;
141
- output_array. to_data ( ) . to_pyarrow ( super_. py ( ) )
142
- }
143
- _ => exec_err ! (
144
- "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}" ,
145
- array_data. data_type( )
146
- )
147
- . map_err ( to_py_err) ,
148
165
}
166
+
167
+ let udf_factory = py. import ( "datafusion" ) . and_then ( |datafusion| datafusion. getattr ( "udf" ) ) ?;
168
+ let pa_utf8 = py. import ( "pyarrow" ) . and_then ( |pa| pa. getattr ( "utf8" ) ?. call0 ( ) ) ?;
169
+
170
+ let inner = PartitionUrlInner {
171
+ connection,
172
+ dataset_id : super_. details . id . id ,
173
+ } ;
174
+ let bound_inner = inner. into_pyobject ( py) ?;
175
+ let py_stable = PyString :: new ( py, "stable" ) ;
176
+
177
+ // df.udf(dataset.partition_url_udf, pa.utf8(), pa.utf8(), 'stable')
178
+ let args = PyTuple :: new ( py, vec ! [ bound_inner. as_any( ) , pa_utf8. as_any( ) , pa_utf8. as_any( ) , py_stable. as_any( ) ] ) ?;
179
+
180
+ Ok ( udf_factory. call1 ( args) ?. unbind ( ) )
149
181
}
150
182
151
183
/// Register a RRD URI to the dataset.
0 commit comments