Skip to content

Commit 6c79362

Browse files
committed
Switching over to creating a getter to return a scalar udf
1 parent 3b59efa commit 6c79362

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

rerun_py/src/catalog/dataset.rs

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArra
44
use arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
55
use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _};
66
use datafusion::common::exec_err;
7-
use pyo3::Bound;
7+
use pyo3::types::{PyAnyMethods, PyString, PyTuple};
8+
use pyo3::{Bound, IntoPyObject};
89
use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python};
10+
use re_tuid::Tuid;
911
use tokio_stream::StreamExt as _;
1012

1113
use re_chunk_store::{ChunkStore, ChunkStoreHandle};
@@ -26,6 +28,7 @@ use re_protos::manifest_registry::v1alpha1::{
2628
};
2729
use re_sdk::{ComponentDescriptor, ComponentName};
2830

31+
use crate::catalog::ConnectionHandle;
2932
use crate::catalog::{
3033
dataframe_query::PyDataframeQueryView, to_py_err, PyEntry, VectorDistanceMetricLike, VectorLike,
3134
};
@@ -98,54 +101,83 @@ impl PyDataset {
98101
.to_string()
99102
}
100103

104+
#[getter]
101105
fn partition_url_udf(
102-
self_: PyRef<'_, Self>,
103-
partition_id_expr: &Bound<'_, PyAny>,
106+
self_: PyRef<'_, Self>
104107
) -> PyResult<Py<PyAny>> {
108+
105109
let super_ = self_.as_super();
106110
let connection = super_.client.borrow(self_.py()).connection().clone();
111+
let py = self_.py();
107112

108-
let mut url = re_uri::DatasetDataUri {
109-
origin: connection.origin().clone(),
110-
dataset_id: super_.details.id.id,
111-
partition_id: "default".to_owned(), // to be replaced during loop
112-
113-
//TODO(ab): add support for these two
114-
time_range: None,
115-
fragment: Default::default(),
116-
};
113+
#[pyclass]
114+
struct PartitionUrlInner {
115+
pub connection: ConnectionHandle,
116+
pub dataset_id: Tuid,
117+
}
117118

118-
let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?;
119-
120-
match array_data.data_type() {
121-
DataType::Utf8 => {
122-
let str_array = StringArray::from(array_data);
123-
let str_iter = str_array.iter().map(|maybe_id| {
124-
maybe_id.map(|id| {
125-
url.partition_id = id.to_owned();
126-
url.to_string()
127-
})
128-
});
129-
let output_array: ArrayRef = Arc::new(str_iter.collect::<StringArray>());
130-
output_array.to_data().to_pyarrow(super_.py())
119+
#[pymethods]
120+
impl PartitionUrlInner {
121+
pub fn __call__(&self, py: Python<'_>, partition_id_expr: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
122+
123+
let mut url = re_uri::DatasetDataUri {
124+
origin: self.connection.origin().clone(),
125+
dataset_id: self.dataset_id,
126+
partition_id: "default".to_owned(), // to be replaced during loop
127+
128+
//TODO(ab): add support for these two
129+
time_range: None,
130+
fragment: Default::default(),
131+
};
132+
133+
let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?;
134+
135+
match array_data.data_type() {
136+
DataType::Utf8 => {
137+
let str_array = StringArray::from(array_data);
138+
let str_iter = str_array.iter().map(|maybe_id| {
139+
maybe_id.map(|id| {
140+
url.partition_id = id.to_owned();
141+
url.to_string()
142+
})
143+
});
144+
let output_array: ArrayRef = Arc::new(str_iter.collect::<StringArray>());
145+
output_array.to_data().to_pyarrow(py)
146+
}
147+
DataType::Utf8View => {
148+
let str_array = StringViewArray::from(array_data);
149+
let str_iter = str_array.iter().map(|maybe_id| {
150+
maybe_id.map(|id| {
151+
url.partition_id = id.to_owned();
152+
url.to_string()
153+
})
154+
});
155+
let output_array: ArrayRef = Arc::new(str_iter.collect::<StringViewArray>());
156+
output_array.to_data().to_pyarrow(py)
157+
}
158+
_ => exec_err!(
159+
"Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}",
160+
array_data.data_type()
161+
)
162+
.map_err(to_py_err),
163+
}
131164
}
132-
DataType::Utf8View => {
133-
let str_array = StringViewArray::from(array_data);
134-
let str_iter = str_array.iter().map(|maybe_id| {
135-
maybe_id.map(|id| {
136-
url.partition_id = id.to_owned();
137-
url.to_string()
138-
})
139-
});
140-
let output_array: ArrayRef = Arc::new(str_iter.collect::<StringViewArray>());
141-
output_array.to_data().to_pyarrow(super_.py())
142-
}
143-
_ => exec_err!(
144-
"Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}",
145-
array_data.data_type()
146-
)
147-
.map_err(to_py_err),
148165
}
166+
167+
let udf_factory = py.import("datafusion").and_then(|datafusion| datafusion.getattr("udf"))?;
168+
let pa_utf8 = py.import("pyarrow").and_then(|pa| pa.getattr("utf8")?.call0())?;
169+
170+
let inner = PartitionUrlInner {
171+
connection,
172+
dataset_id: super_.details.id.id,
173+
};
174+
let bound_inner = inner.into_pyobject(py)?;
175+
let py_stable = PyString::new(py, "stable");
176+
177+
// df.udf(dataset.partition_url_udf, pa.utf8(), pa.utf8(), 'stable')
178+
let args = PyTuple::new(py, vec![bound_inner.as_any(), pa_utf8.as_any(), pa_utf8.as_any(), py_stable.as_any()])?;
179+
180+
Ok(udf_factory.call1(args)?.unbind())
149181
}
150182

151183
/// Register a RRD URI to the dataset.

0 commit comments

Comments
 (0)