diff --git a/Cargo.toml b/Cargo.toml index 731227a..07b15b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,20 +20,26 @@ sqlite3-parser = { version = "0.8.0", default-features = false, features = [ "YY http = { version = "0.2", optional = true } bytes = { version = "1.4.0", optional = true } anyhow = "1.0.69" -reqwest = { version = "0.11.14", optional = true, default-features = false, features = ["rustls-tls"] } -hrana-client = { version = "0.3", optional = true } -hrana-client-proto = { version = "0.2" } +hyper = { version = "0.14.27", optional = true, default-features = false } +hyper-rustls = { version = "0.24.1", optional = true, features = ["http2"] } +# hrana-client = { version = "0.3", optional = true } +hrana-client = { git = "https://github.com/libsql/hrana-client-rs.git", rev = "1cb37f1", optional = true } +# hrana-client-proto = { version = "0.2" } +hrana-client-proto = { git = "https://github.com/libsql/hrana-client-rs.git", rev = "1cb37f1" } futures-util = { version = "0.3.21", optional = true } serde = "1.0.159" tracing = "0.1.37" futures = "0.3.28" fallible-iterator = "0.2.0" libsql = { version = "0.1.6", optional = true } +tower = { version = "0.4.13", features = ["make"] } +tokio = { version = "1", default-features = false, optional = true } [features] default = ["local_backend", "hrana_backend", "reqwest_backend", "mapping_names_to_values_in_rows"] workers_backend = ["worker", "futures-util"] -reqwest_backend = ["reqwest"] +reqwest_backend = ["hyper_backend"] +hyper_backend = ["hyper", "hyper-rustls", "tokio"] local_backend = ["libsql"] spin_backend = ["spin-sdk", "http", "bytes"] hrana_backend = ["hrana-client"] diff --git a/src/client.rs b/src/client.rs index 76f9743..f846e10 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,9 @@ //! [Client] is the main structure to interact with the database. use anyhow::Result; +use hyper::{client::HttpConnector, Uri}; +use hyper::client::connect::Connection as HyperConnection; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower::{make::MakeConnection, Service}; use crate::{proto, BatchResult, ResultSet, Statement, SyncTransaction, Transaction}; @@ -9,7 +13,7 @@ static TRANSACTION_IDS: std::sync::atomic::AtomicU64 = std::sync::atomic::Atomic /// It's a convenience struct which allows implementing connect() /// with backends being passed as env parameters. #[derive(Debug)] -pub enum Client { +pub enum Client { #[cfg(feature = "local_backend")] Local(crate::local::Client), #[cfg(any( @@ -17,7 +21,7 @@ pub enum Client { feature = "workers_backend", feature = "spin_backend" ))] - Http(crate::http::Client), + Http(crate::http::Client), #[cfg(feature = "hrana_backend")] Hrana(crate::hrana::Client), Default, @@ -29,7 +33,7 @@ pub struct SyncClient { inner: Client, } -unsafe impl Send for Client {} +unsafe impl Send for Client {} impl Client { /// Executes a batch of independent SQL statements. @@ -57,7 +61,7 @@ impl Client { ) -> Result { match self { #[cfg(feature = "local_backend")] - Self::Local(l) => l.raw_batch(stmts), + Self::Local(l) => l.raw_batch(stmts).await, #[cfg(any( feature = "reqwest_backend", feature = "workers_backend", @@ -179,7 +183,7 @@ impl Client { pub async fn execute(&self, stmt: impl Into + Send) -> Result { match self { #[cfg(feature = "local_backend")] - Self::Local(l) => l.execute(stmt), + Self::Local(l) => l.execute(stmt).await, #[cfg(any( feature = "reqwest_backend", feature = "workers_backend", @@ -218,7 +222,7 @@ impl Client { ) -> Result { match self { #[cfg(feature = "local_backend")] - Self::Local(l) => l.execute_in_transaction(tx_id, stmt), + Self::Local(l) => l.execute_in_transaction(tx_id, stmt).await, #[cfg(any( feature = "reqwest_backend", feature = "workers_backend", @@ -235,7 +239,7 @@ impl Client { pub(crate) async fn commit_transaction(&self, tx_id: u64) -> Result<()> { match self { #[cfg(feature = "local_backend")] - Self::Local(l) => l.commit_transaction(tx_id), + Self::Local(l) => l.commit_transaction(tx_id).await, #[cfg(any( feature = "reqwest_backend", feature = "workers_backend", @@ -252,7 +256,7 @@ impl Client { pub(crate) async fn rollback_transaction(&self, tx_id: u64) -> Result<()> { match self { #[cfg(feature = "local_backend")] - Self::Local(l) => l.rollback_transaction(tx_id), + Self::Local(l) => l.rollback_transaction(tx_id).await, #[cfg(any( feature = "reqwest_backend", feature = "workers_backend", @@ -267,38 +271,17 @@ impl Client { } } -impl Client { - /// Creates an in-memory database - /// - /// # Examples - /// - /// ``` - /// # async fn f() { - /// # use libsql_client::Config; - /// let db = libsql_client::Client::in_memory().unwrap(); - /// # } - /// ``` - #[cfg(feature = "local_backend")] - pub fn in_memory() -> anyhow::Result { - Ok(Client::Local(crate::local::Client::in_memory()?)) - } - - /// Establishes a database client based on [Config] struct - /// - /// # Examples - /// - /// ``` - /// # async fn f() { - /// # use libsql_client::Config; - /// let config = Config { - /// url: url::Url::parse("file:////tmp/example.db").unwrap(), - /// auth_token: None - /// }; - /// let db = libsql_client::Client::from_config(config).await.unwrap(); - /// # } - /// ``` - #[allow(unreachable_patterns)] - pub async fn from_config<'a>(mut config: Config) -> anyhow::Result { +impl Client +where + C: Service + Send + Clone + Sync + 'static, + C::Response: HyperConnection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + C::Future: Send + 'static, + C::Error: Into>, +{ + pub async fn from_config_with_connector(mut config: Config, connector: C) -> anyhow::Result> + where + C: MakeConnection, + { config.url = if config.url.scheme() == "libsql" { // We cannot use url::Url::set_scheme() because it prevents changing the scheme to http... // Safe to unwrap, because we know that the scheme is libsql @@ -318,7 +301,7 @@ impl Client { }, #[cfg(feature = "reqwest_backend")] "http" | "https" => { - let inner = crate::http::InnerClient::Reqwest(crate::reqwest::HttpClient::new()); + let inner = crate::http::InnerClient::Reqwest(crate::hyper::HttpClient::with_connector(connector)); Client::Http(crate::http::Client::from_config(inner, config)?) }, #[cfg(feature = "workers_backend")] @@ -335,6 +318,44 @@ impl Client { }) } +} + +impl Client { + /// Creates an in-memory database + /// + /// # Examples + /// + /// ``` + /// # async fn f() { + /// # use libsql_client::Config; + /// let db = libsql_client::Client::in_memory().unwrap(); + /// # } + /// ``` + #[cfg(feature = "local_backend")] + pub fn in_memory() -> anyhow::Result { + Ok(Client::Local(crate::local::Client::in_memory()?)) + } + + /// Establishes a database client based on [Config] struct + /// + /// # Examples + /// + /// ``` + /// # async fn f() { + /// # use libsql_client::Config; + /// let config = Config { + /// url: url::Url::parse("file:////tmp/example.db").unwrap(), + /// auth_token: None + /// }; + /// let db = libsql_client::Client::from_config(config).await.unwrap(); + /// # } + /// ``` + #[allow(unreachable_patterns)] + pub async fn from_config(config: Config) -> anyhow::Result { + let connector = HttpConnector::new(); + Self::from_config_with_connector(config, connector).await + } + /// Establishes a database client based on environment variables /// /// # Env diff --git a/src/hrana.rs b/src/hrana.rs index 11f3169..0aad30f 100644 --- a/src/hrana.rs +++ b/src/hrana.rs @@ -1,5 +1,10 @@ use crate::client::Config; use anyhow::Result; +use hyper::Uri; +use hyper::client::HttpConnector; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tower::Service; use std::collections::HashMap; use std::sync::Arc; use std::sync::RwLock; @@ -8,13 +13,14 @@ use crate::{utils, BatchResult, ResultSet, Statement}; /// Database client. This is the main structure used to /// communicate with the database. -pub struct Client { +pub struct Client { url: String, token: Option, client: hrana_client::Client, client_future: hrana_client::ConnFut, streams_for_transactions: RwLock>>, + connector: C, } impl std::fmt::Debug for Client { @@ -26,18 +32,22 @@ impl std::fmt::Debug for Client { } } -impl Client { - /// Creates a database client with JWT authentication. - /// - /// # Arguments - /// * `url` - URL of the database endpoint - /// * `token` - auth token - pub async fn new(url: impl Into, token: impl Into) -> Result { +impl Client +where + C: Service + Send + Clone + Sync + 'static, + C::Response: hyper::client::connect::Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + C::Future: Send + 'static, + C::Error: std::error::Error + Sync + Send + 'static, +{ + /// Same as `new`, but uses `connector` to create connections. + pub async fn new_with_connector(url: impl Into, token: impl Into, connector: C) -> Result + + { let token = token.into(); let token = if token.is_empty() { None } else { Some(token) }; let url = url.into(); - let (client, client_future) = hrana_client::Client::connect(&url, token.clone()).await?; + let (client, client_future) = hrana_client::Client::with_connector(&url, token.clone(), connector.clone()).await?; Ok(Self { url, @@ -45,16 +55,29 @@ impl Client { client, client_future, streams_for_transactions: RwLock::new(HashMap::new()), + connector, }) } pub async fn reconnect(&mut self) -> Result<()> { let (client, client_future) = - hrana_client::Client::connect(&self.url, self.token.clone()).await?; + hrana_client::Client::with_connector(&self.url, self.token.clone(), self.connector.clone()).await?; self.client = client; self.client_future = client_future; Ok(()) } +} + +impl Client { + /// Creates a database client with JWT authentication. + /// + /// # Arguments + /// * `url` - URL of the database endpoint + /// * `token` - auth token + pub async fn new(url: impl Into, token: impl Into) -> Result { + let connector = HttpConnector::new(); + Self::new_with_connector(url, token, connector).await + } /// Creates a database client, given a `Url` /// diff --git a/src/http.rs b/src/http.rs index f984474..5c8821a 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,5 +1,9 @@ use crate::client::Config; use anyhow::Result; +use hyper::Uri; +use hyper::client::connect::Connection as HyperConnection; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower::Service; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -16,17 +20,17 @@ struct Cookie { /// Generic HTTP client. Needs a helper function that actually sends /// the request. #[derive(Clone, Debug)] -pub struct Client { - inner: InnerClient, +pub struct Client { + inner: InnerClient, cookies: Arc>>, url_for_queries: String, auth: String, } #[derive(Clone, Debug)] -pub enum InnerClient { +pub enum InnerClient { #[cfg(feature = "reqwest_backend")] - Reqwest(crate::reqwest::HttpClient), + Reqwest(crate::hyper::HttpClient), #[cfg(feature = "workers_backend")] Workers(crate::workers::HttpClient), #[cfg(feature = "spin_backend")] @@ -34,7 +38,13 @@ pub enum InnerClient { Default, } -impl InnerClient { +impl InnerClient +where + C: Service + Send + Clone + Sync + 'static, + C::Response: HyperConnection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + C::Future: Send + 'static, + C::Error: Into>, +{ pub async fn send( &self, url: String, @@ -53,13 +63,13 @@ impl InnerClient { } } -impl Client { +impl Client { /// Creates a database client with JWT authentication. /// /// # Arguments /// * `url` - URL of the database endpoint /// * `token` - auth token - pub fn new(inner: InnerClient, url: impl Into, token: impl Into) -> Self { + pub fn new(inner: InnerClient, url: impl Into, token: impl Into) -> Self { let token = token.into(); let url = url.into(); // Auto-update the URL to start with https:// if no protocol was specified @@ -78,7 +88,7 @@ impl Client { } /// Establishes a database client from a `Config` object - pub fn from_config(inner: InnerClient, config: Config) -> anyhow::Result { + pub fn from_config(inner: InnerClient, config: Config) -> anyhow::Result { Ok(Self::new( inner, config.url, @@ -86,7 +96,7 @@ impl Client { )) } - pub fn from_env(inner: InnerClient) -> anyhow::Result { + pub fn from_env(inner: InnerClient) -> anyhow::Result { let url = std::env::var("LIBSQL_CLIENT_URL").map_err(|_| { anyhow::anyhow!("LIBSQL_CLIENT_URL variable should point to your sqld database") })?; @@ -96,7 +106,13 @@ impl Client { } } -impl Client { +impl Client +where + C: Service + Send + Clone + Sync + 'static, + C::Response: HyperConnection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + C::Future: Send + 'static, + C::Error: Into>, +{ fn into_hrana(stmt: Statement) -> crate::proto::Stmt { let mut hrana_stmt = crate::proto::Stmt::new(stmt.sql, true); for param in stmt.args { diff --git a/src/hyper.rs b/src/hyper.rs new file mode 100644 index 0000000..ee0fdf5 --- /dev/null +++ b/src/hyper.rs @@ -0,0 +1,79 @@ +use anyhow::Result; +use hyper::{Request, Body, StatusCode, Uri, }; +use hyper::body::{to_bytes, HttpBody}; +use hyper::client::{HttpConnector, connect::Connection}; +use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower::Service; + +use crate::proto::pipeline; + +#[derive(Clone, Debug)] +pub struct HttpClient { + inner: hyper::client::Client>, +} + +pub async fn to_text(body: T) -> anyhow::Result +where + T: HttpBody, + T::Error: std::error::Error + Sync + Send + 'static, +{ + let bytes = to_bytes(body).await?; + Ok(String::from_utf8(bytes.to_vec())?) +} + +impl HttpClient { + pub fn new() -> Self { + let connector = HttpConnector::new(); + Self::with_connector(connector) + } +} + +impl HttpClient +where + C: Service + Send + Clone + Sync + 'static, + C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + C::Future: Send + 'static, + C::Error: Into>, +{ + /// Creates an HttpClient using the provided connector. + pub fn with_connector(connector: C) -> Self { + let connector = HttpsConnectorBuilder::new() + .with_native_roots() + .https_or_http() + .enable_http2() + .wrap_connector(connector); + + let builder = hyper::client::Client::builder(); + let inner = builder.build(connector); + + Self { inner } + } + + pub async fn send( + &self, + url: String, + auth: String, + body: String, + ) -> Result { + let request = Request::post(url) + .header("Authorization", auth) + .body(Body::from(body))?; + + let response = self.inner.request(request).await?; + if response.status() != StatusCode::OK { + let status = response.status(); + let txt = to_text(response.into_body()).await?; + anyhow::bail!("{status}: {txt}"); + } + let resp = to_text(response.into_body()).await?; + let response: pipeline::ServerMsg = serde_json::from_str(&resp)?; + Ok(response) + } +} + +impl Default for HttpClient { + fn default() -> Self { + Self::new() + } +} diff --git a/src/lib.rs b/src/lib.rs index 087acb8..097e32e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,7 +167,7 @@ pub use transaction::{SyncTransaction, Transaction}; pub mod workers; #[cfg(feature = "reqwest_backend")] -pub mod reqwest; +pub mod hyper; #[cfg(feature = "local_backend")] pub mod local; diff --git a/src/local.rs b/src/local.rs index 7298c5a..7baecad 100644 --- a/src/local.rs +++ b/src/local.rs @@ -93,7 +93,7 @@ impl Client { /// .raw_batch(["CREATE TABLE t(id)", "INSERT INTO t VALUES (42)"]); /// # } /// ``` - pub fn raw_batch( + pub async fn raw_batch( &self, stmts: impl IntoIterator>, ) -> anyhow::Result { @@ -109,7 +109,7 @@ impl Client { .map(libsql::Value::from) .collect::>() .into(); - let stmt = self.conn.prepare(sql_string)?; + let stmt = self.conn.prepare(sql_string).await?; let cols: Vec = stmt .columns() .into_iter() @@ -118,7 +118,7 @@ impl Client { }) .collect(); let mut rows = Vec::new(); - let input_rows = match stmt.query(¶ms) { + let mut input_rows = match stmt.query(¶ms).await { Ok(rows) => rows, Err(e) => { step_results.push(None); @@ -172,7 +172,7 @@ impl Client { /// /// # Arguments /// * `stmts` - SQL statements - pub fn batch( + pub async fn batch( &self, stmts: impl IntoIterator + Send> + Send, ) -> Result> { @@ -180,7 +180,7 @@ impl Client { std::iter::once(Statement::new("BEGIN")) .chain(stmts.into_iter().map(|s| s.into())) .chain(std::iter::once(Statement::new("END"))), - )?; + ).await?; let step_error: Option = batch_results .step_errors .into_iter() @@ -207,8 +207,8 @@ impl Client { /// # Arguments /// * `stmt` - the SQL statement - pub fn execute(&self, stmt: impl Into + Send) -> Result { - let results = self.raw_batch(std::iter::once(stmt))?; + pub async fn execute(&self, stmt: impl Into + Send) -> Result { + let results = self.raw_batch(std::iter::once(stmt)).await?; match (results.step_results.first(), results.step_errors.first()) { (Some(Some(result)), Some(None)) => Ok(ResultSet::from(result.clone())), (Some(None), Some(Some(err))) => Err(anyhow::anyhow!(err.message.clone())), @@ -216,15 +216,15 @@ impl Client { } } - pub fn execute_in_transaction(&self, _tx_id: u64, stmt: Statement) -> Result { - self.execute(stmt) + pub async fn execute_in_transaction(&self, _tx_id: u64, stmt: Statement) -> Result { + self.execute(stmt).await } - pub fn commit_transaction(&self, _tx_id: u64) -> Result<()> { - self.execute("COMMIT").map(|_| ()) + pub async fn commit_transaction(&self, _tx_id: u64) -> Result<()> { + self.execute("COMMIT").await.map(|_| ()) } - pub fn rollback_transaction(&self, _tx_id: u64) -> Result<()> { - self.execute("ROLLBACK").map(|_| ()) + pub async fn rollback_transaction(&self, _tx_id: u64) -> Result<()> { + self.execute("ROLLBACK").await.map(|_| ()) } } diff --git a/src/reqwest.rs b/src/reqwest.rs deleted file mode 100644 index 988c1eb..0000000 --- a/src/reqwest.rs +++ /dev/null @@ -1,45 +0,0 @@ -use anyhow::Result; - -use crate::proto::pipeline; - -#[derive(Clone, Debug)] -pub struct HttpClient { - inner: reqwest::Client, -} - -impl HttpClient { - pub fn new() -> Self { - Self { - inner: reqwest::Client::new(), - } - } - - pub async fn send( - &self, - url: String, - auth: String, - body: String, - ) -> Result { - let response = self - .inner - .post(url) - .body(body) - .header("Authorization", auth) - .send() - .await?; - if response.status() != reqwest::StatusCode::OK { - let status = response.status(); - let txt = response.text().await.unwrap_or_default(); - anyhow::bail!("{status}: {txt}"); - } - let resp: String = response.text().await?; - let response: pipeline::ServerMsg = serde_json::from_str(&resp)?; - Ok(response) - } -} - -impl Default for HttpClient { - fn default() -> Self { - Self::new() - } -}