Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/parquet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//! Example of using the datafusion-tpch extension to generate TPCH tables
//! and writing them to disk via `COPY`.

use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_tpch::{register_tpch_udtf, register_tpch_udtfs};

#[tokio::main]
async fn main() -> datafusion::error::Result<()> {
let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true));
register_tpch_udtf(&ctx);

let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?;
sql_df.show().await?;

let sql_df = ctx
.sql(&format!(
"COPY nation TO './tpch_nation.parquet' STORED AS PARQUET"
))
.await?;
sql_df.show().await?;

register_tpch_udtfs(&ctx)?;

let sql_df = ctx
.sql(&format!(
"COPY (SELECT * FROM tpch_lineitem(1.0)) TO './tpch_lineitem_sf_10.parquet' STORED AS PARQUET"
))
.await?;
sql_df.show().await?;

Ok(())
}
39 changes: 39 additions & 0 deletions examples/tpchgen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! Example of using the datafusion-tpch extension to generate TPCH datasets
//! on the the fly in datafusion.

use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_tpch::register_tpch_udtf;

#[tokio::main]
async fn main() -> datafusion::error::Result<()> {
let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true));
register_tpch_udtf(&ctx);

let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM nation LIMIT 5;")).await?;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really cool

sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM partsupp LIMIT 5;")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM region LIMIT 5;")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM customer LIMIT 5;")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM orders LIMIT 5;")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM lineitem LIMIT 5;")).await?;
sql_df.show().await?;

let sql_df = ctx.sql(&format!("SELECT * FROM part LIMIT 5;")).await?;
sql_df.show().await?;
Ok(())
}
203 changes: 194 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
use datafusion::arrow::compute::concat_batches;
use datafusion::arrow::datatypes::Schema;
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::datasource::memory::MemTable;
use datafusion::prelude::SessionContext;
use datafusion::sql::TableReference;
use datafusion_expr::Expr;
use std::fmt::Debug;
use std::sync::Arc;
use tpchgen_arrow::RecordBatchIterator;

/// Defines a table function provider and its implementation using [`tpchgen`]
/// as the data source.
macro_rules! define_tpch_udtf_provider {
($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => {
#[doc = concat!(
"A table function that generates the `",
stringify!($TABLE_FUNCTION_SQL_NAME),
"` table using the `tpchgen` library."
)]
#[doc = concat!("A table function that generates the `",stringify!($TABLE_FUNCTION_SQL_NAME),"` table using the `tpchgen` library.")]
///
/// The expected arguments are a float literal for the scale factor,
/// an i64 literal for the part, and an i64 literal for the number of parts.
Expand Down Expand Up @@ -59,6 +58,19 @@ macro_rules! define_tpch_udtf_provider {
pub fn name() -> &'static str {
stringify!($TABLE_FUNCTION_SQL_NAME)
}

/// Returns the name of the table generated by the table function
/// when used in a SQL query.
pub fn table_name() -> &'static str {
stringify!($TABLE_FUNCTION_SQL_NAME)
.strip_prefix("tpch_")
.unwrap_or_else(|| {
panic!(
"Table function name {} does not start with tpch_",
stringify!($TABLE_FUNCTION_SQL_NAME)
)
})
}
}

impl TableFunctionImpl for $TABLE_FUNCTION_NAME {
Expand Down Expand Up @@ -194,6 +206,122 @@ pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> {
Ok(())
}

/// Table function provider for TPCH tables.
struct TpchTables {
ctx: SessionContext,
}

impl TpchTables {
const TPCH_TABLE_NAMES: &[&str] = &[
"nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier", "region",
];
/// Creates a new TPCH table provider.
pub fn new(ctx: SessionContext) -> Self {
Self { ctx }
}

/// Build and register a TPCH table by it's table function provider.
fn build_and_register_tpch_table<P: TableFunctionImpl>(
&self,
provider: P,
table_name: &str,
scale_factor: f64,
) -> Result<()> {
let table = provider
.call(vec![Expr::Literal(ScalarValue::Float64(Some(scale_factor)))].as_slice())?;
self.ctx
.register_table(TableReference::bare(table_name), table)?;

Ok(())
}

/// Build and register all TPCH tables in the session context.
fn build_and_register_all_tables(&self, scale_factor: f64) -> Result<()> {
for &suffix in Self::TPCH_TABLE_NAMES {
match suffix {
"nation" => {
self.build_and_register_tpch_table(TpchNation {}, suffix, scale_factor)?
}
"customer" => {
self.build_and_register_tpch_table(TpchCustomer {}, suffix, scale_factor)?
}
"orders" => {
self.build_and_register_tpch_table(TpchOrders {}, suffix, scale_factor)?
}
"lineitem" => {
self.build_and_register_tpch_table(TpchLineitem {}, suffix, scale_factor)?
}
"part" => self.build_and_register_tpch_table(TpchPart {}, suffix, scale_factor)?,
"partsupp" => {
self.build_and_register_tpch_table(TpchPartsupp {}, suffix, scale_factor)?
}
"supplier" => {
self.build_and_register_tpch_table(TpchSupplier {}, suffix, scale_factor)?
}
"region" => {
self.build_and_register_tpch_table(TpchRegion {}, suffix, scale_factor)?
}
_ => unreachable!("Unknown TPCH table suffix: {}", suffix), // Should not happen
}
}
Ok(())
}
}

// Implement the `TableProvider` trait for the `TpchTableProvider`, we need
// to do it manually because the `SessionContext` does not implement it.
impl Debug for TpchTables {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TpchTableProvider")
}
}

