Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/delta-kernel-rust-sharing-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ name = "delta_kernel_rust_sharing_wrapper"
crate-type = ["cdylib"]

[dependencies]
arrow = { version = "54.0.0", features = ["pyarrow"] }
arrow = { version = "54.0.0", features = ["pyarrow", "ffi"] }
delta_kernel = { version = "0.6.1", features = ["cloud", "default-engine"]}
openssl = { version = "0.10", features = ["vendored"] }
polars = { version = "0.46.0", features = ["lazy"] }
polars-arrow = "0.46.0"
pyo3-polars = { version = "0.20.0", features = ["dtype-decimal", "lazy"] }
url = "2"

[dependencies.pyo3]
version = "0.23.3"
# "abi3-py38" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.8
features = ["abi3-py38"]
features = ["abi3-py38", "rust_decimal"]
93 changes: 86 additions & 7 deletions python/delta-kernel-rust-sharing-wrapper/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::mem::transmute;
use std::sync::Arc;

use arrow::compute::filter_record_batch;
use arrow::datatypes::SchemaRef as ArrowSchemaRef;
use arrow::error::ArrowError;
use arrow::ffi::to_ffi;
use arrow::pyarrow::PyArrowType;
use arrow::record_batch::{RecordBatch, RecordBatchIterator, RecordBatchReader};

Expand All @@ -17,28 +19,49 @@ use delta_kernel::Error as KernelError;
use delta_kernel::{engine::arrow_data::ArrowEngineData, schema::StructType};
use delta_kernel::{DeltaResult, Engine};

use polars::error::PolarsError;
use polars::prelude::{concat, DataFrame, IntoLazy, Series, UnionArgs};

use polars_arrow::ffi::{import_array_from_c, import_field_from_c};

use pyo3_polars::{PyDataFrame, PyLazyFrame};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use url::Url;

use std::collections::HashMap;

struct PyKernelError(KernelError);
enum PyRustError {
PyKernelError(KernelError),
PyPolarsError(PolarsError),
}


