Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 245 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
use datafusion::arrow::compute::concat_batches;
use datafusion::arrow::datatypes::Schema;
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::datasource::memory::MemTable;
use datafusion::prelude::SessionContext;
use datafusion::sql::TableReference;
use datafusion_expr::Expr;
use std::fmt::Debug;
use std::sync::Arc;
use tpchgen_arrow::RecordBatchIterator;

/// Defines a table function provider and its implementation using [`tpchgen`]
/// as the data source.
macro_rules! define_tpch_udtf_provider {
($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => {
#[doc = concat!(
"A table function that generates the `",
stringify!($TABLE_FUNCTION_SQL_NAME),
"` table using the `tpchgen` library."
)]
#[doc = concat!("A table function that generates the `",stringify!($TABLE_FUNCTION_SQL_NAME),"` table using the `tpchgen` library.")]
///
/// The expected arguments are a float literal for the scale factor,
/// an i64 literal for the part, and an i64 literal for the number of parts.
Expand Down Expand Up @@ -59,6 +58,19 @@ macro_rules! define_tpch_udtf_provider {
pub fn name() -> &'static str {
stringify!($TABLE_FUNCTION_SQL_NAME)
}

/// Returns the name of the table generated by the table function
/// when used in a SQL query.
pub fn table_name() -> &'static str {
stringify!($TABLE_FUNCTION_SQL_NAME)
.strip_prefix("tpch_")
.unwrap_or_else(|| {
panic!(
"Table function name {} does not start with tpch_",
stringify!($TABLE_FUNCTION_SQL_NAME)
)
})
}
}

