Skip to content

Commit f199084

Browse files
Add ClientBuilder and support for connecting to local spice runtime (#38)
* Add `ClientBuilder` and add support for local spice connection * Update ci tests * cleanup * add docs comments * update docs * Update src/client.rs Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech> * Update src/flight.rs Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech> * Update src/client.rs Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech> * Update src/client.rs Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech> * Update src/client.rs Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech> * Update src/client.rs Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech> * update comments * update builder api usage * add builder method * fix --------- Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech>
1 parent ef59c61 commit f199084

9 files changed

Lines changed: 224 additions & 49 deletions

File tree

.github/workflows/build.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ jobs:
4848
- name: Build
4949
run: cargo build --verbose
5050

51+
- name: install Spice (https://install.spiceai.org)
52+
run: |
53+
curl https://install.spiceai.org | /bin/bash
54+
echo "$HOME/.spice/bin" >> $GITHUB_PATH
55+
56+
- name: Init and start spice app
57+
run: |
58+
spice init spice_qs
59+
cd spice_qs
60+
spice add spiceai/quickstart
61+
spice run &> spice.log &
62+
# time to initialize added dataset
63+
sleep 10
64+
5165
- name: Run tests
5266
run: cargo test --verbose
5367
env:

README.md

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,38 @@ cargo add spiceai
1414

1515
<!-- NOTE: If you're changing the code examples below, make sure you update `tests/readme_test.rs`. -->
1616

17-
### New client
17+
### Usage with locally running [spice runtime](https://github.com/spiceai/spiceai)
18+
19+
Follow the [quiqstart guide](https://github.com/spiceai/spiceai?tab=readme-ov-file#%EF%B8%8F-quickstart-local-machine) to install and run spice locally
20+
21+
```rust
22+
use spiceai::ClientBuilder;
23+
24+
#[tokio::main]
25+
async fn main() {
26+
let mut client = ClientBuilder::new()
27+
.flight_url("http://localhost:50051")
28+
.build()
29+
.await
30+
.unwrap();
31+
32+
let data = client.query("SELECT trip_distance, total_amount FROM taxi_trips ORDER BY trip_distance DESC LIMIT 10;").await;
33+
}
34+
```
35+
36+
### New client with https://spice.ai cloud
1837

1938
```rust
20-
use spiceai::Client;
39+
use spiceai::ClientBuilder;
2140

2241
#[tokio::main]
2342
async fn main() {
24-
let mut client = Client::new("API_KEY").await.unwrap();
43+
let mut client = ClientBuilder::new()
44+
.api_key("API_KEY")
45+
.use_spiceai_cloud()
46+
.build()
47+
.await
48+
.unwrap();
2549
}
2650
```
2751

@@ -30,11 +54,17 @@ async fn main() {
3054
SQL Query
3155

3256
```rust
33-
use spiceai::Client;
57+
use spiceai::ClientBuilder;
3458

3559
#[tokio::main]
3660
async fn main() {
37-
let mut client = Client::new("API_KEY").await.unwrap();
61+
let mut client = ClientBuilder::new()
62+
.api_key("API_KEY")
63+
.use_spiceai_cloud()
64+
.build()
65+
.await
66+
.unwrap();
67+
3868
let data = client.query("SELECT * FROM eth.recent_blocks LIMIT 10;").await;
3969
}
4070

@@ -45,11 +75,17 @@ async fn main() {
4575
Firecache SQL Query
4676

4777
```rust
48-
use spiceai::Client;
78+
use spiceai::ClientBuilder;
4979

5080
#[tokio::main]
5181
async fn main() {
52-
let mut client = Client::new("API_KEY").await.unwrap();
82+
let mut client = ClientBuilder::new()
83+
.api_key("API_KEY")
84+
.use_spiceai_cloud()
85+
.build()
86+
.await
87+
.unwrap();
88+
5389
let data = client.fire_query("SELECT * FROM eth.recent_blocks LIMIT 10;").await;
5490
}
5591

src/client.rs

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
config::{FIRECACHE_ADDR, FLIGHT_ADDR, HTTPS_ADDR},
2+
config::{SPICE_CLOUD_FIRECACHE_ADDR, SPICE_CLOUD_FLIGHT_ADDR, SPICE_LOCAL_FLIGHT_ADDR},
33
flight::SqlFlightClient,
44
tls::new_tls_flight_channel,
55
};
@@ -9,31 +9,25 @@ use std::error::Error;
99
use tonic::transport::Channel;
1010

1111
struct SpiceClientConfig {
12-
https_addr: String,
1312
flight_channel: Channel,
1413
firecache_channel: Channel,
1514
}
1615

1716
impl SpiceClientConfig {
18-
fn new(https_addr: String, flight_channel: Channel, firecache_channel: Channel) -> Self {
17+
fn new(flight_channel: Channel, firecache_channel: Channel) -> Self {
1918
SpiceClientConfig {
20-
https_addr,
2119
flight_channel,
2220
firecache_channel,
2321
}
2422
}
2523

2624
pub async fn load_from_default() -> Result<SpiceClientConfig, Box<dyn Error>> {
2725
let (flight_chan, firecache_chan) = try_join!(
28-
new_tls_flight_channel(FLIGHT_ADDR),
29-
new_tls_flight_channel(FIRECACHE_ADDR)
26+
new_tls_flight_channel(SPICE_CLOUD_FLIGHT_ADDR),
27+
new_tls_flight_channel(SPICE_CLOUD_FIRECACHE_ADDR)
3028
)?;
3129

32-
Ok(SpiceClientConfig::new(
33-
HTTPS_ADDR.to_string(),
34-
flight_chan,
35-
firecache_chan,
36-
))
30+
Ok(SpiceClientConfig::new(flight_chan, firecache_chan))
3731
}
3832
}
3933

@@ -59,11 +53,15 @@ impl SpiceClient {
5953
let config = SpiceClientConfig::load_from_default().await?;
6054

6155
Ok(Self {
62-
flight: SqlFlightClient::new(config.flight_channel, api_key.to_string()),
63-
firecache: SqlFlightClient::new(config.firecache_channel, api_key.to_string()),
56+
flight: SqlFlightClient::new(config.flight_channel, Some(api_key.to_string())),
57+
firecache: SqlFlightClient::new(config.firecache_channel, Some(api_key.to_string())),
6458
})
6559
}
6660

61+
pub fn builder() -> SpiceClientBuilder {
62+
SpiceClientBuilder::new()
63+
}
64+
6765
/// Queries the Spice Flight endpoint with the given SQL query.
6866
/// ```
6967
/// # use spiceai::Client;
@@ -95,3 +93,100 @@ impl SpiceClient {
9593
self.firecache.query(query).await
9694
}
9795
}
96+
97+
/// Builder for creating a `SpiceClient`.
98+
///
99+
/// By default the `SpiceClient` will use local spice runtime flight endpoint.
100+
/// Follow [spiceai quickstart](https://github.com/spiceai/spiceai?tab=readme-ov-file#%EF%B8%8F-quickstart-local-machine) to setup local spice runtime.
101+
/// ```
102+
/// # use spiceai::ClientBuilder;
103+
/// #
104+
/// # #[tokio::main]
105+
/// # async fn main() {
106+
/// # let mut client = ClientBuilder::new()
107+
/// # .build()
108+
/// # .await
109+
/// # .unwrap();
110+
/// # }
111+
/// ```
112+
/// To use default Spice.ai Cloud endpoints, you can use the `with_spiceai_cloud()` method.
113+
///
114+
/// ```
115+
/// # use spiceai::ClientBuilder;
116+
/// #
117+
/// # #[tokio::main]
118+
/// # async fn main() {
119+
/// # let mut client = ClientBuilder::new()
120+
/// # .api_key("API_KEY")
121+
/// # .use_spiceai_cloud()
122+
/// # .build()
123+
/// # .await
124+
/// # .unwrap();
125+
/// # }
126+
/// ```
127+
///
128+
pub struct SpiceClientBuilder {
129+
api_key: Option<String>,
130+
firecache_url: Option<String>,
131+
flight_url: Option<String>,
132+
}
133+
134+
impl Default for SpiceClientBuilder {
135+
fn default() -> Self {
136+
Self::new()
137+
}
138+
}
139+
140+
impl SpiceClientBuilder {
141+
pub fn new() -> Self {
142+
Self {
143+
api_key: None,
144+
firecache_url: None,
145+
flight_url: None,
146+
}
147+
}
148+
149+
/// Configures the `SpiceClient` to use the given API key.
150+
pub fn api_key(mut self, api_key: &str) -> Self {
151+
self.api_key = Some(api_key.to_string());
152+
self
153+
}
154+
155+
/// Configures the `SpiceClient` to use the given Spice Firecache endpoint.
156+
pub fn firecache_url(mut self, firecache_url: &str) -> Self {
157+
self.firecache_url = Some(firecache_url.to_string());
158+
self
159+
}
160+
161+
/// Configures the `SpiceClient` to use the given Spice Flight endpoint.
162+
pub fn flight_url(mut self, flight_url: &str) -> Self {
163+
self.flight_url = Some(flight_url.to_string());
164+
self
165+
}
166+
167+
/// Configures the `SpiceClient` to use default Spice.ai Cloud endpoints.
168+
/// Equivalent to calling `.firecache_url("https://firecache.spiceai.io")` and `.flight_url("https://flight.spiceai.io")`.
169+
pub fn use_spiceai_cloud(mut self) -> Self {
170+
self.flight_url = Some(SPICE_CLOUD_FLIGHT_ADDR.to_string());
171+
self.firecache_url = Some(SPICE_CLOUD_FIRECACHE_ADDR.to_string());
172+
self
173+
}
174+
175+
/// Builds the `SpiceClient` with the specified configuration.
176+
pub async fn build(self) -> Result<SpiceClient, Box<dyn Error>> {
177+
let flight_channel = match self.flight_url {
178+
Some(url) => new_tls_flight_channel(&url).await?,
179+
None => new_tls_flight_channel(SPICE_LOCAL_FLIGHT_ADDR).await?,
180+
};
181+
182+
let firecache_channel = match self.firecache_url {
183+
Some(url) => new_tls_flight_channel(&url).await?,
184+
None => new_tls_flight_channel(SPICE_CLOUD_FIRECACHE_ADDR).await?,
185+
};
186+
187+
Ok(SpiceClient {
188+
flight: SqlFlightClient::new(flight_channel, self.api_key.clone()),
189+
firecache: SqlFlightClient::new(firecache_channel, self.api_key.clone()),
190+
})
191+
}
192+
}

src/config.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
pub const HTTPS_ADDR: &str = "https://data.spiceai.io";
2-
pub const FLIGHT_ADDR: &str = "https://flight.spiceai.io";
3-
pub const FIRECACHE_ADDR: &str = "https://firecache.spiceai.io";
1+
pub const SPICE_CLOUD_FLIGHT_ADDR: &str = "https://flight.spiceai.io";
2+
pub const SPICE_CLOUD_FIRECACHE_ADDR: &str = "https://firecache.spiceai.io";
3+
4+
// default address for local spice runtime
5+
pub const SPICE_LOCAL_FLIGHT_ADDR: &str = "http://localhost:50051";

src/flight.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ pub struct SqlFlightClient {
2121
token: Option<String>,
2222
headers: HashMap<String, String>,
2323
client: FlightServiceClient<Channel>,
24-
api_key: String,
24+
api_key: Option<String>,
2525
}
2626

2727
fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
2828
ArrowError::IpcError(format!("{status:?}"))
2929
}
3030

3131
impl SqlFlightClient {
32-
pub fn new(chan: Channel, api_key: String) -> Self {
32+
pub fn new(chan: Channel, api_key: Option<String>) -> Self {
3333
SqlFlightClient {
3434
api_key,
3535
client: FlightServiceClient::new(chan),
@@ -81,11 +81,11 @@ impl SqlFlightClient {
8181
Ok(resp)
8282
}
8383

84-
async fn authenticate(&mut self) -> std::result::Result<(), Box<dyn Error>> {
85-
if self.api_key.split('|').collect::<String>().len() < 2 {
84+
async fn authenticate(&mut self, api_key: &str) -> std::result::Result<(), Box<dyn Error>> {
85+
if api_key.split('|').collect::<String>().len() < 2 {
8686
return Err("Invalid API key format".into());
8787
}
88-
self.handshake("", &self.api_key.to_string()).await?;
88+
self.handshake("", api_key).await?;
8989
Ok(())
9090
}
9191

@@ -115,7 +115,10 @@ impl SqlFlightClient {
115115
&mut self,
116116
query: &str,
117117
) -> std::result::Result<FlightRecordBatchStream, Box<dyn Error>> {
118-
self.authenticate().await?;
118+
let api_key = self.api_key.clone();
119+
if let Some(api_key) = api_key {
120+
self.authenticate(&api_key).await?;
121+
}
119122

120123
let descriptor = FlightDescriptor::new_cmd(query.to_string());
121124
let req = self.set_request_headers(descriptor.into_request())?;

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod flight;
66
mod tls;
77

88
pub use client::SpiceClient as Client;
9+
pub use client::SpiceClientBuilder as ClientBuilder;
910

1011
// Further public exports and integrations
1112
pub use futures::StreamExt;

tests/client_test.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#[cfg(test)]
22
mod tests {
33
use futures::stream::StreamExt;
4-
use spiceai::Client;
4+
use spiceai::{Client, ClientBuilder};
55
use std::env;
66
use std::path::Path;
77

88
async fn new_client() -> Client {
99
dotenv::from_path(Path::new(".env.local")).ok();
1010
let api_key = env::var("API_KEY").expect("API_KEY not found");
11-
Client::new(&api_key)
11+
ClientBuilder::new()
12+
.api_key(&api_key)
13+
.use_spiceai_cloud()
14+
.build()
1215
.await
1316
.expect("Failed to create client")
1417
}
@@ -18,6 +21,11 @@ mod tests {
1821
new_client().await;
1922
}
2023

24+
#[tokio::test]
25+
async fn test_new_client_builder() {
26+
new_client().await;
27+
}
28+
2129
#[tokio::test]
2230
async fn test_query() {
2331
let mut spice_client = new_client().await;

tests/price_test.rs

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)