impl From<PyKernelError> for PyErr {
fn from(error: PyKernelError) -> Self {
PyValueError::new_err(format!("Kernel error: {}", error.0))
impl From<PyRustError> for PyErr {
fn from(error: PyRustError) -> Self {
let msg = match error {
PyRustError::PyKernelError(e) => format!("Kernel error: {}", e),
PyRustError::PyPolarsError(e) => format!("Polars error: {}", e),
};
PyValueError::new_err(msg)
}
}

impl From<KernelError> for PyKernelError {
impl From<KernelError> for PyRustError {
fn from(delta_kernel_error: KernelError) -> Self {
Self(delta_kernel_error)
Self::PyKernelError(delta_kernel_error)
}
}

impl From<PolarsError> for PyRustError {
fn from(polars_error: PolarsError) -> Self {
Self::PyPolarsError(polars_error)
}
}

type DeltaPyResult<T> = std::result::Result<T, PyKernelError>;
type DeltaPyResult<T> = std::result::Result<T, PyRustError>;

#[pyclass]
struct Table(delta_kernel::Table);
Expand Down Expand Up @@ -117,6 +140,44 @@ fn try_create_record_batch_iter(
RecordBatchIterator::new(record_batches, result_schema)
}

unsafe fn record_batch_to_dataframe(batch: &RecordBatch) -> Result<DataFrame, PolarsError> {
let mut columns = Vec::with_capacity(batch.num_columns());

// Arrow stores data by columns, therefore need to be Zero-copied by column
for (i, col) in batch.columns().iter().enumerate() {
// Convert to ArrayData (arrow-rs)
let array = col.to_data();

// Convert to ffi with arrow-rs
let (out_array, out_schema) = to_ffi(&array).unwrap();

// Import field from ffi with polars
let field = unsafe {
import_field_from_c(transmute::<
&arrow::ffi::FFI_ArrowSchema,
&polars_arrow::ffi::ArrowSchema,
>(&out_schema))
}?;

// Import data from ffi with polars
let data = unsafe {
import_array_from_c(
transmute::<arrow::ffi::FFI_ArrowArray, polars_arrow::ffi::ArrowArray>(
out_array,
),
field.dtype().clone(),
)
}?;

// Create Polars series from arrow column
columns.push(Series::from_arrow(
batch.schema().field(i).name().into(),
data,
)?);
}
Ok(DataFrame::from_iter(columns))
}

#[pyclass]
struct Scan(delta_kernel::scan::Scan);

Expand All @@ -131,6 +192,24 @@ impl Scan {
let record_batch_iter = try_create_record_batch_iter(results, result_schema);
Ok(PyArrowType(Box::new(record_batch_iter)))
}

fn execute_polars(
&self,
engine_interface: &PythonInterface,
) -> DeltaPyResult<PyDataFrame> {
let result_schema: ArrowSchemaRef = try_get_schema(self.0.schema())?;
let results = self.0.execute(engine_interface.0.clone())?;
let record_batch_iter = try_create_record_batch_iter(results, result_schema);
let mut dfs = Vec::new();
for rb in record_batch_iter {
unsafe {
let df = record_batch_to_dataframe(&rb.map_err(KernelError::Arrow)?)?;
dfs.push(df.lazy())
};
};
let dfs_concat = concat(dfs, UnionArgs::default());
Ok(PyDataFrame(dfs_concat?.collect()?))
}
}

#[pyclass]
Expand Down
3 changes: 2 additions & 1 deletion python/delta_sharing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

from delta_sharing.delta_sharing import SharingClient, load_as_pandas, load_as_spark
from delta_sharing.delta_sharing import SharingClient, load_as_pandas, load_as_polars, load_as_spark
from delta_sharing.delta_sharing import get_table_metadata, get_table_protocol, get_table_version
from delta_sharing.delta_sharing import load_table_changes_as_pandas, load_table_changes_as_spark
from delta_sharing.protocol import Share, Schema, Table
Expand All @@ -30,6 +30,7 @@
"get_table_protocol",
"get_table_version",
"load_as_pandas",
"load_as_polars",
"load_as_spark",
"load_table_changes_as_pandas",
"load_table_changes_as_spark",
Expand Down
23 changes: 23 additions & 0 deletions python/delta_sharing/delta_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pathlib import Path

import pandas as pd
import polars as pl

from delta_sharing.protocol import CdfOptions, Protocol, Metadata

Expand Down Expand Up @@ -147,6 +148,28 @@ def load_as_pandas(
).to_pandas()


def load_as_polars(
url: str,
limit: Optional[int] = None,
version: Optional[int] = None,
timestamp: Optional[str] = None,
jsonPredicateHints: Optional[str] = None,
use_delta_format: Optional[bool] = None,
convert_in_batches = False,
) -> pl.DataFrame:
profile_json, share, schema, table = _parse_url(url)
profile = DeltaSharingProfile.read_from_file(profile_json)
return DeltaSharingReader(
table=Table(name=table, share=share, schema=schema),
rest_client=DataSharingRestClient(profile),
jsonPredicateHints=jsonPredicateHints,
limit=limit,
version=version,
timestamp=timestamp,
use_delta_format=use_delta_format,
).to_polars()


def load_as_spark(
url: str, version: Optional[int] = None, timestamp: Optional[str] = None
) -> "PySparkDataFrame": # noqa: F821
Expand Down
133 changes: 133 additions & 0 deletions python/delta_sharing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import fsspec
import os
import pandas as pd
import polars as pl
import pyarrow as pa
import tempfile
from pyarrow.dataset import dataset
Expand Down Expand Up @@ -150,6 +151,46 @@ def __to_pandas_kernel(self):

return result

def __to_polars_kernel(self):
self._rest_client.set_delta_format_header()
response = self._rest_client.list_files_in_table(
self._table,
predicateHints=self._predicateHints,
jsonPredicateHints=self._jsonPredicateHints,
limitHint=self._limit,
version=self._version,
timestamp=self._timestamp,
)

lines = response.lines
# Create a temporary directory using the tempfile module
temp_dir = tempfile.TemporaryDirectory()
table_path = self.__write_temp_delta_log_snapshot(temp_dir.name, lines)
num_files = len(lines)

# Invoke delta-kernel-rust to return the pandas dataframe
interface = delta_kernel_rust_sharing_wrapper.PythonInterface(table_path)
table = delta_kernel_rust_sharing_wrapper.Table(table_path)
snapshot = table.snapshot(interface)
scan = delta_kernel_rust_sharing_wrapper.ScanBuilder(snapshot).build()

# The table is empty so use the schema to return an empty table with correct col names
if num_files == 0:
schema = scan.execute(interface).schema
return pl.DataFrame(schema=schema.names)

result = scan.execute_polars(interface)

# Apply residual limit that was not handled from server pushdown
if self._limit:
result = result.head(self._limit)

# Delete the temp folder explicitly and remove the delta format from header
temp_dir.cleanup()
self._rest_client.remove_delta_format_header()

return result

def to_pandas(self) -> pd.DataFrame:
response_format = ""
# If client does not specify which format to use, autoresolve it.
Expand Down Expand Up @@ -215,6 +256,52 @@ def to_pandas(self) -> pd.DataFrame:

return merged[[col_map[field["name"].lower()] for field in schema_json["fields"]]]

def to_polars(self) -> pl.DataFrame:
response_format = ""
# If client does not specify which format to use, autoresolve it.
# Otherwise use the specified format.
if self._use_delta_format is None:
response_format = self._rest_client.autoresolve_query_format(self._table)
elif self._use_delta_format:
response_format = response_format = DataSharingRestClient.DELTA_FORMAT

# If the response format is delta, use delta kernel rust
if response_format == DataSharingRestClient.DELTA_FORMAT:
return self.__to_polars_kernel()

# Otherwise use the standard approach
response = self._rest_client.list_files_in_table(
self._table,
predicateHints=self._predicateHints,
jsonPredicateHints=self._jsonPredicateHints,
limitHint=self._limit,
version=self._version,
timestamp=self._timestamp,
)

schema_json = loads(response.metadata.schema_string)

if len(response.add_files) == 0 or self._limit == 0:
return pl.from_pandas(get_empty_table(schema_json))

converters = to_converters(schema_json)

pdfs = [
DeltaSharingReader._to_polars(file, converters, False)
for file in response.add_files
]

merged = pl.concat(pdfs, how='diagonal_relaxed')

if self._limit:
merged = merged.head(self._limit)

col_map = {}
for col in merged.collect_schema().names():
col_map[col.lower()] = col

return merged.select([col_map[field["name"].lower()] for field in schema_json["fields"]]).collect()

def __write_temp_delta_log_snapshot(self, temp_dir: str, lines: List[str]) -> str:
delta_log_dir_name = temp_dir
table_path = "file:///" + delta_log_dir_name
Expand Down Expand Up @@ -509,6 +596,52 @@ def _to_pandas(
pdf[DeltaSharingReader._commit_timestamp_col_name()] = action.timestamp
return pdf

@staticmethod
def _to_polars(
action: FileAction,
converters: Dict[str, Callable[[str], Any]],
for_cdf: bool,
) -> pl.LazyFrame:
url = urlparse(action.url)
if "storage.googleapis.com" in (url.netloc.lower()):
# Apply the yarl patch for GCS pre-signed urls
import delta_sharing._yarl_patch # noqa: F401

pdf = pl.scan_parquet(source=action.url)

lowered_cols = set()
for col in pdf.collect_schema().names():
lowered_cols.add(col.lower())

for col, converter in converters.items():
lowered = col.lower()
if lowered not in lowered_cols:
if col in action.partition_values:
if converter is not None:
pdf = pdf.with_columns(converter(action.partition_values[col]))
else:
raise ValueError("Cannot partition on binary or complex columns")
else:
pdf = pdf.with_columns(pl.lit(None).alias(col))

if for_cdf:
columns = []
# Add the change type col name to non cdc actions.
if not isinstance(action, AddCdcFile):
columns.append(pl.lit(action.get_change_type_col_value()).alias(DeltaSharingReader._change_type_col_name()))

# If available, add timestamp and version columns from the action.
# All rows of the dataframe will get the same value.
if action.version is not None:
assert DeltaSharingReader._commit_version_col_name() not in pdf.columns
columns.append(pl.lit(action.version).alias(DeltaSharingReader._commit_version_col_name()))

if action.timestamp is not None:
assert DeltaSharingReader._commit_timestamp_col_name() not in pdf.columns
columns.append(pl.lit(action.timestamp).alias(DeltaSharingReader._commit_timestamp_col_name()))

pdf = pdf.with_columns(columns)
return pdf
# The names of special delta columns for cdf.

@staticmethod
Expand Down
Loading