diff --git a/Cargo.toml b/Cargo.toml index 5eef5e9..590a290 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ arrow-ipc = "51.0.0" serde_json = "1.0.115" -parking_lot = "0.12.1" +parking_lot = { version="0.12.1" , features = ["send_guard"]} prost = "0.12.0" prost-types = "0.12.0" diff --git a/examples/databricks.rs b/examples/databricks.rs index 372bed5..2343171 100644 --- a/examples/databricks.rs +++ b/examples/databricks.rs @@ -11,9 +11,11 @@ use spark_connect_rs::{SparkSession, SparkSessionBuilder}; #[tokio::main] async fn main() -> Result<(), Box> { - let spark: SparkSession = SparkSessionBuilder::remote("sc://:443/;token=;x-databricks-cluster-id=") + let spark:Arc = Arc::new( + SparkSessionBuilder::remote("sc://:443/;token=;x-databricks-cluster-id=") .build() - .await?; + .await? + ); spark .range(None, 10, 1, Some(1)) diff --git a/examples/delta.rs b/examples/delta.rs index 425befa..a9ca902 100644 --- a/examples/delta.rs +++ b/examples/delta.rs @@ -6,13 +6,15 @@ // The remote spark session must have the spark package `io.delta:delta-spark_2.12:{DELTA_VERSION}` enabled. // Where the `DELTA_VERSION` is the specified Delta Lake version. +use std::sync::Arc; + use spark_connect_rs::{SparkSession, SparkSessionBuilder}; use spark_connect_rs::dataframe::SaveMode; #[tokio::main] async fn main() -> Result<(), Box> { - let spark: SparkSession = SparkSessionBuilder::default().build().await?; + let spark: Arc = Arc::new(SparkSessionBuilder::default().build().await?); let paths = ["/opt/spark/examples/src/main/resources/people.csv"]; diff --git a/examples/reader.rs b/examples/reader.rs index 550469a..118a55d 100644 --- a/examples/reader.rs +++ b/examples/reader.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use spark_connect_rs::{SparkSession, SparkSessionBuilder}; use spark_connect_rs::functions as F; @@ -7,7 +9,7 @@ use spark_connect_rs::functions as F; // printing the results as "show(...)" #[tokio::main] async fn main() -> Result<(), Box> { - let spark: SparkSession = SparkSessionBuilder::default().build().await?; + let spark: Arc = Arc::new(SparkSessionBuilder::default().build().await?); let path = ["/opt/spark/examples/src/main/resources/people.csv"]; diff --git a/examples/readstream.rs b/examples/readstream.rs index 025098a..7020ec8 100644 --- a/examples/readstream.rs +++ b/examples/readstream.rs @@ -3,15 +3,17 @@ use spark_connect_rs; use spark_connect_rs::streaming::{OutputMode, Trigger}; use spark_connect_rs::{SparkSession, SparkSessionBuilder}; +use std::sync::Arc; use std::{thread, time}; // This example demonstrates creating a Spark Stream and monitoring the progress #[tokio::main] async fn main() -> Result<(), Box> { - let spark: SparkSession = + let spark: Arc = Arc::new( SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") .build() - .await?; + .await?, + ); let df = spark .readStream() diff --git a/examples/sql.rs b/examples/sql.rs index 1284a5c..38aa4bf 100644 --- a/examples/sql.rs +++ b/examples/sql.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use spark_connect_rs; use spark_connect_rs::{SparkSession, SparkSessionBuilder}; @@ -7,10 +9,11 @@ use spark_connect_rs::{SparkSession, SparkSessionBuilder}; // Displaying the results as "show(...)" #[tokio::main] async fn main() -> Result<(), Box> { - let spark: SparkSession = + let spark: Arc = Arc::new( SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") .build() - .await?; + .await?, + ); let df = spark .sql("SELECT * FROM json.`/opt/spark/examples/src/main/resources/employees.json`") diff --git a/examples/writer.rs b/examples/writer.rs index 42e9708..e0312da 100644 --- a/examples/writer.rs +++ b/examples/writer.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use spark_connect_rs; use spark_connect_rs::{SparkSession, SparkSessionBuilder}; @@ -11,7 +13,7 @@ use spark_connect_rs::dataframe::SaveMode; // then reading the csv file back #[tokio::main] async fn main() -> Result<(), Box> { - let spark: SparkSession = SparkSessionBuilder::default().build().await?; + let spark: Arc = Arc::new(SparkSessionBuilder::default().build().await?); let df = spark .clone() diff --git a/src/catalog.rs b/src/catalog.rs index 14a5e9e..7d57dd7 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -1,5 +1,7 @@ //! Spark Catalog representation through which the user may create, drop, alter or query underlying databases, tables, functions, etc. +use std::sync::Arc; + use arrow::array::RecordBatch; use crate::errors::SparkError; @@ -9,11 +11,11 @@ use crate::spark; #[derive(Debug, Clone)] pub struct Catalog { - spark_session: SparkSession, + spark_session: Arc, } impl Catalog { - pub fn new(spark_session: SparkSession) -> Self { + pub fn new(spark_session: Arc) -> Self { Self { spark_session } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 8a8b3f6..5dd2f18 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -127,15 +127,12 @@ impl ChannelBuilder { channel_builder.use_ssl = true } }; - channel_builder.headers = Some(metadata_builder(&headers)); - Ok(channel_builder) } async fn create_client(&self) -> Result { let endpoint = format!("https://{}:{}", self.host, self.port); - let channel = Endpoint::from_shared(endpoint)?.connect().await?; let service_client = SparkConnectServiceClient::with_interceptor( @@ -413,7 +410,6 @@ where self.handle_analyze(resp) } - fn handle_response(&mut self, resp: spark::ExecutePlanResponse) -> Result<(), SparkError> { self.validate_session(&resp.session_id)?; diff --git a/src/dataframe.rs b/src/dataframe.rs index 7ea899f..c23001f 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -1,5 +1,7 @@ //! DataFrame representation for Spark Connection +use std::sync::Arc; + use crate::column::Column; use crate::errors::SparkError; use crate::expressions::{ToExpr, ToFilterExpr, ToVecExpr}; @@ -66,7 +68,7 @@ use arrow::util::pretty; #[derive(Clone, Debug)] pub struct DataFrame { /// Global [SparkSession] connecting to the remote cluster - pub spark_session: SparkSession, + pub spark_session: Arc, /// Logical Plan representing the unresolved Relation /// which will be submitted to the remote cluster @@ -75,7 +77,7 @@ pub struct DataFrame { impl DataFrame { /// create default DataFrame based on a spark session and initial logical plan - pub fn new(spark_session: SparkSession, logical_plan: LogicalPlanBuilder) -> DataFrame { + pub fn new(spark_session: Arc, logical_plan: LogicalPlanBuilder) -> DataFrame { DataFrame { spark_session, logical_plan, @@ -658,15 +660,11 @@ impl DataFrame { spark::analyze_plan_request::Analyze::Schema(spark::analyze_plan_request::Schema { plan: Some(plan), }); - - let data_type = self - .spark_session - .client() - .analyze(schema) - .await? - .schema()?; - - Ok(data_type) + let session = self.spark_session.clone(); + let mut client = session.client(); + let data_type = client.analyze(schema).await?; + let schema = data_type.schema()?; + Ok(schema.clone()) } /// Projects a set of expressions and returns a new [DataFrame] @@ -757,7 +755,7 @@ impl DataFrame { } #[allow(non_snake_case)] - pub fn sparkSession(self) -> SparkSession { + pub fn sparkSession(self) -> Arc { self.spark_session } diff --git a/src/errors.rs b/src/errors.rs index 9a25c9e..f92cab8 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -79,6 +79,9 @@ impl Display for SparkError { } } +unsafe impl Send for SparkError {} +unsafe impl Sync for SparkError {} + impl Error for SparkError { fn source(&self) -> Option<&(dyn Error + 'static)> { if let Self::ExternalError(e) = self { diff --git a/src/readwriter.rs b/src/readwriter.rs index 615f34d..ea70c64 100644 --- a/src/readwriter.rs +++ b/src/readwriter.rs @@ -1,6 +1,7 @@ //! DataFrameReader & DataFrameWriter representations use std::collections::HashMap; +use std::sync::Arc; use crate::errors::SparkError; use crate::plan::LogicalPlanBuilder; @@ -14,14 +15,14 @@ use spark::write_operation::SaveMode; /// from a specific file format. #[derive(Clone, Debug)] pub struct DataFrameReader { - spark_session: SparkSession, + spark_session: Arc, format: Option, read_options: HashMap, } impl DataFrameReader { /// Create a new DataFrameReader with a [SparkSession] - pub fn new(spark_session: SparkSession) -> Self { + pub fn new(spark_session: Arc) -> Self { Self { spark_session, format: None, diff --git a/src/session.rs b/src/session.rs index 05cf905..4f1a5c1 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,7 @@ //! Spark Session containing the remote gRPC client use std::collections::HashMap; +use std::sync::Arc; use crate::catalog::Catalog; pub use crate::client::SparkSessionBuilder; @@ -39,7 +40,7 @@ impl SparkSession { /// `end` (exclusive) with a step value `step`, and control the number /// of partitions with `num_partitions` pub fn range( - self, + self: Arc, start: Option, end: i64, step: i64, @@ -55,29 +56,57 @@ impl SparkSession { DataFrame::new(self, LogicalPlanBuilder::from(range_relation)) } + pub fn setCatalog(self: Arc, catalog: &str) -> DataFrame { + let catalog_relation = spark::relation::RelType::Catalog(spark::Catalog { + cat_type: Some(spark::catalog::CatType::SetCurrentCatalog( + spark::SetCurrentCatalog { + catalog_name: catalog.to_string(), + }, + )), + }); + + let logical_plan = LogicalPlanBuilder::from(catalog_relation); + + DataFrame::new(self, logical_plan) + } + + pub fn setDatabase(self: Arc, database: &str) -> DataFrame { + let catalog_relation = spark::relation::RelType::Catalog(spark::Catalog { + cat_type: Some(spark::catalog::CatType::SetCurrentDatabase( + spark::SetCurrentDatabase { + db_name: database.to_string(), + }, + )), + }); + + let logical_plan = LogicalPlanBuilder::from(catalog_relation); + + DataFrame::new(self, logical_plan) + } + /// Returns a [DataFrameReader] that can be used to read datra in as a [DataFrame] - pub fn read(self) -> DataFrameReader { + pub fn read(self: Arc) -> DataFrameReader { DataFrameReader::new(self) } /// Returns a [DataFrameReader] that can be used to read datra in as a [DataFrame] #[allow(non_snake_case)] - pub fn readStream(self) -> DataStreamReader { + pub fn readStream(self: Arc) -> DataStreamReader { DataStreamReader::new(self) } - pub fn table(self, name: &str) -> Result { + pub fn table(self: Arc, name: &str) -> Result { DataFrameReader::new(self).table(name, None) } /// Interface through which the user may create, drop, alter or query underlying databases, /// tables, functions, etc. - pub fn catalog(self) -> Catalog { + pub fn catalog(self: Arc) -> Catalog { Catalog::new(self) } /// Returns a [DataFrame] representing the result of the given query - pub async fn sql(self, sql_query: &str) -> Result { + pub async fn sql(self: Arc, sql_query: &str) -> Result { let sql_cmd = spark::command::CommandType::SqlCommand(spark::SqlCommand { sql: sql_query.to_string(), args: HashMap::default(), @@ -100,7 +129,7 @@ impl SparkSession { } #[allow(non_snake_case)] - pub fn createDataFrame(self, data: &RecordBatch) -> Result { + pub fn createDataFrame(self: Arc, data: &RecordBatch) -> Result { let logical_plan = LogicalPlanBuilder::local_relation(data)?; Ok(DataFrame::new(self, logical_plan)) } @@ -111,7 +140,9 @@ impl SparkSession { } /// Spark Connection gRPC client interface - pub fn client(self) -> SparkConnectClient> { - self.client + pub fn client( + self: Arc, + ) -> SparkConnectClient> { + self.client.clone() } } diff --git a/src/streaming/mod.rs b/src/streaming/mod.rs index 6d66031..3e79dad 100644 --- a/src/streaming/mod.rs +++ b/src/streaming/mod.rs @@ -1,6 +1,7 @@ //! Streaming implementation for the Spark Connect Client use std::collections::HashMap; +use std::sync::Arc; use crate::plan::LogicalPlanBuilder; use crate::session::SparkSession; @@ -13,14 +14,14 @@ use crate::errors::SparkError; /// DataStreamReader represents the entrypoint to create a streaming DataFrame #[derive(Clone, Debug)] pub struct DataStreamReader { - spark_session: SparkSession, + spark_session: Arc, format: Option, schema: Option, read_options: HashMap, } impl DataStreamReader { - pub fn new(spark_session: SparkSession) -> Self { + pub fn new(spark_session: Arc) -> Self { Self { spark_session, format: None, @@ -238,7 +239,7 @@ impl DataStreamWriter { .write_stream_operation_start_result; Ok(StreamingQuery::new( - self.dataframe.spark_session, + self.dataframe.spark_session.clone(), res.unwrap(), )) } @@ -270,7 +271,7 @@ impl DataStreamWriter { /// This object is used to control and monitor the active stream #[derive(Clone, Debug)] pub struct StreamingQuery { - spark_session: SparkSession, + spark_session: Arc, query_instance: spark::StreamingQueryInstanceId, query_id: String, run_id: String, @@ -279,7 +280,7 @@ pub struct StreamingQuery { impl StreamingQuery { pub fn new( - spark_session: SparkSession, + spark_session: Arc, write_stream: spark::WriteStreamOperationStartResult, ) -> Self { let query_instance = write_stream.query_id.unwrap(); @@ -475,15 +476,17 @@ mod tests { use crate::errors::SparkError; use crate::SparkSessionBuilder; - async fn setup() -> SparkSession { + async fn setup() -> Arc { println!("SparkSession Setup"); let connection = "sc://127.0.0.1:15002/;user_id=rust_stream"; - SparkSessionBuilder::remote(connection) - .build() - .await - .unwrap() + Arc::new( + SparkSessionBuilder::remote(connection) + .build() + .await + .unwrap(), + ) } #[tokio::test]