Skip to content

DataSet parition ID to URL user defined function #9530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 78 additions & 13 deletions rerun_py/src/catalog/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;

use arrow::array::{RecordBatch, StringArray};
use arrow::datatypes::{Field, Schema as ArrowSchema};
use arrow::pyarrow::PyArrowType;
use arrow::array::{ArrayData, ArrayRef, RecordBatch, StringArray, StringViewArray};
use arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use arrow::pyarrow::{FromPyArrow as _, PyArrowType, ToPyArrow as _};
use datafusion::{common::exec_err, logical_expr_common::signature::Volatility};
use pyo3::types::{PyDict, PyTuple, PyTupleMethods as _};
use pyo3::Bound;
use pyo3::{exceptions::PyRuntimeError, pyclass, pymethods, Py, PyAny, PyRef, PyResult, Python};
use tokio_stream::StreamExt as _;

Expand Down Expand Up @@ -30,6 +33,7 @@ use crate::catalog::{
use crate::dataframe::{
PyComponentColumnSelector, PyDataFusionTable, PyIndexColumnSelector, PyRecording,
};
use crate::datafusion_utils::create_datafusion_scalar_udf;
use crate::utils::wait_for_future;

/// A dataset entry in the catalog.
Expand Down Expand Up @@ -58,7 +62,7 @@ impl PyDataset {
Ok(schema.into())
}

/// Return the partition table as a Datafusion table provider.
/// Return the partition table as a DataFusion table provider.
fn partition_table(self_: PyRef<'_, Self>) -> PyResult<PyDataFusionTable> {
let super_ = self_.as_super();
let connection = super_.client.borrow(self_.py()).connection().clone();
Expand Down Expand Up @@ -96,6 +100,71 @@ impl PyDataset {
.to_string()
}

#[getter]
fn partition_url_udf(self_: PyRef<'_, Self>) -> PyResult<Py<PyAny>> {
let super_ = self_.as_super();
let connection = super_.client.borrow(self_.py()).connection().clone();

let origin = connection.origin().clone();
let dataset_id = super_.details.id.id;

let partition_url_inner = move |args: &Bound<'_, PyTuple>,
_kwargs: Option<&Bound<'_, PyDict>>|
-> PyResult<Py<PyAny>> {
let py = args.py();
let partition_id_expr = args.get_borrowed_item(0)?;
let mut url = re_uri::DatasetDataUri {
origin: origin.clone(),
dataset_id,
partition_id: "default".to_owned(), // to be replaced during loop

//TODO(ab): add support for these two
time_range: None,
fragment: Default::default(),
};

let array_data = ArrayData::from_pyarrow_bound(partition_id_expr.as_ref())?;

match array_data.data_type() {
DataType::Utf8 => {
let str_array = StringArray::from(array_data);
let str_iter = str_array.iter().map(|maybe_id| {
maybe_id.map(|id| {
url.partition_id = id.to_owned();
url.to_string()
})
});
let output_array: ArrayRef = Arc::new(str_iter.collect::<StringArray>());
output_array.to_data().to_pyarrow(py)
}
DataType::Utf8View => {
let str_array = StringViewArray::from(array_data);
let str_iter = str_array.iter().map(|maybe_id| {
maybe_id.map(|id| {
url.partition_id = id.to_owned();
url.to_string()
})
});
let output_array: ArrayRef = Arc::new(str_iter.collect::<StringViewArray>());
output_array.to_data().to_pyarrow(py)
}
_ => exec_err!(
"Incorrect data type for partition_url_udf. Expected utf8 or utf8view. Received {}",
array_data.data_type()
)
.map_err(to_py_err),
}
};

create_datafusion_scalar_udf(
self_.py(),
partition_url_inner,
&[&DataType::Utf8],
&DataType::Utf8,
Volatility::Stable,
)
}

/// Register a RRD URI to the dataset.
fn register(self_: PyRef<'_, Self>, recording_uri: String) -> PyResult<()> {
let super_ = self_.as_super();
Expand Down Expand Up @@ -140,9 +209,7 @@ impl PyDataset {

while let Some(chunk) = chunk_stream.next().await {
let chunk = chunk.map_err(to_py_err)?;
store
.insert_chunk(&std::sync::Arc::new(chunk))
.map_err(to_py_err)?;
store.insert_chunk(&Arc::new(chunk)).map_err(to_py_err)?;
}

Ok(store)
Expand Down Expand Up @@ -329,7 +396,7 @@ impl PyDataset {
let component_descriptor = ComponentDescriptor::new(column_selector.component_name.clone());

let schema = arrow::datatypes::Schema::new_with_metadata(
vec![Field::new("items", arrow::datatypes::DataType::Utf8, false)],
vec![Field::new("items", DataType::Utf8, false)],
Default::default(),
);

Expand All @@ -346,11 +413,9 @@ impl PyDataset {
component: Some(component_descriptor.into()),
}),
properties: Some(IndexQueryProperties {
props: Some(
re_protos::manifest_registry::v1alpha1::index_query_properties::Props::Inverted(
InvertedIndexQuery {},
),
),
props: Some(index_query_properties::Props::Inverted(
InvertedIndexQuery {},
)),
}),
query: Some(
query
Expand Down
107 changes: 107 additions & 0 deletions rerun_py/src/datafusion_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use arrow::datatypes::DataType;
use datafusion::logical_expr::Volatility;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::PyModule;
use pyo3::types::{PyAnyMethods as _, PyCFunction, PyDict, PyList, PyString, PyTuple};
use pyo3::{Bound, IntoPyObject as _, Py, PyAny, PyResult, Python};

/// This is a helper function to initialize the required pyarrow data
/// types for passing into `datafusion.udf()`
fn data_type_to_pyarrow_obj<'py>(
pa: &Bound<'py, PyModule>,
data_type: &DataType,
) -> PyResult<Bound<'py, PyAny>> {
match data_type {
DataType::Null => pa.getattr("utf8")?.call0(),
DataType::Boolean => pa.getattr("bool_")?.call0(),
DataType::Int8 => pa.getattr("int8")?.call0(),
DataType::Int16 => pa.getattr("int16")?.call0(),
DataType::Int32 => pa.getattr("int32")?.call0(),
DataType::Int64 => pa.getattr("int64")?.call0(),
DataType::UInt8 => pa.getattr("uint8")?.call0(),
DataType::UInt16 => pa.getattr("uint16")?.call0(),
DataType::UInt32 => pa.getattr("uint32")?.call0(),
DataType::UInt64 => pa.getattr("uint64")?.call0(),
DataType::Float16 => pa.getattr("float16")?.call0(),
DataType::Float32 => pa.getattr("float32")?.call0(),
DataType::Float64 => pa.getattr("float64")?.call0(),
DataType::Date32 => pa.getattr("date32")?.call0(),
DataType::Date64 => pa.getattr("date64")?.call0(),
DataType::Binary => pa.getattr("binary")?.call0(),
DataType::LargeBinary => pa.getattr("large_binary")?.call0(),
DataType::BinaryView => pa.getattr("binary_view")?.call0(),
DataType::Utf8 => pa.getattr("string")?.call0(),
DataType::LargeUtf8 => pa.getattr("large_string")?.call0(),
DataType::Utf8View => pa.getattr("string_view")?.call0(),

DataType::FixedSizeBinary(_)
| DataType::Timestamp(_, _)
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::List(_)
| DataType::ListView(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
| DataType::LargeListView(_)
| DataType::Struct(_)
| DataType::Union(_, _)
| DataType::Dictionary(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::Map(_, _)
| DataType::RunEndEncoded(_, _) => {
Err(PyRuntimeError::new_err("Data type is not supported"))
}
}
}

/// This helper function will take a closure and turn it into a `DataFusion` scalar UDF.
/// It calls the python `datafusion.udf()` function. These may get removed once
/// <https://github.com/apache/datafusion/issues/14562> and the associated support
/// in `datafusion-python` are completed.
pub fn create_datafusion_scalar_udf<F>(
py: Python<'_>,
closure: F,
arg_types: &[&DataType],
return_type: &DataType,
volatility: Volatility,
) -> PyResult<Py<PyAny>>
where
F: Fn(&Bound<'_, PyTuple>, Option<&Bound<'_, PyDict>>) -> PyResult<Py<PyAny>> + Send + 'static,
{
let udf_factory = py
.import("datafusion")
.and_then(|datafusion| datafusion.getattr("udf"))?;
let pyarrow_module = py.import("pyarrow")?;
let arg_types = arg_types
.iter()
.map(|arg_type| data_type_to_pyarrow_obj(&pyarrow_module, arg_type))
.collect::<PyResult<Vec<_>>>()?;

let arg_types = PyList::new(py, arg_types)?;
let return_type = data_type_to_pyarrow_obj(&pyarrow_module, return_type)?;

let inner = PyCFunction::new_closure(py, None, None, closure)?;
let bound_inner = inner.into_pyobject(py)?;

let volatility = match volatility {
Volatility::Immutable => "immutable",
Volatility::Stable => "stable",
Volatility::Volatile => "volatile",
};
let py_stable = PyString::new(py, volatility);

let args = PyTuple::new(
py,
vec![
bound_inner.as_any(),
arg_types.as_any(),
return_type.as_any(),
py_stable.as_any(),
],
)?;

Ok(udf_factory.call1(args)?.unbind())
}
1 change: 1 addition & 0 deletions rerun_py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static GLOBAL: AccountingAllocator<mimalloc::MiMalloc> =
mod arrow;
mod catalog;
mod dataframe;
mod datafusion_utils;
mod python_bridge;
mod utils;
mod video;
Loading