impl TableFunctionImpl for TpchTables {
/// The `call` method is the entry point for the UDTF and is called when the UDTF is
/// invoked in a SQL query.
///
/// The UDF requires one argument, the scale factor, and allows a second optional
/// argument which is a path on disk. If a path is specified, the data is flushed
/// to disk from the generated memory table.
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let scale_factor = match args.first() {
Some(Expr::Literal(ScalarValue::Float64(Some(value)))) => *value,
_ => {
return plan_err!(
"First argument must be a float literal that specifies the scale factor."
);
}
};

// Register the TPCH tables in the session context.
self.build_and_register_all_tables(scale_factor)?;

// Create a table with the schema |table_name| and the data is just the
// individual table names.
let schema = Schema::new(vec![datafusion::arrow::datatypes::Field::new(
"table_name",
datafusion::arrow::datatypes::DataType::Utf8,
false,
)]);
let batch = datafusion::arrow::record_batch::RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(datafusion::arrow::array::StringArray::from(vec![
"nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier",
"region",
]))],
)?;
let mem_table = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;

Ok(Arc::new(mem_table))
}
}

/// Register the `tpch` table function.
pub fn register_tpch_udtf(ctx: &SessionContext) {
let tpch_udtf = TpchTables::new(ctx.clone());
ctx.register_udtf("tpch", Arc::new(tpch_udtf));
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -203,12 +331,15 @@ mod tests {
async fn test_register_all_tpch_functions() -> Result<()> {
let ctx = SessionContext::new();

let tpch_tbl_fn = TpchTables::new(ctx.clone());
ctx.register_udtf("tcph", Arc::new(tpch_tbl_fn));

// Register all the UDTFs.
register_tpch_udtfs(&ctx)?;

// Test all the UDTFs, the constants were computed using the tpchgen library
// and the expected values are the number of rows and columns for each table.
let test_cases = vec![
let expected_tables = vec![
(TpchNation::name(), 25, 4),
(TpchCustomer::name(), 150000, 8),
(TpchOrders::name(), 1500000, 9),
Expand All @@ -219,7 +350,7 @@ mod tests {
(TpchRegion::name(), 5, 3),
];

for (function, expected_rows, expected_columns) in test_cases {
for (function, expected_rows, expected_columns) in expected_tables {
let df = ctx
.sql(&format!("SELECT * FROM {}(1.0)", function))
.await?
Expand Down Expand Up @@ -261,7 +392,7 @@ mod tests {

// Test all the UDTFs, the constants were computed using the tpchgen library
// and the expected values are the number of rows and columns for each table.
let test_cases = vec![
let expected_tables = vec![
(TpchNation::name(), 25, 4),
(TpchCustomer::name(), 150000, 8),
(TpchOrders::name(), 1500000, 9),
Expand All @@ -272,7 +403,7 @@ mod tests {
(TpchRegion::name(), 5, 3),
];

for (function, expected_rows, expected_columns) in test_cases {
for (function, expected_rows, expected_columns) in expected_tables {
let df = ctx
.sql(&format!("SELECT * FROM {}(1.0)", function))
.await?
Expand All @@ -297,4 +428,58 @@ mod tests {
}
Ok(())
}

#[tokio::test]
async fn test_register_tpch_provider() -> Result<()> {
let ctx = SessionContext::new();

register_tpch_udtf(&ctx);

// Test the TPCH provider.
let df = ctx
.sql("SELECT * FROM tpch(1.0, '')")
.await?
.collect()
.await?;

assert_eq!(df.len(), 1);
assert_eq!(df[0].num_rows(), 8);
assert_eq!(df[0].num_columns(), 1);

let expected_tables = vec![
(TpchNation::table_name(), 25, 4),
(TpchCustomer::table_name(), 150000, 8),
(TpchOrders::table_name(), 1500000, 9),
(TpchLineitem::table_name(), 6001215, 16),
(TpchPart::table_name(), 200000, 9),
(TpchPartsupp::table_name(), 800000, 5),
(TpchSupplier::table_name(), 10000, 7),
(TpchRegion::table_name(), 5, 3),
];

for (function, expected_rows, expected_columns) in expected_tables {
let df = ctx
.sql(&format!("SELECT * FROM {}", function))
.await?
.collect()
.await?;

assert_eq!(df.len(), 1);
assert_eq!(
df[0].num_rows(),
expected_rows,
"{}: {}",
function,
expected_rows
);
assert_eq!(
df[0].num_columns(),
expected_columns,
"{}: {}",
function,
expected_columns
);
}
Ok(())
}
}