Skip to content

Commit 5810507

Browse files
authored
Improve SQLite-typed input for API (#82)
* sqlite-input-improvements * fix unused imports
1 parent 94bd76a commit 5810507

File tree

9 files changed

+199
-45
lines changed

9 files changed

+199
-45
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] }
5454
rustls-pemfile = "1.0.2"
5555
seahash = "4.1.0"
5656
serde = "1.0.159"
57-
serde_json = "1.0.95"
57+
serde_json = { version = "1.0.95", features = ["raw_value"] }
5858
serde_with = "2.3.2"
5959
smallvec = { version = "1.11.0", features = ["serde", "write", "union"] }
6060
speedy = { version = "0.8.7", features = ["uuid", "smallvec"], package = "corro-speedy" }

crates/corro-agent/src/agent.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2736,7 +2736,7 @@ pub mod tests {
27362736
use serde_json::json;
27372737
use spawn::wait_for_all_pending_handles;
27382738
use tokio::time::{sleep, timeout, MissedTickBehavior};
2739-
use tracing::{info, info_span};
2739+
use tracing::info_span;
27402740
use tripwire::Tripwire;
27412741

27422742
use super::*;

crates/corro-agent/src/api/public/mod.rs

+34-16
Original file line numberDiff line numberDiff line change
@@ -308,16 +308,25 @@ where
308308

309309
#[tracing::instrument(skip_all, err)]
310310
fn execute_statement(tx: &Transaction, stmt: &Statement) -> rusqlite::Result<usize> {
311-
let mut prepped = match &stmt {
312-
Statement::Simple(q) => tx.prepare(q),
313-
Statement::WithParams(q, _) => tx.prepare(q),
314-
Statement::WithNamedParams(q, _) => tx.prepare(q),
315-
}?;
311+
let mut prepped = tx.prepare(stmt.query())?;
316312

317313
match stmt {
318-
Statement::Simple(_) => prepped.execute([]),
319-
Statement::WithParams(_, params) => prepped.execute(params_from_iter(params)),
320-
Statement::WithNamedParams(_, params) => prepped.execute(
314+
Statement::Simple(_)
315+
| Statement::Verbose {
316+
params: None,
317+
named_params: None,
318+
..
319+
} => prepped.execute([]),
320+
Statement::WithParams(_, params)
321+
| Statement::Verbose {
322+
params: Some(params),
323+
..
324+
} => prepped.execute(params_from_iter(params)),
325+
Statement::WithNamedParams(_, params)
326+
| Statement::Verbose {
327+
named_params: Some(params),
328+
..
329+
} => prepped.execute(
321330
params
322331
.iter()
323332
.map(|(k, v)| (k.as_str(), v as &dyn ToSql))
@@ -429,11 +438,7 @@ async fn build_query_rows_response(
429438
}
430439
};
431440

432-
let prepped_res = block_in_place(|| match &stmt {
433-
Statement::Simple(q) => conn.prepare(q),
434-
Statement::WithParams(q, _) => conn.prepare(q),
435-
Statement::WithNamedParams(q, _) => conn.prepare(q),
436-
});
441+
let prepped_res = block_in_place(|| conn.prepare(stmt.query()));
437442

438443
let mut prepped = match prepped_res {
439444
Ok(prepped) => prepped,
@@ -476,9 +481,22 @@ async fn build_query_rows_response(
476481
let start = Instant::now();
477482

478483
let query = match stmt {
479-
Statement::Simple(_) => prepped.query(()),
480-
Statement::WithParams(_, params) => prepped.query(params_from_iter(params)),
481-
Statement::WithNamedParams(_, params) => prepped.query(
484+
Statement::Simple(_)
485+
| Statement::Verbose {
486+
params: None,
487+
named_params: None,
488+
..
489+
} => prepped.query(()),
490+
Statement::WithParams(_, params)
491+
| Statement::Verbose {
492+
params: Some(params),
493+
..
494+
} => prepped.query(params_from_iter(params)),
495+
Statement::WithNamedParams(_, params)
496+
| Statement::Verbose {
497+
named_params: Some(params),
498+
..
499+
} => prepped.query(
482500
params
483501
.iter()
484502
.map(|(k, v)| (k.as_str(), v as &dyn ToSql))

crates/corro-agent/src/api/public/pubsub.rs

+20-5
Original file line numberDiff line numberDiff line change
@@ -202,16 +202,31 @@ pub async fn process_sub_channel(
202202

203203
fn expanded_statement(conn: &Connection, stmt: &Statement) -> rusqlite::Result<Option<String>> {
204204
Ok(match stmt {
205-
Statement::Simple(q) => conn.prepare(q)?.expanded_sql(),
206-
Statement::WithParams(q, params) => {
207-
let mut prepped = conn.prepare(q)?;
205+
Statement::Simple(query)
206+
| Statement::Verbose {
207+
query,
208+
params: None,
209+
named_params: None,
210+
} => conn.prepare(query)?.expanded_sql(),
211+
Statement::WithParams(query, params)
212+
| Statement::Verbose {
213+
query,
214+
params: Some(params),
215+
..
216+
} => {
217+
let mut prepped = conn.prepare(query)?;
208218
for (i, param) in params.iter().enumerate() {
209219
prepped.raw_bind_parameter(i + 1, param)?;
210220
}
211221
prepped.expanded_sql()
212222
}
213-
Statement::WithNamedParams(q, params) => {
214-
let mut prepped = conn.prepare(q)?;
223+
Statement::WithNamedParams(query, params)
224+
| Statement::Verbose {
225+
query,
226+
named_params: Some(params),
227+
..
228+
} => {
229+
let mut prepped = conn.prepare(query)?;
215230
for (k, v) in params.iter() {
216231
let idx = match prepped.parameter_index(k)? {
217232
Some(idx) => idx,

crates/corro-api-types/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ compact_str = { workspace = true }
1313
hex = { workspace = true }
1414
rusqlite = { workspace = true }
1515
serde = { workspace = true }
16+
serde_json = { workspace = true }
1617
smallvec = { workspace = true }
1718
speedy = { workspace = true }
1819
strum = { workspace = true }
1920
thiserror = { workspace = true }
20-
tokio = { workspace = true }
21+
tokio = { workspace = true }

crates/corro-api-types/src/lib.rs

+125-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rusqlite::{
1111
Row, ToSql,
1212
};
1313
use serde::{Deserialize, Serialize};
14+
use serde_json::value::RawValue;
1415
use smallvec::{SmallVec, ToSmallVec};
1516
use speedy::{Context, Readable, Reader, Writable, Writer};
1617
use sqlite::ChangeType;
@@ -120,9 +121,25 @@ impl ToSql for ChangeId {
120121
#[derive(Debug, Clone, Serialize, Deserialize)]
121122
#[serde(untagged)]
122123
pub enum Statement {
124+
Verbose {
125+
query: String,
126+
params: Option<Vec<SqliteParam>>,
127+
named_params: Option<HashMap<String, SqliteParam>>,
128+
},
123129
Simple(String),
124-
WithParams(String, Vec<SqliteValue>),
125-
WithNamedParams(String, HashMap<String, SqliteValue>),
130+
WithParams(String, Vec<SqliteParam>),
131+
WithNamedParams(String, HashMap<String, SqliteParam>),
132+
}
133+
134+
impl Statement {
135+
pub fn query(&self) -> &str {
136+
match self {
137+
Statement::Verbose { query, .. }
138+
| Statement::Simple(query)
139+
| Statement::WithParams(query, _)
140+
| Statement::WithNamedParams(query, _) => query,
141+
}
142+
}
126143
}
127144

128145
impl From<&str> for Statement {
@@ -292,6 +309,76 @@ impl FromSql for ColumnType {
292309
}
293310
}
294311

312+
#[allow(clippy::large_enum_variant)]
313+
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
314+
#[serde(untagged)]
315+
pub enum SqliteParam {
316+
#[default]
317+
Null,
318+
Bool(bool),
319+
Integer(i64),
320+
Real(f64),
321+
Text(CompactString),
322+
Blob(SmallVec<[u8; 512]>),
323+
Json(Box<RawValue>),
324+
}
325+
326+
impl From<&str> for SqliteParam {
327+
fn from(value: &str) -> Self {
328+
Self::Text(value.into())
329+
}
330+
}
331+
332+
impl From<Vec<u8>> for SqliteParam {
333+
fn from(value: Vec<u8>) -> Self {
334+
Self::Blob(value.into())
335+
}
336+
}
337+
338+
impl From<String> for SqliteParam {
339+
fn from(value: String) -> Self {
340+
Self::Text(value.into())
341+
}
342+
}
343+
344+
impl From<u16> for SqliteParam {
345+
fn from(value: u16) -> Self {
346+
Self::Integer(value as i64)
347+
}
348+
}
349+
350+
impl From<i64> for SqliteParam {
351+
fn from(value: i64) -> Self {
352+
Self::Integer(value)
353+
}
354+
}
355+
356+
impl ToSql for SqliteParam {
357+
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
358+
Ok(match self {
359+
SqliteParam::Null => ToSqlOutput::Owned(Value::Null),
360+
SqliteParam::Bool(v) => ToSqlOutput::Owned(Value::Integer(*v as i64)),
361+
SqliteParam::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)),
362+
SqliteParam::Real(f) => ToSqlOutput::Owned(Value::Real(*f)),
363+
SqliteParam::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())),
364+
SqliteParam::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
365+
SqliteParam::Json(map) => ToSqlOutput::Borrowed(ValueRef::Text(map.get().as_bytes())),
366+
})
367+
}
368+
}
369+
370+
impl<'a> ToSql for SqliteValueRef<'a> {
371+
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'a>> {
372+
Ok(match self {
373+
SqliteValueRef::Null => ToSqlOutput::Owned(Value::Null),
374+
SqliteValueRef::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)),
375+
SqliteValueRef::Real(f) => ToSqlOutput::Owned(Value::Real(*f)),
376+
SqliteValueRef::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())),
377+
SqliteValueRef::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
378+
})
379+
}
380+
}
381+
295382
#[allow(clippy::large_enum_variant)]
296383
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Hash)]
297384
#[serde(untagged)]
@@ -655,3 +742,39 @@ impl ToSql for ColumnName {
655742
self.0.as_str().to_sql()
656743
}
657744
}
745+
746+
#[cfg(test)]
747+
mod tests {
748+
use super::*;
749+
750+
#[test]
751+
fn test_statement_serialization() {
752+
let s = serde_json::to_string(&vec![Statement::WithParams(
753+
"select 1
754+
from table
755+
where column = ?"
756+
.into(),
757+
vec!["my-value".into()],
758+
)])
759+
.unwrap();
760+
println!("{s}");
761+
762+
let stmts: Vec<Statement> = serde_json::from_str(&s).unwrap();
763+
println!("stmts: {stmts:?}");
764+
765+
let json = r#"[["some statement",[1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]]]"#;
766+
767+
let value: serde_json::Value = serde_json::from_str(json).unwrap();
768+
println!("value: {value:#?}");
769+
770+
let stmts: Vec<Statement> = serde_json::from_str(json).unwrap();
771+
println!("stmts: {stmts:?}");
772+
773+
let json = r#"[{"query": "some statement", "params": [1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]}]"#;
774+
let value: serde_json::Value = serde_json::from_str(json).unwrap();
775+
println!("value: {value:#?}");
776+
777+
let stmts: Vec<Statement> = serde_json::from_str(json).unwrap();
778+
println!("stmts: {stmts:?}");
779+
}
780+
}

crates/corro-tpl/src/lib.rs

+12-16
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use compact_str::ToCompactString;
1212
use corro_client::sub::SubscriptionStream;
1313
use corro_client::CorrosionApiClient;
1414
use corro_types::api::QueryEvent;
15+
use corro_types::api::SqliteParam;
1516
use corro_types::api::Statement;
1617
use corro_types::change::SqliteValue;
1718
use futures::StreamExt;
@@ -536,33 +537,28 @@ impl Engine {
536537
}
537538
});
538539

539-
fn dyn_to_sql(v: Dynamic) -> Result<SqliteValue, Box<EvalAltResult>> {
540+
fn dyn_to_sql(v: Dynamic) -> Result<SqliteParam, Box<EvalAltResult>> {
540541
Ok(match v.type_name() {
541-
"()" => SqliteValue::Null,
542-
"i64" => SqliteValue::Integer(
542+
"()" => SqliteParam::Null,
543+
"i64" => SqliteParam::Integer(
543544
v.as_int()
544545
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to i64")))?,
545546
),
546-
"f64" => SqliteValue::Real(corro_types::api::Real(
547+
"f64" => SqliteParam::Real(
547548
v.as_float()
548549
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to f64")))?,
549-
)),
550-
"bool" => {
551-
if v.as_bool()
552-
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?
553-
{
554-
SqliteValue::Integer(1)
555-
} else {
556-
SqliteValue::Integer(0)
557-
}
558-
}
559-
"blob" => SqliteValue::Blob(
550+
),
551+
"bool" => SqliteParam::Bool(
552+
v.as_bool()
553+
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?,
554+
),
555+
"blob" => SqliteParam::Blob(
560556
v.into_blob()
561557
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to blob")))?
562558
.into(),
563559
),
564560
// convert everything else into a string, including a string
565-
_ => SqliteValue::Text(v.to_compact_string()),
561+
_ => SqliteParam::Text(v.to_compact_string()),
566562
})
567563
}
568564

crates/corrosion/src/main.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use command::{
1212
tls::{generate_ca, generate_client_cert, generate_server_cert},
1313
tpl::TemplateFlags,
1414
};
15-
use corro_api_types::SqliteValue;
15+
use corro_api_types::SqliteParam;
1616
use corro_client::CorrosionApiClient;
1717
use corro_types::{
1818
api::{ExecResult, QueryEvent, Statement},
@@ -301,7 +301,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> {
301301
} else {
302302
Statement::WithParams(
303303
query.clone(),
304-
param.iter().map(|p| SqliteValue::Text(p.into())).collect(),
304+
param.iter().map(|p| SqliteParam::Text(p.into())).collect(),
305305
)
306306
};
307307

@@ -359,7 +359,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> {
359359
} else {
360360
Statement::WithParams(
361361
query.clone(),
362-
param.iter().map(|p| SqliteValue::Text(p.into())).collect(),
362+
param.iter().map(|p| SqliteParam::Text(p.into())).collect(),
363363
)
364364
};
365365

0 commit comments

Comments
 (0)