diff --git a/rerun_py/src/catalog/dataset.rs b/rerun_py/src/catalog/dataset.rs index 2a83134e82dd..45f9ac982fd0 100644 --- a/rerun_py/src/catalog/dataset.rs +++ b/rerun_py/src/catalog/dataset.rs @@ -1,8 +1,11 @@ use std::sync::Arc; -use arrow::array::{RecordBatch, StringArray}; -use arrow::datatypes::{Field, Schema as ArrowSchema}; -use arrow::pyarrow::PyArrowType; +use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; +use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _}; +use datafusion::{common::exec_err, logical_expr_common::signature::Volatility}; +use pyo3::types::{PyDict, PyTuple, PyTupleMethods as _}; +use pyo3::Bound; use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python}; use tokio_stream::StreamExt as _; @@ -30,6 +33,7 @@ use crate::catalog::{ use crate::dataframe::{ PyComponentColumnSelector, PyDataFusionTable, PyIndexColumnSelector, PyRecording, }; +use crate::datafusion_utils::create_datafusion_scalar_udf; use crate::utils::wait_for_future; /// A dataset entry in the catalog. @@ -58,7 +62,7 @@ impl PyDataset { Ok(schema.into()) } - /// Return the partition table as a Datafusion table provider. + /// Return the partition table as a DataFusion table provider. fn partition_table(self_: PyRef<'_, Self>) -> PyResult { let super_ = self_.as_super(); let connection = super_.client.borrow(self_.py()).connection().clone(); @@ -96,6 +100,71 @@ impl PyDataset { .to_string() } + #[getter] + fn partition_url_udf(self_: PyRef<'_, Self>) -> PyResult> { + let super_ = self_.as_super(); + let connection = super_.client.borrow(self_.py()).connection().clone(); + + let origin = connection.origin().clone(); + let dataset_id = super_.details.id.id; + + let partition_url_inner = move |args: &Bound<'_, PyTuple>, + _kwargs: Option<&Bound<'_, PyDict>>| + -> PyResult> { + let py = args.py(); + let partition_id_expr = args.get_borrowed_item(0)?; + let mut url = re_uri::DatasetDataUri { + origin: origin.clone(), + dataset_id, + partition_id: "default".to_owned(), // to be replaced during loop + + //TODO(ab): add support for these two + time_range: None, + fragment: Default::default(), + }; + + let array_data = ArrayData::from_pyarrow_bound(partition_id_expr.as_ref())?; + + match array_data.data_type() { + DataType::Utf8 => { + let str_array = StringArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); + output_array.to_data().to_pyarrow(py) + } + DataType::Utf8View => { + let str_array = StringViewArray::from(array_data); + let str_iter = str_array.iter().map(|maybe_id| { + maybe_id.map(|id| { + url.partition_id = id.to_owned(); + url.to_string() + }) + }); + let output_array: ArrayRef = Arc::new(str_iter.collect::()); + output_array.to_data().to_pyarrow(py) + } + _ => exec_err!( + "Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}", + array_data.data_type() + ) + .map_err(to_py_err), + } + }; + + create_datafusion_scalar_udf( + self_.py(), + partition_url_inner, + &[&DataType::Utf8], + &DataType::Utf8, + Volatility::Stable, + ) + } + /// Register a RRD URI to the dataset. fn register(self_: PyRef<'_, Self>, recording_uri: String) -> PyResult<()> { let super_ = self_.as_super(); @@ -140,9 +209,7 @@ impl PyDataset { while let Some(chunk) = chunk_stream.next().await { let chunk = chunk.map_err(to_py_err)?; - store - .insert_chunk(&std::sync::Arc::new(chunk)) - .map_err(to_py_err)?; + store.insert_chunk(&Arc::new(chunk)).map_err(to_py_err)?; } Ok(store) @@ -329,7 +396,7 @@ impl PyDataset { let component_descriptor = ComponentDescriptor::new(column_selector.component_name.clone()); let schema = arrow::datatypes::Schema::new_with_metadata( - vec![Field::new("items", arrow::datatypes::DataType::Utf8, false)], + vec![Field::new("items", DataType::Utf8, false)], Default::default(), ); @@ -346,11 +413,9 @@ impl PyDataset { component: Some(component_descriptor.into()), }), properties: Some(IndexQueryProperties { - props: Some( - re_protos::manifest_registry::v1alpha1::index_query_properties::Props::Inverted( - InvertedIndexQuery {}, - ), - ), + props: Some(index_query_properties::Props::Inverted( + InvertedIndexQuery {}, + )), }), query: Some( query diff --git a/rerun_py/src/datafusion_utils.rs b/rerun_py/src/datafusion_utils.rs new file mode 100644 index 000000000000..2ce877e7b6d8 --- /dev/null +++ b/rerun_py/src/datafusion_utils.rs @@ -0,0 +1,107 @@ +use arrow::datatypes::DataType; +use datafusion::logical_expr::Volatility; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::PyModule; +use pyo3::types::{PyAnyMethods as _, PyCFunction, PyDict, PyList, PyString, PyTuple}; +use pyo3::{Bound, IntoPyObject as _, Py, PyAny, PyResult, Python}; + +/// This is a helper function to initialize the required pyarrow data +/// types for passing into `datafusion.udf()` +fn data_type_to_pyarrow_obj<'py>( + pa: &Bound<'py, PyModule>, + data_type: &DataType, +) -> PyResult> { + match data_type { + DataType::Null => pa.getattr("utf8")?.call0(), + DataType::Boolean => pa.getattr("bool_")?.call0(), + DataType::Int8 => pa.getattr("int8")?.call0(), + DataType::Int16 => pa.getattr("int16")?.call0(), + DataType::Int32 => pa.getattr("int32")?.call0(), + DataType::Int64 => pa.getattr("int64")?.call0(), + DataType::UInt8 => pa.getattr("uint8")?.call0(), + DataType::UInt16 => pa.getattr("uint16")?.call0(), + DataType::UInt32 => pa.getattr("uint32")?.call0(), + DataType::UInt64 => pa.getattr("uint64")?.call0(), + DataType::Float16 => pa.getattr("float16")?.call0(), + DataType::Float32 => pa.getattr("float32")?.call0(), + DataType::Float64 => pa.getattr("float64")?.call0(), + DataType::Date32 => pa.getattr("date32")?.call0(), + DataType::Date64 => pa.getattr("date64")?.call0(), + DataType::Binary => pa.getattr("binary")?.call0(), + DataType::LargeBinary => pa.getattr("large_binary")?.call0(), + DataType::BinaryView => pa.getattr("binary_view")?.call0(), + DataType::Utf8 => pa.getattr("string")?.call0(), + DataType::LargeUtf8 => pa.getattr("large_string")?.call0(), + DataType::Utf8View => pa.getattr("string_view")?.call0(), + + DataType::FixedSizeBinary(_) + | DataType::Timestamp(_, _) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::List(_) + | DataType::ListView(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::LargeListView(_) + | DataType::Struct(_) + | DataType::Union(_, _) + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Map(_, _) + | DataType::RunEndEncoded(_, _) => { + Err(PyRuntimeError::new_err("Data type is not supported")) + } + } +} + +/// This helper function will take a closure and turn it into a `DataFusion` scalar UDF. +/// It calls the python `datafusion.udf()` function. These may get removed once +/// and the associated support +/// in `datafusion-python` are completed. +pub fn create_datafusion_scalar_udf( + py: Python<'_>, + closure: F, + arg_types: &[&DataType], + return_type: &DataType, + volatility: Volatility, +) -> PyResult> +where + F: Fn(&Bound<'_, PyTuple>, Option<&Bound<'_, PyDict>>) -> PyResult> + Send + 'static, +{ + let udf_factory = py + .import("datafusion") + .and_then(|datafusion| datafusion.getattr("udf"))?; + let pyarrow_module = py.import("pyarrow")?; + let arg_types = arg_types + .iter() + .map(|arg_type| data_type_to_pyarrow_obj(&pyarrow_module, arg_type)) + .collect::>>()?; + + let arg_types = PyList::new(py, arg_types)?; + let return_type = data_type_to_pyarrow_obj(&pyarrow_module, return_type)?; + + let inner = PyCFunction::new_closure(py, None, None, closure)?; + let bound_inner = inner.into_pyobject(py)?; + + let volatility = match volatility { + Volatility::Immutable => "immutable", + Volatility::Stable => "stable", + Volatility::Volatile => "volatile", + }; + let py_stable = PyString::new(py, volatility); + + let args = PyTuple::new( + py, + vec![ + bound_inner.as_any(), + arg_types.as_any(), + return_type.as_any(), + py_stable.as_any(), + ], + )?; + + Ok(udf_factory.call1(args)?.unbind()) +} diff --git a/rerun_py/src/lib.rs b/rerun_py/src/lib.rs index 733ed2284fc6..89ee9e696edb 100644 --- a/rerun_py/src/lib.rs +++ b/rerun_py/src/lib.rs @@ -16,6 +16,7 @@ static GLOBAL: AccountingAllocator = mod arrow; mod catalog; mod dataframe; +mod datafusion_utils; mod python_bridge; mod utils; mod video;