Skip to content

Commit 91c5823

Browse files
authored
Merge pull request #3 from spiceai/qianqian/v1
Update prices functions
2 parents a1dd5ab + f268283 commit 91c5823

12 files changed

Lines changed: 271 additions & 208 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ Cargo.lock
1212

1313
# MSVC Windows builds of rustc generate these, which store debugging information
1414
*.pdb
15+
16+
# Store local variables for test
17+
.env.local

Cargo.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@ version = "0.1.0"
44
edition = "2021"
55

66
[dependencies]
7-
arrow-flight = { version = "47.0.0", features = ["flight-sql-experimental"] }
7+
arrow-flight = { version = "48.0.0", features = ["flight-sql-experimental"] }
88
bytes = "1.5.0"
99
prost = "0.12.1"
1010
prost-types = "0.12.1"
1111
rustls = "0.21.7"
1212
tokio = "1.32.0"
1313
rustls-native-certs = "0.6.3"
14-
tonic = { version = "0.10.0", default-features = false, features = ["transport", "tls"] }
14+
tonic = { version = "0.10.0", default-features = false, features = ["transport", "tls", "tls-roots"] }
1515
rustls-pemfile = "1.0.3"
1616
reqwest = { version = "0.11.21", features = ["json"] }
1717
serde = "1.0.188"
1818
serde_derive = "1.0.188"
1919
serde_json = "1.0.107"
2020
chrono = { version = "0.4.31", features = ["serde"] }
21-
21+
dotenv = "0.15.0"
22+
arrow = "48.0.0"
23+
futures = "0.3.28"