impl TableFunctionImpl for $TABLE_FUNCTION_NAME {
Expand Down Expand Up @@ -194,6 +206,173 @@ pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> {
Ok(())
}

/// Table function provider for TPCH tables.
struct TpchTables {
ctx: SessionContext,
}

impl TpchTables {
const TPCH_TABLE_NAMES: &[&str] = &[
"nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier", "region",
];
/// Creates a new TPCH table provider.
pub fn new(ctx: SessionContext) -> Self {
Self { ctx }
}

/// Build and register a TPCH table by it's table function provider.
fn build_and_register_tpch_table<P: TableFunctionImpl>(
&self,
provider: P,
table_name: &str,
scale_factor: f64,
write_to_disk: bool,
_path: &str,
Comment thread
clflushopt marked this conversation as resolved.
Outdated
) -> Result<()> {
// Short path when the table is generated in memory only.
if !write_to_disk {
let table = provider
.call(vec![Expr::Literal(ScalarValue::Float64(Some(scale_factor)))].as_slice())?;
self.ctx
.register_table(TableReference::bare(table_name), table)?;
Comment thread
clflushopt marked this conversation as resolved.
Outdated
return Ok(());
}

Ok(())
}

/// Build and register all TPCH tables in the session context.
fn build_and_register_all_tables(
&self,
scale_factor: f64,
write_to_disk: bool,
path: &str,
) -> Result<()> {
for &suffix in Self::TPCH_TABLE_NAMES {
match suffix {
"nation" => self.build_and_register_tpch_table(
TpchNation {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"customer" => self.build_and_register_tpch_table(
TpchCustomer {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"orders" => self.build_and_register_tpch_table(
TpchOrders {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"lineitem" => self.build_and_register_tpch_table(
TpchLineitem {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"part" => self.build_and_register_tpch_table(
TpchPart {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"partsupp" => self.build_and_register_tpch_table(
TpchPartsupp {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"supplier" => self.build_and_register_tpch_table(
TpchSupplier {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
"region" => self.build_and_register_tpch_table(
TpchRegion {},
suffix,
scale_factor,
write_to_disk,
path,
)?,
_ => unreachable!("Unknown TPCH table suffix: {}", suffix), // Should not happen
}
}
Ok(())
}
}

// Implement the `TableProvider` trait for the `TpchTableProvider`, we need
// to do it manually because the `SessionContext` does not implement it.
impl Debug for TpchTables {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TpchTableProvider")
}
}

impl TableFunctionImpl for TpchTables {
/// The `call` method is the entry point for the UDTF and is called when the UDTF is
/// invoked in a SQL query.
///
/// It takes a list of arguments, the scale factor, whether to generate the data on
/// disk in parquet format and the path to the output files. If no path is provided,
/// the data is generated in memory and we fallback to the `MemTable` provider.
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let scale_factor = match args.first() {
Some(Expr::Literal(ScalarValue::Float64(Some(value)))) => *value,
_ => return plan_err!("First argument must be a float literal."),
};

let write_to_disk = match args.get(1) {
Some(Expr::Literal(ScalarValue::Boolean(Some(value)))) => *value,
_ => false,
};

let path = match args.get(2) {
Some(Expr::Literal(ScalarValue::Utf8(Some(value)))) => value.clone(),
_ => "".to_string(),
};

// Register the TPCH tables in the session context.
self.build_and_register_all_tables(scale_factor, write_to_disk, &path)?;

// Create a table with the schema |table_name| and the data is just the
// individual table names.
let schema = Schema::new(vec![datafusion::arrow::datatypes::Field::new(
"table_name",
datafusion::arrow::datatypes::DataType::Utf8,
false,
)]);
let batch = datafusion::arrow::record_batch::RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(datafusion::arrow::array::StringArray::from(vec![
"nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier",
"region",
]))],
)?;
let mem_table = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;

Ok(Arc::new(mem_table))
}
}

/// Register the `tpch` table function.
pub fn register_tpch_udtf(ctx: &SessionContext) {
let tpch_udtf = TpchTables::new(ctx.clone());
ctx.register_udtf("tpch", Arc::new(tpch_udtf));
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -203,12 +382,15 @@ mod tests {
async fn test_register_all_tpch_functions() -> Result<()> {
let ctx = SessionContext::new();

let tpch_tbl_fn = TpchTables::new(ctx.clone());
ctx.register_udtf("tcph", Arc::new(tpch_tbl_fn));

// Register all the UDTFs.
register_tpch_udtfs(&ctx)?;

// Test all the UDTFs, the constants were computed using the tpchgen library
// and the expected values are the number of rows and columns for each table.
let test_cases = vec![
let expected_tables = vec![
(TpchNation::name(), 25, 4),
(TpchCustomer::name(), 150000, 8),
(TpchOrders::name(), 1500000, 9),
Expand All @@ -219,7 +401,7 @@ mod tests {
(TpchRegion::name(), 5, 3),
];

for (function, expected_rows, expected_columns) in test_cases {
for (function, expected_rows, expected_columns) in expected_tables {
let df = ctx
.sql(&format!("SELECT * FROM {}(1.0)", function))
.await?
Expand Down Expand Up @@ -261,7 +443,7 @@ mod tests {

// Test all the UDTFs, the constants were computed using the tpchgen library
// and the expected values are the number of rows and columns for each table.
let test_cases = vec![
let expected_tables = vec![
(TpchNation::name(), 25, 4),
(TpchCustomer::name(), 150000, 8),
(TpchOrders::name(), 1500000, 9),
Expand All @@ -272,7 +454,7 @@ mod tests {
(TpchRegion::name(), 5, 3),
];

for (function, expected_rows, expected_columns) in test_cases {
for (function, expected_rows, expected_columns) in expected_tables {
let df = ctx
.sql(&format!("SELECT * FROM {}(1.0)", function))
.await?
Expand All @@ -297,4 +479,58 @@ mod tests {
}
Ok(())
}

#[tokio::test]
async fn test_register_tpch_provider() -> Result<()> {
let ctx = SessionContext::new();

register_tpch_udtf(&ctx);

// Test the TPCH provider.
let df = ctx
.sql("SELECT * FROM tpch(1.0, false, '')")
.await?
.collect()
.await?;

assert_eq!(df.len(), 1);
assert_eq!(df[0].num_rows(), 8);
assert_eq!(df[0].num_columns(), 1);

let expected_tables = vec![
(TpchNation::table_name(), 25, 4),
(TpchCustomer::table_name(), 150000, 8),
(TpchOrders::table_name(), 1500000, 9),
(TpchLineitem::table_name(), 6001215, 16),
(TpchPart::table_name(), 200000, 9),
(TpchPartsupp::table_name(), 800000, 5),
(TpchSupplier::table_name(), 10000, 7),
(TpchRegion::table_name(), 5, 3),
];

for (function, expected_rows, expected_columns) in expected_tables {
let df = ctx
.sql(&format!("SELECT * FROM {}", function))
.await?
.collect()
.await?;

assert_eq!(df.len(), 1);
assert_eq!(
df[0].num_rows(),
expected_rows,
"{}: {}",
function,
expected_rows
);
assert_eq!(
df[0].num_columns(),
expected_columns,
"{}: {}",
function,
expected_columns
);
}
Ok(())
}
}