diff --git a/examples/parquet.rs b/examples/parquet.rs new file mode 100644 index 0000000..9d16ef5 --- /dev/null +++ b/examples/parquet.rs @@ -0,0 +1,35 @@ +//! Example of using the datafusion-tpch extension to generate TPCH tables +//! and writing them to disk via `COPY`. + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_tpch::{register_tpch_udtf, register_tpch_udtfs}; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true)); + register_tpch_udtf(&ctx); + + let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?; + sql_df.show().await?; + + let sql_df = ctx + .sql(&format!( + "COPY nation TO './tpch_nation.parquet' STORED AS PARQUET" + )) + .await?; + sql_df.show().await?; + + register_tpch_udtfs(&ctx)?; + + let sql_df = ctx + .sql(&format!( + "COPY (SELECT * FROM tpch_lineitem(1.0)) TO './tpch_lineitem_sf_10.parquet' STORED AS PARQUET" + )) + .await?; + sql_df.show().await?; + + Ok(()) +} diff --git a/examples/tpchgen.rs b/examples/tpchgen.rs new file mode 100644 index 0000000..7f39b24 --- /dev/null +++ b/examples/tpchgen.rs @@ -0,0 +1,39 @@ +//! Example of using the datafusion-tpch extension to generate TPCH datasets +//! on the the fly in datafusion. + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_tpch::register_tpch_udtf; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true)); + register_tpch_udtf(&ctx); + + let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM nation LIMIT 5;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM partsupp LIMIT 5;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM region LIMIT 5;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM customer LIMIT 5;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM orders LIMIT 5;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM lineitem LIMIT 5;")).await?; + sql_df.show().await?; + + let sql_df = ctx.sql(&format!("SELECT * FROM part LIMIT 5;")).await?; + sql_df.show().await?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 7d3a961..85e3e5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,12 @@ use datafusion::arrow::compute::concat_batches; +use datafusion::arrow::datatypes::Schema; use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::common::{Result, ScalarValue, plan_err}; use datafusion::datasource::memory::MemTable; use datafusion::prelude::SessionContext; +use datafusion::sql::TableReference; use datafusion_expr::Expr; +use std::fmt::Debug; use std::sync::Arc; use tpchgen_arrow::RecordBatchIterator; @@ -11,11 +14,7 @@ use tpchgen_arrow::RecordBatchIterator; /// as the data source. macro_rules! define_tpch_udtf_provider { ($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => { - #[doc = concat!( - "A table function that generates the `", - stringify!($TABLE_FUNCTION_SQL_NAME), - "` table using the `tpchgen` library." - )] + #[doc = concat!("A table function that generates the `",stringify!($TABLE_FUNCTION_SQL_NAME),"` table using the `tpchgen` library.")] /// /// The expected arguments are a float literal for the scale factor, /// an i64 literal for the part, and an i64 literal for the number of parts. @@ -59,6 +58,19 @@ macro_rules! define_tpch_udtf_provider { pub fn name() -> &'static str { stringify!($TABLE_FUNCTION_SQL_NAME) } + + /// Returns the name of the table generated by the table function + /// when used in a SQL query. + pub fn table_name() -> &'static str { + stringify!($TABLE_FUNCTION_SQL_NAME) + .strip_prefix("tpch_") + .unwrap_or_else(|| { + panic!( + "Table function name {} does not start with tpch_", + stringify!($TABLE_FUNCTION_SQL_NAME) + ) + }) + } } impl TableFunctionImpl for $TABLE_FUNCTION_NAME { @@ -194,6 +206,122 @@ pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> { Ok(()) } +/// Table function provider for TPCH tables. +struct TpchTables { + ctx: SessionContext, +} + +impl TpchTables { + const TPCH_TABLE_NAMES: &[&str] = &[ + "nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier", "region", + ]; + /// Creates a new TPCH table provider. + pub fn new(ctx: SessionContext) -> Self { + Self { ctx } + } + + /// Build and register a TPCH table by it's table function provider. + fn build_and_register_tpch_table( + &self, + provider: P, + table_name: &str, + scale_factor: f64, + ) -> Result<()> { + let table = provider + .call(vec![Expr::Literal(ScalarValue::Float64(Some(scale_factor)))].as_slice())?; + self.ctx + .register_table(TableReference::bare(table_name), table)?; + + Ok(()) + } + + /// Build and register all TPCH tables in the session context. + fn build_and_register_all_tables(&self, scale_factor: f64) -> Result<()> { + for &suffix in Self::TPCH_TABLE_NAMES { + match suffix { + "nation" => { + self.build_and_register_tpch_table(TpchNation {}, suffix, scale_factor)? + } + "customer" => { + self.build_and_register_tpch_table(TpchCustomer {}, suffix, scale_factor)? + } + "orders" => { + self.build_and_register_tpch_table(TpchOrders {}, suffix, scale_factor)? + } + "lineitem" => { + self.build_and_register_tpch_table(TpchLineitem {}, suffix, scale_factor)? + } + "part" => self.build_and_register_tpch_table(TpchPart {}, suffix, scale_factor)?, + "partsupp" => { + self.build_and_register_tpch_table(TpchPartsupp {}, suffix, scale_factor)? + } + "supplier" => { + self.build_and_register_tpch_table(TpchSupplier {}, suffix, scale_factor)? + } + "region" => { + self.build_and_register_tpch_table(TpchRegion {}, suffix, scale_factor)? + } + _ => unreachable!("Unknown TPCH table suffix: {}", suffix), // Should not happen + } + } + Ok(()) + } +} + +// Implement the `TableProvider` trait for the `TpchTableProvider`, we need +// to do it manually because the `SessionContext` does not implement it. +impl Debug for TpchTables { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TpchTableProvider") + } +} + +impl TableFunctionImpl for TpchTables { + /// The `call` method is the entry point for the UDTF and is called when the UDTF is + /// invoked in a SQL query. + /// + /// The UDF requires one argument, the scale factor, and allows a second optional + /// argument which is a path on disk. If a path is specified, the data is flushed + /// to disk from the generated memory table. + fn call(&self, args: &[Expr]) -> Result> { + let scale_factor = match args.first() { + Some(Expr::Literal(ScalarValue::Float64(Some(value)))) => *value, + _ => { + return plan_err!( + "First argument must be a float literal that specifies the scale factor." + ); + } + }; + + // Register the TPCH tables in the session context. + self.build_and_register_all_tables(scale_factor)?; + + // Create a table with the schema |table_name| and the data is just the + // individual table names. + let schema = Schema::new(vec![datafusion::arrow::datatypes::Field::new( + "table_name", + datafusion::arrow::datatypes::DataType::Utf8, + false, + )]); + let batch = datafusion::arrow::record_batch::RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(datafusion::arrow::array::StringArray::from(vec![ + "nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier", + "region", + ]))], + )?; + let mem_table = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; + + Ok(Arc::new(mem_table)) + } +} + +/// Register the `tpch` table function. +pub fn register_tpch_udtf(ctx: &SessionContext) { + let tpch_udtf = TpchTables::new(ctx.clone()); + ctx.register_udtf("tpch", Arc::new(tpch_udtf)); +} + #[cfg(test)] mod tests { use super::*; @@ -203,12 +331,15 @@ mod tests { async fn test_register_all_tpch_functions() -> Result<()> { let ctx = SessionContext::new(); + let tpch_tbl_fn = TpchTables::new(ctx.clone()); + ctx.register_udtf("tcph", Arc::new(tpch_tbl_fn)); + // Register all the UDTFs. register_tpch_udtfs(&ctx)?; // Test all the UDTFs, the constants were computed using the tpchgen library // and the expected values are the number of rows and columns for each table. - let test_cases = vec![ + let expected_tables = vec![ (TpchNation::name(), 25, 4), (TpchCustomer::name(), 150000, 8), (TpchOrders::name(), 1500000, 9), @@ -219,7 +350,7 @@ mod tests { (TpchRegion::name(), 5, 3), ]; - for (function, expected_rows, expected_columns) in test_cases { + for (function, expected_rows, expected_columns) in expected_tables { let df = ctx .sql(&format!("SELECT * FROM {}(1.0)", function)) .await? @@ -261,7 +392,7 @@ mod tests { // Test all the UDTFs, the constants were computed using the tpchgen library // and the expected values are the number of rows and columns for each table. - let test_cases = vec![ + let expected_tables = vec![ (TpchNation::name(), 25, 4), (TpchCustomer::name(), 150000, 8), (TpchOrders::name(), 1500000, 9), @@ -272,7 +403,7 @@ mod tests { (TpchRegion::name(), 5, 3), ]; - for (function, expected_rows, expected_columns) in test_cases { + for (function, expected_rows, expected_columns) in expected_tables { let df = ctx .sql(&format!("SELECT * FROM {}(1.0)", function)) .await? @@ -297,4 +428,58 @@ mod tests { } Ok(()) } + + #[tokio::test] + async fn test_register_tpch_provider() -> Result<()> { + let ctx = SessionContext::new(); + + register_tpch_udtf(&ctx); + + // Test the TPCH provider. + let df = ctx + .sql("SELECT * FROM tpch(1.0, '')") + .await? + .collect() + .await?; + + assert_eq!(df.len(), 1); + assert_eq!(df[0].num_rows(), 8); + assert_eq!(df[0].num_columns(), 1); + + let expected_tables = vec![ + (TpchNation::table_name(), 25, 4), + (TpchCustomer::table_name(), 150000, 8), + (TpchOrders::table_name(), 1500000, 9), + (TpchLineitem::table_name(), 6001215, 16), + (TpchPart::table_name(), 200000, 9), + (TpchPartsupp::table_name(), 800000, 5), + (TpchSupplier::table_name(), 10000, 7), + (TpchRegion::table_name(), 5, 3), + ]; + + for (function, expected_rows, expected_columns) in expected_tables { + let df = ctx + .sql(&format!("SELECT * FROM {}", function)) + .await? + .collect() + .await?; + + assert_eq!(df.len(), 1); + assert_eq!( + df[0].num_rows(), + expected_rows, + "{}: {}", + function, + expected_rows + ); + assert_eq!( + df[0].num_columns(), + expected_columns, + "{}: {}", + function, + expected_columns + ); + } + Ok(()) + } }