Skip to content

Commit 479b94b

Browse files
authored
Support loading only the requested fields (#31)
1 parent 8ac281f commit 479b94b

File tree

13 files changed

+1522
-166
lines changed

13 files changed

+1522
-166
lines changed

README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async fn example() -> anyhow::Result<()> {
9494
// Read
9595
let mut calls = 0;
9696
let filter = Filter::new(&[database_id], &[granularity], start..=end);
97-
for group in CompressedQueryStats::load(db, filter.clone()).await? {
97+
for group in CompressedQueryStats::load(db, filter.clone(), ()).await? {
9898
for stat in group.decompress()? {
9999
calls += stat.calls;
100100
}
@@ -106,7 +106,7 @@ async fn example() -> anyhow::Result<()> {
106106
assert_eq!(2, db.query_one("SELECT count(*) FROM query_stats", &[]).await?.get::<_, i64>(0));
107107
transaction!(db, {
108108
let mut stats = Vec::new();
109-
for group in CompressedQueryStats::delete(db, filter.clone()).await? {
109+
for group in CompressedQueryStats::delete(db, filter.clone(), ()).await? {
110110
stats.extend(group.decompress()?);
111111
}
112112
assert_eq!(0, db.query_one("SELECT count(*) FROM query_stats", &[]).await?.get::<_, i64>(0));
@@ -117,7 +117,7 @@ async fn example() -> anyhow::Result<()> {
117117
.await?;
118118
});
119119
assert_eq!(1, db.query_one("SELECT count(*) FROM query_stats", &[]).await?.get::<_, i64>(0));
120-
let group = CompressedQueryStats::load(db, filter).await?.remove(0);
120+
let group = CompressedQueryStats::load(db, filter, ()).await?.remove(0);
121121
assert_eq!(group.start_at, end - Duration::from_secs(120));
122122
assert_eq!(group.end_at, end - Duration::from_secs(60));
123123
let stats = group.decompress()?;
@@ -189,6 +189,15 @@ Timestamps support multiple formats:
189189
- `range_duration` returns the duration of the filter's time range
190190
- `range_shift` mutably shifts the time range's start and end by a certain amount, e.g. to filter for "today, 7 days ago"
191191

192+
## Loading a subset of fields
193+
194+
Read requests that don't need all fields in a struct can be optimized by declaring which fields they need, allowing pco_store to skip the others. Fields can be specified in several ways:
195+
- `()` or `Fields::default()`: load all fields
196+
- `[]` or `Fields::required()`: load only the required fields from `group_by` and `timestamp`
197+
- `["other"]` or `Fields::new(["other"])`: load extra fields in addition to the required ones
198+
199+
Note that when optional filters are combined with `Fields::required()`, the fields needed by those filters are automatically added to the fields to be loaded.
200+
192201
## Contributions are welcome to
193202

194203
- support decompression of only the fields requested at runtime

benches/comparison/pco_store.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ pub async fn load() -> Result<()> {
9191
let database_ids: Vec<i64> = db.query_one("SELECT array_agg(DISTINCT database_id) FROM comparison_pco_stores", &[]).await?.get(0);
9292
let mut stats = Vec::new();
9393
let filter = Filter::new(&database_ids, SystemTime::UNIX_EPOCH..=SystemTime::now());
94-
for group in CompressedQueryStats::load(db, filter).await? {
94+
for group in CompressedQueryStats::load(db, filter, ()).await? {
9595
for stat in group.decompress()? {
9696
stats.push(stat);
9797
}

rust-toolchain.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[toolchain]
2+
channel = "1.93.0"

src/lib.rs

Lines changed: 185 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,19 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
7373
};
7474

7575
// load and delete
76-
let mut fields = vec![quote! { filter: Option<Filter>, }];
76+
let mut packed_fields = vec![quote! { filter: Option<Filter>, }];
7777
let mut load_checks = Vec::new();
7878
let mut load_where = Vec::new();
7979
let mut load_params = Vec::new();
80-
let mut load_fields = Vec::new();
8180
let mut bind = 1;
82-
let mut index: usize = 0;
8381
let mut timestamp_ty = None;
8482
let mut using_chrono = false;
8583
for field in model.fields.iter() {
8684
let ident = field.ident.clone().unwrap();
8785
let ty = field.ty.clone();
8886
let name = format!("{ident}");
8987
if group_by.iter().any(|i| *i == ident) {
90-
fields.push(quote! { #ident: #ty, });
88+
packed_fields.push(quote! { #ident: #ty, });
9189
load_checks.push(quote! {
9290
if filter.#ident.is_empty() {
9391
return Err(anyhow::Error::msg(#name.to_string() + " is required"));
@@ -99,7 +97,7 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
9997
} else if timestamp.as_ref().map(|t| *t == ident).unwrap_or(false) {
10098
using_chrono = !ty.to_token_stream().to_string().contains("SystemTime");
10199
timestamp_ty = Some(ty.clone());
102-
fields.push(quote! { #ident: Vec<u8>, });
100+
packed_fields.push(quote! { #ident: Vec<u8>, });
103101
load_checks.push(quote! {
104102
if filter.#ident.is_none() {
105103
return Err(anyhow::Error::msg(#name.to_string() + " is required"));
@@ -115,21 +113,13 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
115113
filter.#ident.as_ref().unwrap().end(),
116114
});
117115
} else {
118-
fields.push(quote! { #ident: Vec<u8>, });
116+
packed_fields.push(quote! { #ident: Vec<u8>, });
119117
}
120-
if timestamp.as_ref().map(|t| *t == ident).unwrap_or(false) {
121-
index += 2; // Skip over start_at and end_at
122-
}
123-
load_fields.push(quote! { #ident: row.get(#index), });
124-
index += 1;
125118
}
126-
let fields = tokens(fields);
119+
let packed_fields = tokens(packed_fields);
127120
let load_checks = tokens(load_checks);
128121
let load_where = if load_where.is_empty() { "true".to_string() } else { load_where.join(" AND ") };
129122
let load_params = tokens(load_params);
130-
let load_fields = tokens(load_fields);
131-
let load_sql = format!("SELECT * FROM {table_name} WHERE {load_where}");
132-
let delete_sql = format!("DELETE FROM {table_name} WHERE {load_where} RETURNING *");
133123

134124
// decompress
135125
let mut decompress_fields = Vec::new();
@@ -263,38 +253,53 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
263253
};
264254
let store_sql = format!("COPY {table_name} ({store_fields}) FROM STDIN BINARY");
265255

266-
let filter = filter(model, args, using_chrono, &timestamp_ty);
256+
let filter = filter(model.clone(), args.clone(), using_chrono, &timestamp_ty);
257+
let fields = fields(model, args, packed_name.clone());
267258
let deserialize_time_range = timestamp_ty.map(|t| deserialize_time_range(&t));
268259

269260
quote! {
261+
use serde::Deserialize as _;
262+
270263
#item
271264

272265
#[doc=concat!(" Generated by pco_store to store and load compressed versions of [", stringify!(#name), "]")]
273266
pub struct #packed_name {
274-
#fields
267+
#packed_fields
275268
}
276269

277270
impl #packed_name {
278271
/// Loads data for the specified filters.
279-
pub async fn load(db: &impl ::std::ops::Deref<Target = deadpool_postgres::ClientWrapper>, mut filter: Filter) -> anyhow::Result<Vec<#packed_name>> {
272+
pub async fn load(
273+
db: &impl ::std::ops::Deref<Target = deadpool_postgres::ClientWrapper>,
274+
mut filter: Filter,
275+
fields: impl TryInto<Fields>
276+
) -> anyhow::Result<Vec<#packed_name>> {
277+
let mut fields = fields.try_into().map_err(|_| anyhow::Error::msg("unknown field"))?;
278+
fields.merge_filter(&filter);
280279
#load_checks
281-
let sql = #load_sql;
280+
let sql = "SELECT ".to_string() + fields.select().as_str() + " FROM " + #table_name + " WHERE " + #load_where;
282281
let mut results = Vec::new();
283282
for row in db.query(&db.prepare_cached(&sql).await?, &[#load_params]).await? {
284-
results.push(#packed_name { filter: Some(filter.clone()), #load_fields });
283+
results.push(fields.load_from_row(row, Some(filter.clone()))?);
285284
}
286285
Ok(results)
287286
}
288287

289288
/// Deletes data for the specified filters, returning it to the caller.
290289
///
291290
/// Note that all rows are returned from [decompress][Self::decompress] even if post-decompress filters would normally apply.
292-
pub async fn delete(db: &impl ::std::ops::Deref<Target = deadpool_postgres::ClientWrapper>, mut filter: Filter) -> anyhow::Result<Vec<#packed_name>> {
291+
pub async fn delete(
292+
db: &impl ::std::ops::Deref<Target = deadpool_postgres::ClientWrapper>,
293+
mut filter: Filter,
294+
fields: impl TryInto<Fields>
295+
) -> anyhow::Result<Vec<#packed_name>> {
296+
let mut fields = fields.try_into().map_err(|_| anyhow::Error::msg("unknown field"))?;
297+
fields.merge_filter(&filter);
293298
#load_checks
294-
let sql = #delete_sql;
299+
let sql = "DELETE FROM ".to_string() + #table_name + " WHERE " + #load_where + " RETURNING " + fields.select().as_str();
295300
let mut results = Vec::new();
296301
for row in db.query(&db.prepare_cached(&sql).await?, &[#load_params]).await? {
297-
results.push(#packed_name { filter: None, #load_fields });
302+
results.push(fields.load_from_row(row, None)?);
298303
}
299304
Ok(results)
300305
}
@@ -370,6 +375,7 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
370375
}
371376

372377
#filter
378+
#fields
373379
#deserialize_time_range
374380
}
375381
.into()
@@ -514,6 +520,162 @@ fn filter(model: ItemStruct, args: Arguments, using_chrono: bool, timestamp_ty:
514520
}
515521
}
516522

523+
fn fields(model: ItemStruct, args: Arguments, packed_name: Ident) -> proc_macro2::TokenStream {
524+
let name = model.ident.clone();
525+
let Arguments { timestamp, group_by, .. } = args;
526+
let mut fields = Vec::new();
527+
let mut required = Vec::new();
528+
let mut merge_filter = Vec::new();
529+
let mut select = Vec::new();
530+
let mut load = Vec::new();
531+
let mut default = Vec::new();
532+
let mut from = Vec::new();
533+
for field in model.fields.iter() {
534+
let ident = field.ident.clone().unwrap();
535+
let name = format!("{ident}");
536+
let is_timestamp = timestamp.as_ref().map(|t| *t == ident).unwrap_or(false);
537+
fields.push(quote! { #ident: bool, });
538+
if group_by.iter().any(|i| *i == ident) || is_timestamp {
539+
required.push(quote! { #ident: true, });
540+
} else {
541+
required.push(quote! { #ident: false, });
542+
merge_filter.push(quote! {
543+
(!filter.#ident.is_empty()).then(|| self.#ident = true);
544+
});
545+
}
546+
select.push(quote! { self.#ident.then(|| fields.push(#name)); });
547+
load.push(quote! { #ident: if self.#ident {
548+
let v = row.get(index);
549+
index += 1;
550+
v
551+
} else {
552+
Default::default()
553+
},
554+
});
555+
default.push(quote! { #ident: true, });
556+
from.push(quote! { #name => fields.#ident = true, });
557+
}
558+
let fields = tokens(fields);
559+
let required = tokens(required);
560+
let merge_filter = tokens(merge_filter);
561+
let select = tokens(select);
562+
let load = tokens(load);
563+
let default = tokens(default);
564+
let from = tokens(from);
565+
quote! {
566+
#[derive(Clone, Copy, Debug, PartialEq)]
567+
#[doc=concat!(" Generated by pco_store to choose which fields to decompress when loading [", stringify!(#name), "]")]
568+
pub struct Fields {
569+
#fields
570+
}
571+
572+
impl Fields {
573+
pub fn new(fields: &[&str]) -> anyhow::Result<Self> {
574+
fields.try_into().map_err(|e| anyhow::Error::msg(e))
575+
}
576+
577+
pub fn required() -> Self {
578+
Self { #required }
579+
}
580+
581+
fn merge_filter(&mut self, filter: &Filter) {
582+
#merge_filter
583+
}
584+
585+
fn select(&self) -> String {
586+
let mut fields = Vec::new();
587+
#select
588+
fields.join(", ")
589+
}
590+
591+
fn load_from_row(&self, row: tokio_postgres::Row, filter: Option<Filter>) -> anyhow::Result<#packed_name> {
592+
let mut index = 0;
593+
Ok(#packed_name {
594+
filter,
595+
#load
596+
})
597+
}
598+
}
599+
600+
impl Default for Fields {
601+
fn default() -> Self {
602+
Self { #default }
603+
}
604+
}
605+
606+
impl TryFrom<&[&str]> for Fields {
607+
type Error = &'static str;
608+
fn try_from(input: &[&str]) -> Result<Self, Self::Error> {
609+
let mut fields = Self::required();
610+
for s in input {
611+
match *s {
612+
#from
613+
_ => return Err("unknown field"),
614+
}
615+
}
616+
Ok(fields)
617+
}
618+
}
619+
620+
impl<const N: usize> TryFrom<&[&str; N]> for Fields {
621+
type Error = &'static str;
622+
fn try_from(input: &[&str; N]) -> Result<Self, Self::Error> {
623+
Self::try_from(&input[..])
624+
}
625+
}
626+
627+
impl TryFrom<Vec<String>> for Fields {
628+
type Error = &'static str;
629+
fn try_from(input: Vec<String>) -> Result<Self, Self::Error> {
630+
let input: Vec<_> = input.iter().map(|s| s.as_str()).collect();
631+
Self::try_from(input.as_slice())
632+
}
633+
}
634+
635+
impl From<()> for Fields {
636+
fn from(_: ()) -> Self {
637+
Self::default()
638+
}
639+
}
640+
641+
impl<'de> serde::Deserialize<'de> for Fields {
642+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
643+
where
644+
D: serde::Deserializer<'de>,
645+
{
646+
deserializer.deserialize_any(FieldsVisitor)
647+
}
648+
}
649+
650+
struct FieldsVisitor;
651+
impl<'de> serde::de::Visitor<'de> for FieldsVisitor {
652+
type Value = Fields;
653+
654+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
655+
formatter.write_str("an array of strings matching the struct fields")
656+
}
657+
658+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
659+
where
660+
A: serde::de::SeqAccess<'de>,
661+
{
662+
let mut fields = Vec::new();
663+
while let Some(field) = seq.next_element()? {
664+
fields.push(field);
665+
}
666+
Fields::try_from(fields).map_err(serde::de::Error::custom)
667+
}
668+
669+
fn visit_unit<E>(self) -> Result<Self::Value, E>
670+
where
671+
E: serde::de::Error,
672+
{
673+
Ok(Fields::default())
674+
}
675+
}
676+
}
677+
}
678+
517679
fn deserialize_time_range(timestamp_ty: &Type) -> proc_macro2::TokenStream {
518680
quote! {
519681
/// Deserializes many different time range formats:
@@ -527,8 +689,6 @@ fn deserialize_time_range(timestamp_ty: &Type) -> proc_macro2::TokenStream {
527689
Ok(TimeRange::deserialize(deserializer)?.0)
528690
}
529691

530-
use serde::Deserialize as _;
531-
532692
#[derive(Debug, PartialEq)]
533693
struct TimeRange(Option<std::ops::RangeInclusive<#timestamp_ty>>);
534694
impl<'de> serde::Deserialize<'de> for TimeRange {

tests/chrono_tests/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async fn systemtime_chrono_interop() {
4747
{
4848
let mut actual: Vec<QueryStat> = vec![];
4949
let filter = Filter::new(&[database_id], start.into()..=end.into());
50-
for group in CompressedQueryStats::load(db, filter).await.unwrap() {
50+
for group in CompressedQueryStats::load(db, filter, ()).await.unwrap() {
5151
for stat in group.decompress().unwrap() {
5252
actual.push(stat.clone());
5353
}
@@ -70,7 +70,7 @@ async fn systemtime_chrono_interop() {
7070
// Read, using chrono::DateTime.
7171
let mut actual: Vec<QueryStat> = vec![];
7272
let filter = Filter::new(&[database_id], start..=end);
73-
for group in CompressedQueryStats::load(db, filter).await.unwrap() {
73+
for group in CompressedQueryStats::load(db, filter, ()).await.unwrap() {
7474
for stat in group.decompress().unwrap() {
7575
actual.push(stat.clone());
7676
}
@@ -90,7 +90,7 @@ async fn systemtime_chrono_interop() {
9090
// Read again, using SystemTime.
9191
let mut actual: Vec<QueryStat> = vec![];
9292
let filter = Filter::new(&[database_id], start.into()..=end.into());
93-
for group in CompressedQueryStats::load(db, filter).await.unwrap() {
93+
for group in CompressedQueryStats::load(db, filter, ()).await.unwrap() {
9494
for stat in group.decompress().unwrap() {
9595
actual.push(stat.clone());
9696
}

0 commit comments

Comments
 (0)