From e705ebe4588bbe2b80fa370babff92394f96d0e4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 7 Apr 2025 10:28:06 -0400 Subject: [PATCH 1/6] Add UDF for turning an array of partition IDs into their URLs --- rerun_py/src/catalog/dataset.rs | 53 +++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index 2a83134e82dd..8abcd53c7105 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -1,8 +1,10 @@ use std::sync::Arc; -use arrow::array::{RecordBatch, StringArray}; -use arrow::datatypes::{Field, Schema as ArrowSchema}; -use arrow::pyarrow::PyArrowType; +use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; +use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow}; +use datafusion::common::exec_err; +use pyo3::Bound; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; use tokio_stream::StreamExt as _; @@ -96,6 +98,51 @@ impl PyDataset { .to_string() } + fn partition_url_udf( + self_: PyRef<'_, Self>, + partition_id_expr: &Bound<'_, PyAny> + ) -> PyResult> { + let super_ = self_.as_super(); + let connection = super_.client.borrow(self_.py()).connection().clone(); + + let mut url = re_uri::DatasetDataEndpoint { + origin: connection.origin().clone(), + dataset_id: super_.details.id.id, + partition_id: "default".to_string(), // to be replaced during loop + + //TODO(ab): add support for these two + time_range: None, + fragment: Default::default(), + }; + + let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?; + + match array_data.data_type() { + DataType::Utf8 => { + let str_array = StringArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| maybe_id.map(|id| { + url.partition_id = id.to_string(); + url.to_string() + })); + let output_array: ArrayRef = Arc::new(StringArray::from_iter(str_iter)); + output_array.to_data().to_pyarrow(super_.py()) + } + DataType::Utf8View => { + let str_array = StringViewArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| maybe_id.map(|id| { + url.partition_id = id.to_string(); + url.to_string() + })); + let output_array: ArrayRef = Arc::new(StringViewArray::from_iter(str_iter)); + output_array.to_data().to_pyarrow(super_.py()) + } + _ => { + return exec_err!("Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", array_data.data_type()).map_err(to_py_err) + } + } + + } + /// Register a RRD URI to the dataset. fn register(self_: PyRef<'_, Self>, recording_uri: String) -> PyResult<()> { let super_ = self_.as_super(); From 268aadb18a0b64468e44fcd08577c44cf1530acd Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 7 Apr 2025 11:40:01 -0400 Subject: [PATCH 2/6] Fix a few clippy warnings --- rerun_py/src/catalog/dataset.rs | 43 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index 8abcd53c7105..b196a13815ff 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; -use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow}; +use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _}; use datafusion::common::exec_err; use pyo3::Bound; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; @@ -100,15 +100,15 @@ impl PyDataset { fn partition_url_udf( self_: PyRef<'_, Self>, - partition_id_expr: &Bound<'_, PyAny> + partition_id_expr: &Bound<'_, PyAny>, ) -> PyResult> { let super_ = self_.as_super(); let connection = super_.client.borrow(self_.py()).connection().clone(); - let mut url = re_uri::DatasetDataEndpoint { + let mut url = re_uri::DatasetDataUri { origin: connection.origin().clone(), dataset_id: super_.details.id.id, - partition_id: "default".to_string(), // to be replaced during loop + partition_id: "default".to_owned(), // to be replaced during loop //TODO(ab): add support for these two time_range: None, @@ -117,30 +117,35 @@ impl PyDataset { let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?; - match array_data.data_type() { + match array_data.data_type() { DataType::Utf8 => { let str_array = StringArray::from(array_data); - let str_iter = str_array.iter().map(|maybe_id| maybe_id.map(|id| { - url.partition_id = id.to_string(); - url.to_string() - })); - let output_array: ArrayRef = Arc::new(StringArray::from_iter(str_iter)); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); output_array.to_data().to_pyarrow(super_.py()) } DataType::Utf8View => { let str_array = StringViewArray::from(array_data); - let str_iter = str_array.iter().map(|maybe_id| maybe_id.map(|id| { - url.partition_id = id.to_string(); - url.to_string() - })); - let output_array: ArrayRef = Arc::new(StringViewArray::from_iter(str_iter)); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); output_array.to_data().to_pyarrow(super_.py()) } - _ => { - return exec_err!("Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", array_data.data_type()).map_err(to_py_err) - } + _ => exec_err!( + "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", + array_data.data_type() + ) + .map_err(to_py_err), } - } /// Register a RRD URI to the dataset. From 4b68d6b77c1cd358a7f0712bcf9c62973fa0a54e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 7 Apr 2025 12:47:22 -0400 Subject: [PATCH 3/6] Switching over to creating a getter to return a scalar udf --- rerun_py/src/catalog/dataset.rs | 114 ++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 41 deletions(-) diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index b196a13815ff..a77c5361a0ff 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -4,8 +4,10 @@ use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArra use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _}; use datafusion::common::exec_err; -use pyo3::Bound; +use pyo3::types::{PyAnyMethods, PyString, PyTuple}; +use pyo3::{Bound, IntoPyObject}; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; +use re_tuid::Tuid; use tokio_stream::StreamExt as _; use re_chunk_store::{ChunkStore, ChunkStoreHandle}; @@ -26,6 +28,7 @@ use re_protos::manifest_registry::v1alpha1::{ }; use re_sdk::{ComponentDescriptor, ComponentName}; +use crate::catalog::ConnectionHandle; use crate::catalog::{ dataframe_query::PyDataframeQueryView, to_py_err, PyEntry, VectorDistanceMetricLike, VectorLike, }; @@ -98,54 +101,83 @@ impl PyDataset { .to_string() } + #[getter] fn partition_url_udf( - self_: PyRef<'_, Self>, - partition_id_expr: &Bound<'_, PyAny>, + self_: PyRef<'_, Self> ) -> PyResult> { + let super_ = self_.as_super(); let connection = super_.client.borrow(self_.py()).connection().clone(); + let py = self_.py(); - let mut url = re_uri::DatasetDataUri { - origin: connection.origin().clone(), - dataset_id: super_.details.id.id, - partition_id: "default".to_owned(), // to be replaced during loop - - //TODO(ab): add support for these two - time_range: None, - fragment: Default::default(), - }; + #[pyclass] + struct PartitionUrlInner { + pub connection: ConnectionHandle, + pub dataset_id: Tuid, + } - let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?; - - match array_data.data_type() { - DataType::Utf8 => { - let str_array = StringArray::from(array_data); - let str_iter = str_array.iter().map(|maybe_id| { - maybe_id.map(|id| { - url.partition_id = id.to_owned(); - url.to_string() - }) - }); - let output_array: ArrayRef = Arc::new(str_iter.collect::()); - output_array.to_data().to_pyarrow(super_.py()) + #[pymethods] + impl PartitionUrlInner { + pub fn __call__(&self, py: Python<'_>, partition_id_expr: &Bound<'_, PyAny>) -> PyResult> { + + let mut url = re_uri::DatasetDataUri { + origin: self.connection.origin().clone(), + dataset_id: self.dataset_id, + partition_id: "default".to_owned(), // to be replaced during loop + + //TODO(ab): add support for these two + time_range: None, + fragment: Default::default(), + }; + + let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?; + + match array_data.data_type() { + DataType::Utf8 => { + let str_array = StringArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); + output_array.to_data().to_pyarrow(py) + } + DataType::Utf8View => { + let str_array = StringViewArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); + output_array.to_data().to_pyarrow(py) + } + _ => exec_err!( + "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", + array_data.data_type() + ) + .map_err(to_py_err), + } } - DataType::Utf8View => { - let str_array = StringViewArray::from(array_data); - let str_iter = str_array.iter().map(|maybe_id| { - maybe_id.map(|id| { - url.partition_id = id.to_owned(); - url.to_string() - }) - }); - let output_array: ArrayRef = Arc::new(str_iter.collect::()); - output_array.to_data().to_pyarrow(super_.py()) - } - _ => exec_err!( - "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", - array_data.data_type() - ) - .map_err(to_py_err), } + + let udf_factory = py.import("datafusion").and_then(|datafusion| datafusion.getattr("udf"))?; + let pa_utf8 = py.import("pyarrow").and_then(|pa| pa.getattr("utf8")?.call0())?; + + let inner = PartitionUrlInner { + connection, + dataset_id: super_.details.id.id, + }; + let bound_inner = inner.into_pyobject(py)?; + let py_stable = PyString::new(py, "stable"); + + // df.udf(dataset.partition_url_udf, pa.utf8(), pa.utf8(), 'stable') + let args = PyTuple::new(py, vec![bound_inner.as_any(), pa_utf8.as_any(), pa_utf8.as_any(), py_stable.as_any()])?; + + Ok(udf_factory.call1(args)?.unbind()) } /// Register a RRD URI to the dataset. From b9711be396e6be41bc825e6ac3be75924da72fa7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 7 Apr 2025 15:26:04 -0400 Subject: [PATCH 4/6] Switch over to use a closure so we can extract out the common parts into a util --- rerun_py/src/catalog/dataset.rs | 137 ++++++++++++++++---------------- 1 file changed, 69 insertions(+), 68 deletions(-) diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index a77c5361a0ff..0314396ddc32 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -4,10 +4,9 @@ use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArra use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _}; use datafusion::common::exec_err; -use pyo3::types::{PyAnyMethods, PyString, PyTuple}; -use pyo3::{Bound, IntoPyObject}; +use pyo3::types::{PyAnyMethods, PyCFunction, PyDict, PyString, PyTuple, PyTupleMethods}; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; -use re_tuid::Tuid; +use pyo3::{Bound, IntoPyObject}; use tokio_stream::StreamExt as _; use re_chunk_store::{ChunkStore, ChunkStoreHandle}; @@ -28,7 +27,6 @@ use re_protos::manifest_registry::v1alpha1::{ }; use re_sdk::{ComponentDescriptor, ComponentName}; -use crate::catalog::ConnectionHandle; use crate::catalog::{ dataframe_query::PyDataframeQueryView, to_py_err, PyEntry, VectorDistanceMetricLike, VectorLike, }; @@ -102,80 +100,83 @@ impl PyDataset { } #[getter] - fn partition_url_udf( - self_: PyRef<'_, Self> - ) -> PyResult> { - + fn partition_url_udf(self_: PyRef<'_, Self>) -> PyResult> { let super_ = self_.as_super(); let connection = super_.client.borrow(self_.py()).connection().clone(); - let py = self_.py(); - #[pyclass] - struct PartitionUrlInner { - pub connection: ConnectionHandle, - pub dataset_id: Tuid, - } + let origin = connection.origin().clone(); + let dataset_id = super_.details.id.id; + + let partition_url_inner = move |args: &Bound<'_, PyTuple>, + _kwargs: Option<&Bound<'_, PyDict>>| + -> PyResult> { + let py = args.py(); + let partition_id_expr = args.get_borrowed_item(0)?; + let mut url = re_uri::DatasetDataUri { + origin: origin.clone(), + dataset_id, + partition_id: "default".to_owned(), // to be replaced during loop + + //TODO(ab): add support for these two + time_range: None, + fragment: Default::default(), + }; - #[pymethods] - impl PartitionUrlInner { - pub fn __call__(&self, py: Python<'_>, partition_id_expr: &Bound<'_, PyAny>) -> PyResult> { - - let mut url = re_uri::DatasetDataUri { - origin: self.connection.origin().clone(), - dataset_id: self.dataset_id, - partition_id: "default".to_owned(), // to be replaced during loop - - //TODO(ab): add support for these two - time_range: None, - fragment: Default::default(), - }; - - let array_data = ArrayData::from_pyarrow_bound(partition_id_expr)?; - - match array_data.data_type() { - DataType::Utf8 => { - let str_array = StringArray::from(array_data); - let str_iter = str_array.iter().map(|maybe_id| { - maybe_id.map(|id| { - url.partition_id = id.to_owned(); - url.to_string() - }) - }); - let output_array: ArrayRef = Arc::new(str_iter.collect::()); - output_array.to_data().to_pyarrow(py) - } - DataType::Utf8View => { - let str_array = StringViewArray::from(array_data); - let str_iter = str_array.iter().map(|maybe_id| { - maybe_id.map(|id| { - url.partition_id = id.to_owned(); - url.to_string() - }) - }); - let output_array: ArrayRef = Arc::new(str_iter.collect::()); - output_array.to_data().to_pyarrow(py) - } - _ => exec_err!( - "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", - array_data.data_type() - ) - .map_err(to_py_err), + let array_data = ArrayData::from_pyarrow_bound(partition_id_expr.as_ref())?; + + match array_data.data_type() { + DataType::Utf8 => { + let str_array = StringArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); + output_array.to_data().to_pyarrow(py) } + DataType::Utf8View => { + let str_array = StringViewArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); + output_array.to_data().to_pyarrow(py) + } + _ => exec_err!( + "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", + array_data.data_type() + ) + .map_err(to_py_err), } - } + }; - let udf_factory = py.import("datafusion").and_then(|datafusion| datafusion.getattr("udf"))?; - let pa_utf8 = py.import("pyarrow").and_then(|pa| pa.getattr("utf8")?.call0())?; + let py = self_.py(); - let inner = PartitionUrlInner { - connection, - dataset_id: super_.details.id.id, - }; + let udf_factory = py + .import("datafusion") + .and_then(|datafusion| datafusion.getattr("udf"))?; + let pa_utf8 = py + .import("pyarrow") + .and_then(|pa| pa.getattr("utf8")?.call0())?; + + let inner = PyCFunction::new_closure(py, None, None, partition_url_inner)?; let bound_inner = inner.into_pyobject(py)?; let py_stable = PyString::new(py, "stable"); - - // df.udf(dataset.partition_url_udf, pa.utf8(), pa.utf8(), 'stable') - let args = PyTuple::new(py, vec![bound_inner.as_any(), pa_utf8.as_any(), pa_utf8.as_any(), py_stable.as_any()])?; + + let args = PyTuple::new( + py, + vec![ + bound_inner.as_any(), + pa_utf8.as_any(), + pa_utf8.as_any(), + py_stable.as_any(), + ], + )?; Ok(udf_factory.call1(args)?.unbind()) } From 2c9614854d4a0a3bc17e0f39f57776cdc36b6f34 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 7 Apr 2025 16:14:03 -0400 Subject: [PATCH 5/6] Move datafusion udf creation into helper function so we can pass a closure and argument types and have the helper function take care of the rest --- rerun_py/src/catalog/dataset.rs | 38 ++++------- rerun_py/src/datafusion_utils.rs | 107 +++++++++++++++++++++++++++++++ rerun_py/src/lib.rs | 1 + 3 files changed, 119 insertions(+), 27 deletions(-) create mode 100644 rerun_py/src/datafusion_utils.rs diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index 0314396ddc32..0383f7e6baa4 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _}; -use datafusion::common::exec_err; -use pyo3::types::{PyAnyMethods, PyCFunction, PyDict, PyString, PyTuple, PyTupleMethods}; +use datafusion::{common::exec_err, logical_expr_common::signature::Volatility}; +use pyo3::types::{PyDict, PyTuple, PyTupleMethods}; +use pyo3::Bound; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; -use pyo3::{Bound, IntoPyObject}; use tokio_stream::StreamExt as _; use re_chunk_store::{ChunkStore, ChunkStoreHandle}; @@ -33,6 +33,7 @@ use crate::catalog::{ use crate::dataframe::{ PyComponentColumnSelector, PyDataFusionTable, PyIndexColumnSelector, PyRecording, }; +use crate::datafusion_utils::create_datafusion_scalar_udf; use crate::utils::wait_for_future; /// A dataset entry in the catalog. @@ -155,30 +156,13 @@ impl PyDataset { } }; - let py = self_.py(); - - let udf_factory = py - .import("datafusion") - .and_then(|datafusion| datafusion.getattr("udf"))?; - let pa_utf8 = py - .import("pyarrow") - .and_then(|pa| pa.getattr("utf8")?.call0())?; - - let inner = PyCFunction::new_closure(py, None, None, partition_url_inner)?; - let bound_inner = inner.into_pyobject(py)?; - let py_stable = PyString::new(py, "stable"); - - let args = PyTuple::new( - py, - vec![ - bound_inner.as_any(), - pa_utf8.as_any(), - pa_utf8.as_any(), - py_stable.as_any(), - ], - )?; - - Ok(udf_factory.call1(args)?.unbind()) + create_datafusion_scalar_udf( + self_.py(), + partition_url_inner, + &[&DataType::Utf8], + &DataType::Utf8, + Volatility::Stable, + ) } /// Register a RRD URI to the dataset. diff --git a/rerun_py/src/datafusion_utils.rs b/rerun_py/src/datafusion_utils.rs new file mode 100644 index 000000000000..d5ef338b3a62 --- /dev/null +++ b/rerun_py/src/datafusion_utils.rs @@ -0,0 +1,107 @@ +use arrow::datatypes::DataType; +use datafusion::logical_expr::Volatility; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::PyModule; +use pyo3::types::{PyAnyMethods, PyCFunction, PyDict, PyList, PyString, PyTuple}; +use pyo3::{Bound, IntoPyObject as _, Py, PyAny, PyResult, Python}; + +/// This is a helper function to initialize the required pyarrow data +/// types for passing into datafusion.udf() +fn data_type_to_pyarrow_obj<'py>( + pa: &Bound<'py, PyModule>, + data_type: &DataType, +) -> PyResult> { + match data_type { + DataType::Null => pa.getattr("utf8")?.call0(), + DataType::Boolean => pa.getattr("bool_")?.call0(), + DataType::Int8 => pa.getattr("int8")?.call0(), + DataType::Int16 => pa.getattr("int16")?.call0(), + DataType::Int32 => pa.getattr("int32")?.call0(), + DataType::Int64 => pa.getattr("int64")?.call0(), + DataType::UInt8 => pa.getattr("uint8")?.call0(), + DataType::UInt16 => pa.getattr("uint16")?.call0(), + DataType::UInt32 => pa.getattr("uint32")?.call0(), + DataType::UInt64 => pa.getattr("uint64")?.call0(), + DataType::Float16 => pa.getattr("float16")?.call0(), + DataType::Float32 => pa.getattr("float32")?.call0(), + DataType::Float64 => pa.getattr("float64")?.call0(), + DataType::Date32 => pa.getattr("date32")?.call0(), + DataType::Date64 => pa.getattr("date64")?.call0(), + DataType::Binary => pa.getattr("binary")?.call0(), + DataType::LargeBinary => pa.getattr("large_binary")?.call0(), + DataType::BinaryView => pa.getattr("binary_view")?.call0(), + DataType::Utf8 => pa.getattr("string")?.call0(), + DataType::LargeUtf8 => pa.getattr("large_string")?.call0(), + DataType::Utf8View => pa.getattr("string_view")?.call0(), + + DataType::FixedSizeBinary(_) + | DataType::Timestamp(_, _) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::List(_) + | DataType::ListView(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::LargeListView(_) + | DataType::Struct(_) + | DataType::Union(_, _) + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Map(_, _) + | DataType::RunEndEncoded(_, _) => { + Err(PyRuntimeError::new_err("Data type is not supported")) + } + } +} + +/// This helper function will take a closure and turn it into a datafusion scalar UDF. +/// It calls the python datafusion.udf() function. These may get removed once +/// https://github.com/apache/datafusion/issues/14562 and the associated support +/// in datafusion-python are completed. +pub fn create_datafusion_scalar_udf( + py: Python<'_>, + closure: F, + arg_types: &[&DataType], + return_type: &DataType, + volatility: Volatility, +) -> PyResult> +where + F: Fn(&Bound<'_, PyTuple>, Option<&Bound<'_, PyDict>>) -> PyResult> + Send + 'static, +{ + let udf_factory = py + .import("datafusion") + .and_then(|datafusion| datafusion.getattr("udf"))?; + let pyarrow_module = py.import("pyarrow")?; + let arg_types = arg_types + .iter() + .map(|arg_type| data_type_to_pyarrow_obj(&pyarrow_module, *arg_type)) + .collect::>>()?; + + let arg_types = PyList::new(py, arg_types)?; + let return_type = data_type_to_pyarrow_obj(&pyarrow_module, return_type)?; + + let inner = PyCFunction::new_closure(py, None, None, closure)?; + let bound_inner = inner.into_pyobject(py)?; + + let volatility = match volatility { + Volatility::Immutable => "immutable", + Volatility::Stable => "stable", + Volatility::Volatile => "volatile", + }; + let py_stable = PyString::new(py, volatility); + + let args = PyTuple::new( + py, + vec![ + bound_inner.as_any(), + arg_types.as_any(), + return_type.as_any(), + py_stable.as_any(), + ], + )?; + + Ok(udf_factory.call1(args)?.unbind()) +} diff --git a/rerun_py/src/lib.rs b/rerun_py/src/lib.rs index 733ed2284fc6..89ee9e696edb 100644 --- a/rerun_py/src/lib.rs +++ b/rerun_py/src/lib.rs @@ -16,6 +16,7 @@ static GLOBAL: AccountingAllocator = mod arrow; mod catalog; mod dataframe; +mod datafusion_utils; mod python_bridge; mod utils; mod video; From f367af1930c58c5cd525c57234b1e914032fd915 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 9 Apr 2025 08:36:34 -0400 Subject: [PATCH 6/6] clippy warnings --- rerun_py/src/catalog/dataset.rs | 18 +++++++----------- rerun_py/src/datafusion_utils.rs | 14 +++++++------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index 0383f7e6baa4..45f9ac982fd0 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -4,7 +4,7 @@ use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArra use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _}; use datafusion::{common::exec_err, logical_expr_common::signature::Volatility}; -use pyo3::types::{PyDict, PyTuple, PyTupleMethods}; +use pyo3::types::{PyDict, PyTuple, PyTupleMethods as _}; use pyo3::Bound; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; use tokio_stream::StreamExt as _; @@ -62,7 +62,7 @@ impl PyDataset { Ok(schema.into()) } - /// Return the partition table as a Datafusion table provider. + /// Return the partition table as a DataFusion table provider. fn partition_table(self_: PyRef<'_, Self>) -> PyResult { let super_ = self_.as_super(); let connection = super_.client.borrow(self_.py()).connection().clone(); @@ -209,9 +209,7 @@ impl PyDataset { while let Some(chunk) = chunk_stream.next().await { let chunk = chunk.map_err(to_py_err)?; - store - .insert_chunk(&std::sync::Arc::new(chunk)) - .map_err(to_py_err)?; + store.insert_chunk(&Arc::new(chunk)).map_err(to_py_err)?; } Ok(store) @@ -398,7 +396,7 @@ impl PyDataset { let component_descriptor = ComponentDescriptor::new(column_selector.component_name.clone()); let schema = arrow::datatypes::Schema::new_with_metadata( - vec![Field::new("items", arrow::datatypes::DataType::Utf8, false)], + vec![Field::new("items", DataType::Utf8, false)], Default::default(), ); @@ -415,11 +413,9 @@ impl PyDataset { component: Some(component_descriptor.into()), }), properties: Some(IndexQueryProperties { - props: Some( - re_protos::manifest_registry::v1alpha1::index_query_properties::Props::Inverted( - InvertedIndexQuery {}, - ), - ), + props: Some(index_query_properties::Props::Inverted( + InvertedIndexQuery {}, + )), }), query: Some( query diff --git a/rerun_py/src/datafusion_utils.rs b/rerun_py/src/datafusion_utils.rs index d5ef338b3a62..2ce877e7b6d8 100644 --- a/rerun_py/src/datafusion_utils.rs +++ b/rerun_py/src/datafusion_utils.rs @@ -2,11 +2,11 @@ use arrow::datatypes::DataType; use datafusion::logical_expr::Volatility; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::PyModule; -use pyo3::types::{PyAnyMethods, PyCFunction, PyDict, PyList, PyString, PyTuple}; +use pyo3::types::{PyAnyMethods as _, PyCFunction, PyDict, PyList, PyString, PyTuple}; use pyo3::{Bound, IntoPyObject as _, Py, PyAny, PyResult, Python}; /// This is a helper function to initialize the required pyarrow data -/// types for passing into datafusion.udf() +/// types for passing into `datafusion.udf()` fn data_type_to_pyarrow_obj<'py>( pa: &Bound<'py, PyModule>, data_type: &DataType, @@ -57,10 +57,10 @@ fn data_type_to_pyarrow_obj<'py>( } } -/// This helper function will take a closure and turn it into a datafusion scalar UDF. -/// It calls the python datafusion.udf() function. These may get removed once -/// https://github.com/apache/datafusion/issues/14562 and the associated support -/// in datafusion-python are completed. +/// This helper function will take a closure and turn it into a `DataFusion` scalar UDF. +/// It calls the python `datafusion.udf()` function. These may get removed once +/// and the associated support +/// in `datafusion-python` are completed. pub fn create_datafusion_scalar_udf( py: Python<'_>, closure: F, @@ -77,7 +77,7 @@ where let pyarrow_module = py.import("pyarrow")?; let arg_types = arg_types .iter() - .map(|arg_type| data_type_to_pyarrow_obj(&pyarrow_module, *arg_type)) + .map(|arg_type| data_type_to_pyarrow_obj(&pyarrow_module, arg_type)) .collect::>>()?; let arg_types = PyList::new(py, arg_types)?;