src/client.rs

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,45 @@
1-
2-
use crate::{flight::{SqlFlightClient}, prices::PricesClient, tls::new_tls_flight_channel};
1+
use crate::{flight::SqlFlightClient, prices::PricesClient, tls::new_tls_flight_channel};
2+
use arrow_flight::decode::FlightRecordBatchStream;
33
use std::error::Error;
4-
use arrow_flight::sql::client::FlightSqlServiceClient;
5-
use tonic::{transport::{Channel}};
6-
use tonic::{Streaming};
7-
use arrow_flight::{FlightData};
8-
9-
10-
pub async fn new_spice_client(api_key: String, http_addr: String, flight_addr: String, firecache_addr: String) -> Result<SpiceClient, Box<dyn Error>> {
11-
let flight_chan = new_tls_flight_channel(flight_addr).await;
12-
if flight_chan.is_err() {
13-
return Err(flight_chan.err().expect("").into())
14-
}
4+
use tonic::transport::Channel;
5+
6+
pub async fn new_spice_client(api_key: String) -> Result<SpiceClient, Box<dyn Error>> {
7+
return new_spice_client_with_address(
8+
api_key.to_string(),
9+
"https://data.spiceai.io".to_string(),
10+
"https://flight.spiceai.io".to_string(),
11+
)
12+
.await;
13+
}
1514

16-
match new_tls_flight_channel(firecache_addr).await {
15+
pub async fn new_spice_client_with_address(
16+
api_key: String,
17+
http_addr: String,
18+
flight_addr: String,
19+
) -> Result<SpiceClient, Box<dyn Error>> {
20+
match new_tls_flight_channel(flight_addr).await {
1721
Err(e) => Err(e.into()),
18-
Ok(firecache_chan) => Ok(SpiceClient::new(
19-
http_addr,
20-
api_key,
21-
flight_chan.expect(""),
22-
firecache_chan
23-
))
22+
Ok(flight_chan) => Ok(SpiceClient::new(http_addr, api_key, flight_chan)),
2423
}
2524
}
2625

2726
pub struct SpiceClient {
2827
flight: SqlFlightClient,
29-
firecache: SqlFlightClient,
30-
pub prices: PricesClient
28+
pub prices: PricesClient,
3129
}
3230

3331
impl SpiceClient {
34-
pub fn new(http_addr: String, api_key: String, flight: Channel, firecache: Channel) -> Self {
32+
pub fn new(http_addr: String, api_key: String, flight: Channel) -> Self {
3533
Self {
3634
flight: SqlFlightClient::new(flight, api_key.clone()),
37-
firecache: SqlFlightClient::new(firecache, api_key.clone()),
38-
prices: PricesClient::new(Some(http_addr), api_key)
35+
prices: PricesClient::new(Some(http_addr), api_key),
3936
}
4037
}
41-
pub async fn query(&mut self, query: String, timeout: Option<u32>) -> Result<Streaming<FlightData>, Box<dyn Error>> {
38+
pub async fn query(
39+
&mut self,
40+
query: String,
41+
timeout: Option<u32>,
42+
) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
4243
self.flight.query(query, timeout).await
4344
}
44-
45-
pub async fn firecache_query(&mut self, query: String, timeout: Option<u32>) -> Result<Streaming<FlightData>, Box<dyn Error>> {
46-
self.firecache.query(query, timeout).await
47-
}
48-
49-
50-
}
45+
}

src/flight.rs

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1+
use arrow_flight::decode::FlightRecordBatchStream;
2+
use arrow_flight::sql::client::FlightSqlServiceClient;
13
use std::error::Error;
2-
use prost::Message;
3-
use tonic::{transport::Channel, Streaming, Request};
4-
use tonic::transport::channel::{Endpoint, ClientTlsConfig,};
5-
use arrow_flight::sql::{client::FlightSqlServiceClient, CommandStatementQuery, ProstMessageExt};
6-
use arrow_flight::FlightData;
7-
use arrow_flight::FlightDescriptor;
4+
use tonic::transport::Channel;
85

96
pub struct SqlFlightClient {
107
client: FlightSqlServiceClient<Channel>,
@@ -15,55 +12,39 @@ impl SqlFlightClient {
1512
pub fn new(chan: Channel, api_key: String) -> Self {
1613
SqlFlightClient {
1714
api_key: api_key,
18-
client: FlightSqlServiceClient::new(chan)
15+
client: FlightSqlServiceClient::new(chan),
1916
}
2017
}
2118

22-
pub async fn authenticate(&mut self) -> Result<(), Box<dyn Error>> {
23-
let parts: Vec<&str> = self.api_key.split("|").collect();
24-
if parts.len() < 2 {
19+
pub async fn authenticate(&mut self) -> std::result::Result<(), Box<dyn Error>> {
20+
if self.api_key.split("|").collect::<String>().len() < 2 {
2521
return Err("Invalid API key format".into());
2622
}
27-
match self.client.handshake(parts[0], parts[1]).await {
23+
match self.client.handshake("", &self.api_key.clone()).await {
2824
Err(e) => Err(e.into()),
29-
Ok(v) => {
30-
self.client.set_token(String::from_utf8(v.to_vec()).expect("something"));
31-
Ok(())
32-
}
25+
Ok(_) => Ok(()),
3326
}
3427
}
3528

36-
pub async fn query(&mut self, query: String, _timeout: Option<u32>) -> Result<Streaming<FlightData>, Box<dyn Error>> {
29+
pub async fn query(
30+
&mut self,
31+
query: String,
32+
_timeout: Option<u32>,
33+
) -> std::result::Result<FlightRecordBatchStream, Box<dyn Error>> {
3734
match self.authenticate().await {
38-
Err(e) => {
39-
return Err(e.into())
40-
},
41-
Ok(()) => {},
42-
};
43-
44-
let cmd = CommandStatementQuery {
45-
query: query.clone(),
46-
..Default::default()
35+
Err(e) => return Err(e.into()),
36+
Ok(()) => {}
4737
};
48-
let fd = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
49-
let req = Request::new(fd);
50-
51-
match self.client.inner_mut().get_flight_info(req).await {
38+
match self.client.execute(query, Option::None).await {
5239
Ok(resp) => {
53-
let flight_info = resp.into_inner();
54-
for ep in flight_info.endpoint {
40+
for ep in resp.endpoint {
5541
if let Some(tkt) = ep.ticket {
5642
return self.client.do_get(tkt).await.map_err(|e| e.into());
5743
}
5844
}
5945
Err("no tickets for flight endpoint".into())
60-
},
61-
// Err(e) => {
62-
// // Handle re-authentication similar to the Python client and then retry the request.
63-
// self.authenticate().await?;
64-
// self.query(query, _timeout)
65-
// },
46+
}
6647
Err(e) => Err(e.into()),
6748
}
6849
}
69-
}
50+
}

src/lib.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
21
mod client;
32
mod flight;
4-
mod prices;
3+
mod prices;
54
mod tls;
65

7-
pub use client::{SpiceClient as Client};
8-
pub use flight::{SqlFlightClient};
6+
pub use client::{new_spice_client, SpiceClient as Client};
97

108
// Further public exports and integrations

0 commit comments

Comments
 (0)