diff --git a/python/delta-kernel-rust-sharing-wrapper/Cargo.toml b/python/delta-kernel-rust-sharing-wrapper/Cargo.toml index 4049de564..0232898b9 100644 --- a/python/delta-kernel-rust-sharing-wrapper/Cargo.toml +++ b/python/delta-kernel-rust-sharing-wrapper/Cargo.toml @@ -10,12 +10,15 @@ name = "delta_kernel_rust_sharing_wrapper" crate-type = ["cdylib"] [dependencies] -arrow = { version = "54.0.0", features = ["pyarrow"] } +arrow = { version = "54.0.0", features = ["pyarrow", "ffi"] } delta_kernel = { version = "0.6.1", features = ["cloud", "default-engine"]} openssl = { version = "0.10", features = ["vendored"] } +polars = { version = "0.46.0", features = ["lazy"] } +polars-arrow = "0.46.0" +pyo3-polars = { version = "0.20.0", features = ["dtype-decimal", "lazy"] } url = "2" [dependencies.pyo3] version = "0.23.3" # "abi3-py38" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.8 -features = ["abi3-py38"] +features = ["abi3-py38", "rust_decimal"] diff --git a/python/delta-kernel-rust-sharing-wrapper/src/lib.rs b/python/delta-kernel-rust-sharing-wrapper/src/lib.rs index cfa6c6783..bf0b1ab0a 100644 --- a/python/delta-kernel-rust-sharing-wrapper/src/lib.rs +++ b/python/delta-kernel-rust-sharing-wrapper/src/lib.rs @@ -1,8 +1,10 @@ +use std::mem::transmute; use std::sync::Arc; use arrow::compute::filter_record_batch; use arrow::datatypes::SchemaRef as ArrowSchemaRef; use arrow::error::ArrowError; +use arrow::ffi::to_ffi; use arrow::pyarrow::PyArrowType; use arrow::record_batch::{RecordBatch, RecordBatchIterator, RecordBatchReader}; @@ -17,6 +19,13 @@ use delta_kernel::Error as KernelError; use delta_kernel::{engine::arrow_data::ArrowEngineData, schema::StructType}; use delta_kernel::{DeltaResult, Engine}; +use polars::error::PolarsError; +use polars::prelude::{concat, DataFrame, IntoLazy, Series, UnionArgs}; + +use polars_arrow::ffi::{import_array_from_c, import_field_from_c}; + +use pyo3_polars::{PyDataFrame, PyLazyFrame}; + use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -24,21 +33,35 @@ use url::Url; use std::collections::HashMap; -struct PyKernelError(KernelError); +enum PyRustError { + PyKernelError(KernelError), + PyPolarsError(PolarsError), +} + -impl From for PyErr { - fn from(error: PyKernelError) -> Self { - PyValueError::new_err(format!("Kernel error: {}", error.0)) +impl From for PyErr { + fn from(error: PyRustError) -> Self { + let msg = match error { + PyRustError::PyKernelError(e) => format!("Kernel error: {}", e), + PyRustError::PyPolarsError(e) => format!("Polars error: {}", e), + }; + PyValueError::new_err(msg) } } -impl From for PyKernelError { +impl From for PyRustError { fn from(delta_kernel_error: KernelError) -> Self { - Self(delta_kernel_error) + Self::PyKernelError(delta_kernel_error) + } +} + +impl From for PyRustError { + fn from(polars_error: PolarsError) -> Self { + Self::PyPolarsError(polars_error) } } -type DeltaPyResult = std::result::Result; +type DeltaPyResult = std::result::Result; #[pyclass] struct Table(delta_kernel::Table); @@ -117,6 +140,44 @@ fn try_create_record_batch_iter( RecordBatchIterator::new(record_batches, result_schema) } +unsafe fn record_batch_to_dataframe(batch: &RecordBatch) -> Result { + let mut columns = Vec::with_capacity(batch.num_columns()); + + // Arrow stores data by columns, therefore need to be Zero-copied by column + for (i, col) in batch.columns().iter().enumerate() { + // Convert to ArrayData (arrow-rs) + let array = col.to_data(); + + // Convert to ffi with arrow-rs + let (out_array, out_schema) = to_ffi(&array).unwrap(); + + // Import field from ffi with polars + let field = unsafe { + import_field_from_c(transmute::< + &arrow::ffi::FFI_ArrowSchema, + &polars_arrow::ffi::ArrowSchema, + >(&out_schema)) + }?; + + // Import data from ffi with polars + let data = unsafe { + import_array_from_c( + transmute::( + out_array, + ), + field.dtype().clone(), + ) + }?; + + // Create Polars series from arrow column + columns.push(Series::from_arrow( + batch.schema().field(i).name().into(), + data, + )?); + } + Ok(DataFrame::from_iter(columns)) +} + #[pyclass] struct Scan(delta_kernel::scan::Scan); @@ -131,6 +192,24 @@ impl Scan { let record_batch_iter = try_create_record_batch_iter(results, result_schema); Ok(PyArrowType(Box::new(record_batch_iter))) } + + fn execute_polars( + &self, + engine_interface: &PythonInterface, + ) -> DeltaPyResult { + let result_schema: ArrowSchemaRef = try_get_schema(self.0.schema())?; + let results = self.0.execute(engine_interface.0.clone())?; + let record_batch_iter = try_create_record_batch_iter(results, result_schema); + let mut dfs = Vec::new(); + for rb in record_batch_iter { + unsafe { + let df = record_batch_to_dataframe(&rb.map_err(KernelError::Arrow)?)?; + dfs.push(df.lazy()) + }; + }; + let dfs_concat = concat(dfs, UnionArgs::default()); + Ok(PyDataFrame(dfs_concat?.collect()?)) + } } #[pyclass] diff --git a/python/delta_sharing/__init__.py b/python/delta_sharing/__init__.py index 2f4011076..a53ded709 100644 --- a/python/delta_sharing/__init__.py +++ b/python/delta_sharing/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # -from delta_sharing.delta_sharing import SharingClient, load_as_pandas, load_as_spark +from delta_sharing.delta_sharing import SharingClient, load_as_pandas, load_as_polars, load_as_spark from delta_sharing.delta_sharing import get_table_metadata, get_table_protocol, get_table_version from delta_sharing.delta_sharing import load_table_changes_as_pandas, load_table_changes_as_spark from delta_sharing.protocol import Share, Schema, Table @@ -30,6 +30,7 @@ "get_table_protocol", "get_table_version", "load_as_pandas", + "load_as_polars", "load_as_spark", "load_table_changes_as_pandas", "load_table_changes_as_spark", diff --git a/python/delta_sharing/delta_sharing.py b/python/delta_sharing/delta_sharing.py index 966fa8fe5..5a00d9fcb 100644 --- a/python/delta_sharing/delta_sharing.py +++ b/python/delta_sharing/delta_sharing.py @@ -18,6 +18,7 @@ from pathlib import Path import pandas as pd +import polars as pl from delta_sharing.protocol import CdfOptions, Protocol, Metadata @@ -147,6 +148,28 @@ def load_as_pandas( ).to_pandas() +def load_as_polars( + url: str, + limit: Optional[int] = None, + version: Optional[int] = None, + timestamp: Optional[str] = None, + jsonPredicateHints: Optional[str] = None, + use_delta_format: Optional[bool] = None, + convert_in_batches = False, +) -> pl.DataFrame: + profile_json, share, schema, table = _parse_url(url) + profile = DeltaSharingProfile.read_from_file(profile_json) + return DeltaSharingReader( + table=Table(name=table, share=share, schema=schema), + rest_client=DataSharingRestClient(profile), + jsonPredicateHints=jsonPredicateHints, + limit=limit, + version=version, + timestamp=timestamp, + use_delta_format=use_delta_format, + ).to_polars() + + def load_as_spark( url: str, version: Optional[int] = None, timestamp: Optional[str] = None ) -> "PySparkDataFrame": # noqa: F821 diff --git a/python/delta_sharing/reader.py b/python/delta_sharing/reader.py index d021ddeb2..878e3ecb8 100644 --- a/python/delta_sharing/reader.py +++ b/python/delta_sharing/reader.py @@ -23,6 +23,7 @@ import fsspec import os import pandas as pd +import polars as pl import pyarrow as pa import tempfile from pyarrow.dataset import dataset @@ -150,6 +151,46 @@ def __to_pandas_kernel(self): return result + def __to_polars_kernel(self): + self._rest_client.set_delta_format_header() + response = self._rest_client.list_files_in_table( + self._table, + predicateHints=self._predicateHints, + jsonPredicateHints=self._jsonPredicateHints, + limitHint=self._limit, + version=self._version, + timestamp=self._timestamp, + ) + + lines = response.lines + # Create a temporary directory using the tempfile module + temp_dir = tempfile.TemporaryDirectory() + table_path = self.__write_temp_delta_log_snapshot(temp_dir.name, lines) + num_files = len(lines) + + # Invoke delta-kernel-rust to return the pandas dataframe + interface = delta_kernel_rust_sharing_wrapper.PythonInterface(table_path) + table = delta_kernel_rust_sharing_wrapper.Table(table_path) + snapshot = table.snapshot(interface) + scan = delta_kernel_rust_sharing_wrapper.ScanBuilder(snapshot).build() + + # The table is empty so use the schema to return an empty table with correct col names + if num_files == 0: + schema = scan.execute(interface).schema + return pl.DataFrame(schema=schema.names) + + result = scan.execute_polars(interface) + + # Apply residual limit that was not handled from server pushdown + if self._limit: + result = result.head(self._limit) + + # Delete the temp folder explicitly and remove the delta format from header + temp_dir.cleanup() + self._rest_client.remove_delta_format_header() + + return result + def to_pandas(self) -> pd.DataFrame: response_format = "" # If client does not specify which format to use, autoresolve it. @@ -215,6 +256,52 @@ def to_pandas(self) -> pd.DataFrame: return merged[[col_map[field["name"].lower()] for field in schema_json["fields"]]] + def to_polars(self) -> pl.DataFrame: + response_format = "" + # If client does not specify which format to use, autoresolve it. + # Otherwise use the specified format. + if self._use_delta_format is None: + response_format = self._rest_client.autoresolve_query_format(self._table) + elif self._use_delta_format: + response_format = response_format = DataSharingRestClient.DELTA_FORMAT + + # If the response format is delta, use delta kernel rust + if response_format == DataSharingRestClient.DELTA_FORMAT: + return self.__to_polars_kernel() + + # Otherwise use the standard approach + response = self._rest_client.list_files_in_table( + self._table, + predicateHints=self._predicateHints, + jsonPredicateHints=self._jsonPredicateHints, + limitHint=self._limit, + version=self._version, + timestamp=self._timestamp, + ) + + schema_json = loads(response.metadata.schema_string) + + if len(response.add_files) == 0 or self._limit == 0: + return pl.from_pandas(get_empty_table(schema_json)) + + converters = to_converters(schema_json) + + pdfs = [ + DeltaSharingReader._to_polars(file, converters, False) + for file in response.add_files + ] + + merged = pl.concat(pdfs, how='diagonal_relaxed') + + if self._limit: + merged = merged.head(self._limit) + + col_map = {} + for col in merged.collect_schema().names(): + col_map[col.lower()] = col + + return merged.select([col_map[field["name"].lower()] for field in schema_json["fields"]]).collect() + def __write_temp_delta_log_snapshot(self, temp_dir: str, lines: List[str]) -> str: delta_log_dir_name = temp_dir table_path = "file:///" + delta_log_dir_name @@ -509,6 +596,52 @@ def _to_pandas( pdf[DeltaSharingReader._commit_timestamp_col_name()] = action.timestamp return pdf + @staticmethod + def _to_polars( + action: FileAction, + converters: Dict[str, Callable[[str], Any]], + for_cdf: bool, + ) -> pl.LazyFrame: + url = urlparse(action.url) + if "storage.googleapis.com" in (url.netloc.lower()): + # Apply the yarl patch for GCS pre-signed urls + import delta_sharing._yarl_patch # noqa: F401 + + pdf = pl.scan_parquet(source=action.url) + + lowered_cols = set() + for col in pdf.collect_schema().names(): + lowered_cols.add(col.lower()) + + for col, converter in converters.items(): + lowered = col.lower() + if lowered not in lowered_cols: + if col in action.partition_values: + if converter is not None: + pdf = pdf.with_columns(converter(action.partition_values[col])) + else: + raise ValueError("Cannot partition on binary or complex columns") + else: + pdf = pdf.with_columns(pl.lit(None).alias(col)) + + if for_cdf: + columns = [] + # Add the change type col name to non cdc actions. + if not isinstance(action, AddCdcFile): + columns.append(pl.lit(action.get_change_type_col_value()).alias(DeltaSharingReader._change_type_col_name())) + + # If available, add timestamp and version columns from the action. + # All rows of the dataframe will get the same value. + if action.version is not None: + assert DeltaSharingReader._commit_version_col_name() not in pdf.columns + columns.append(pl.lit(action.version).alias(DeltaSharingReader._commit_version_col_name())) + + if action.timestamp is not None: + assert DeltaSharingReader._commit_timestamp_col_name() not in pdf.columns + columns.append(pl.lit(action.timestamp).alias(DeltaSharingReader._commit_timestamp_col_name())) + + pdf = pdf.with_columns(columns) + return pdf # The names of special delta columns for cdf. @staticmethod