Skip to content

Commit 23284c9

Browse files
committed
Support filtering by any fields
1 parent cf33268 commit 23284c9

File tree

10 files changed

+5534
-738
lines changed

10 files changed

+5534
-738
lines changed

Cargo.lock

Lines changed: 241 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ futures = "0.3"
2020
pco = "0.4"
2121
proc-macro2 = "1.0"
2222
quote = "1.0"
23+
serde = { version = "1.0", features = ["derive"] }
24+
serde_json = { version = "1.0" }
25+
serde_with = "3.16"
2326
syn = "2.0"
2427
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] }
2528

src/lib.rs

Lines changed: 137 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -71,76 +71,60 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
7171
};
7272

7373
// load and delete
74-
let mut fields = Vec::new();
75-
let mut load_filters = Vec::new();
74+
let mut fields = vec![quote!{ pub filter: Option<&'a Filter>, }];
7675
let mut load_checks = Vec::new();
7776
let mut load_where = Vec::new();
7877
let mut load_params = Vec::new();
7978
let mut load_fields = Vec::new();
8079
let mut bind = 1;
8180
let mut index: usize = 0;
82-
let mut timestamp_ty: Option<_> = None;
81+
let mut timestamp_ty = None;
8382
let mut using_chrono = false;
8483
for field in model.fields.iter() {
8584
let ident = field.ident.clone().unwrap();
8685
let ty = field.ty.clone();
86+
let name = format!("{ident}");
8787
if group_by.iter().any(|i| *i == ident) {
8888
fields.push(quote! { pub #ident: #ty, });
89-
load_filters.push(quote! { #ident: &[#ty], });
90-
let name = format!("{ident}");
9189
load_checks.push(quote! {
92-
if #ident.is_empty() {
93-
return Err(anyhow::Error::msg(#name.to_string() + "must be specified"));
90+
if filter.#ident.is_empty() {
91+
return Err(anyhow::Error::msg(#name.to_string() + " is required"));
9492
}
9593
});
9694
load_where.push(format!("{ident} = ANY(${bind})"));
9795
bind += 1;
98-
load_params.push(quote! { &#ident, });
96+
load_params.push(quote! { &filter.#ident, });
9997
} else if timestamp.as_ref().map(|t| *t == ident).unwrap_or(false) {
100-
timestamp_ty = Some(ty.clone());
98+
timestamp_ty = Some(quote! { #ty });
10199
using_chrono = !ty.to_token_stream().to_string().contains("SystemTime");
102-
fields.push(quote! {
103-
pub filter: bool,
104-
pub filter_start: #ty,
105-
pub filter_end: #ty,
106-
pub start_at: #ty,
107-
pub end_at: #ty,
108-
#ident: Vec<u8>,
100+
fields.push(quote! { #ident: Vec<u8>, });
101+
load_checks.push(quote! {
102+
if filter.#ident.is_none() {
103+
return Err(anyhow::Error::msg(#name.to_string() + " is required"));
104+
}
109105
});
110106
load_where.push(format!("end_at >= ${bind}"));
111107
bind += 1;
112108
load_where.push(format!("start_at <= ${bind}"));
113109
bind += 1;
114-
load_params.push(quote! { &filter_start, &filter_end, });
110+
load_params.push(quote! {
111+
filter.#ident.as_ref().unwrap().start(),
112+
filter.#ident.as_ref().unwrap().end(),
113+
});
115114
} else {
116115
fields.push(quote! { #ident: Vec<u8>, });
117116
}
118117
if timestamp.as_ref().map(|t| *t == ident).unwrap_or(false) {
119-
load_fields.push(quote! { start_at: row.get(#index), });
120-
index += 1;
121-
load_fields.push(quote! { end_at: row.get(#index), });
122-
index += 1;
118+
index += 2; // Skip over start_at and end_at
123119
}
124120
load_fields.push(quote! { #ident: row.get(#index), });
125121
index += 1;
126122
}
127-
let mut delete_fields = load_fields.clone();
128-
if timestamp.is_some() {
129-
let ty = timestamp_ty.unwrap_or_else(|| Type::Verbatim(quote! { std::time::SystemTime }));
130-
load_filters.push(quote! {
131-
filter_start: #ty,
132-
filter_end: #ty,
133-
});
134-
load_fields.push(quote! { filter: true, filter_start, filter_end, });
135-
delete_fields.push(quote! { filter: false, filter_start, filter_end, });
136-
}
137123
let fields = tokens(fields);
138-
let load_filters = tokens(load_filters);
139124
let load_checks = tokens(load_checks);
140125
let load_where = if load_where.is_empty() { "true".to_string() } else { load_where.join(" AND ") };
141126
let load_params = tokens(load_params);
142127
let load_fields = tokens(load_fields);
143-
let delete_fields = tokens(delete_fields);
144128
let load_sql = format!("SELECT * FROM {table_name} WHERE {load_where}");
145129
let delete_sql = format!("DELETE FROM {table_name} WHERE {load_where} RETURNING *");
146130

@@ -204,11 +188,6 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
204188
let decompress_fields = tokens(decompress_fields);
205189
let compressed_field_sizes = tokens(compressed_field_sizes);
206190
let decompressed_fields = tokens(decompressed_fields);
207-
let in_time_range = if timestamp.is_some() {
208-
quote! { !self.filter || row.#timestamp >= self.filter_start && row.#timestamp <= self.filter_end }
209-
} else {
210-
quote! { true }
211-
};
212191

213192
// store
214193
let mut store_fields = Vec::new();
@@ -281,39 +260,67 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
281260
};
282261
let store_sql = format!("COPY {table_name} ({store_fields}) FROM STDIN BINARY");
283262

263+
let mut filter_fields = Vec::new();
264+
let mut filter_conditions = Vec::new();
265+
for field in model.fields.iter() {
266+
let ident = field.ident.clone().unwrap();
267+
let ty = field.ty.clone();
268+
let time = ty.to_token_stream().to_string().contains("Time");
269+
if time {
270+
filter_fields.push(quote! {
271+
#[serde(deserialize_with = "serde_extra::deserialize_time_range")]
272+
pub #ident: Option<std::ops::RangeInclusive<#ty>>,
273+
});
274+
} else {
275+
filter_fields.push(quote! {
276+
#[serde(default)]
277+
#[serde_as(deserialize_as = "serde_with::OneOrMany<_>")]
278+
pub #ident: Vec<#ty>,
279+
});
280+
}
281+
if time {
282+
filter_conditions.push(quote! {
283+
self.#ident.as_ref().map(|t| t.contains(&row.#ident)) != Some(false)
284+
});
285+
} else {
286+
filter_conditions.push(quote! {
287+
(self.#ident.is_empty() || self.#ident.contains(&row.#ident))
288+
});
289+
}
290+
}
291+
let filter_fields = tokens(filter_fields);
292+
293+
let serde_extra = timestamp_ty.map(|t| serde_extra(&t));
294+
284295
quote! {
285296
#item
286297

287298
#[doc=concat!(" Generated by pco_store to store and load compressed versions of [", stringify!(#name), "]")]
288-
pub struct #packed_name {
299+
pub struct #packed_name<'a> {
289300
#fields
290301
}
291302

292-
impl #packed_name {
303+
impl<'a> #packed_name<'a> {
293304
/// Loads data for the specified filters.
294-
///
295-
/// For models with a timestamp, [decompress][Self::decompress] automatically filters out
296-
/// rows outside the requested time range.
297-
pub async fn load(db: &deadpool_postgres::Object, #load_filters) -> anyhow::Result<Vec<#packed_name>> {
305+
pub async fn load(db: &deadpool_postgres::Object, filter: &'a Filter) -> anyhow::Result<Vec<#packed_name<'a>>> {
298306
#load_checks
299307
let sql = #load_sql;
300308
let mut results = Vec::new();
301309
for row in db.query(&db.prepare_cached(&sql).await?, &[#load_params]).await? {
302-
results.push(#packed_name { #load_fields });
310+
results.push(#packed_name { filter: Some(filter), #load_fields });
303311
}
304312
Ok(results)
305313
}
306314

307315
/// Deletes data for the specified filters, returning it to the caller.
308316
///
309-
/// For models with a timestamp, [decompress][Self::decompress] **will not** filter out
310-
/// rows outside the requested time range.
311-
pub async fn delete(db: &deadpool_postgres::Object, #load_filters) -> anyhow::Result<Vec<#packed_name>> {
317+
/// Note that all rows are returned from [decompress][Self::decompress] even if post-decompress filters would normally apply.
318+
pub async fn delete(db: &deadpool_postgres::Object, filter: &'a Filter) -> anyhow::Result<Vec<#packed_name<'a>>> {
312319
#load_checks
313320
let sql = #delete_sql;
314321
let mut results = Vec::new();
315322
for row in db.query(&db.prepare_cached(&sql).await?, &[#load_params]).await? {
316-
results.push(#packed_name { #delete_fields });
323+
results.push(#packed_name { filter: None, #load_fields });
317324
}
318325
Ok(results)
319326
}
@@ -325,7 +332,7 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
325332
let len = [#compressed_field_sizes].into_iter().max().unwrap_or(0);
326333
for index in 0..len {
327334
let row = #name { #decompressed_fields };
328-
if #in_time_range {
335+
if self.filter.as_ref().map(|f| f.filter(&row)) != Some(false) {
329336
results.push(row);
330337
}
331338
}
@@ -387,10 +394,90 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
387394
Ok(())
388395
}
389396
}
397+
398+
#[serde_with::serde_as]
399+
#[derive(Debug, Default, serde::Deserialize, Clone)]
400+
#[serde(deny_unknown_fields)]
401+
#[doc=concat!(" Generated by pco_store to specify filters when loading [", stringify!(#name), "]")]
402+
pub struct Filter {
403+
#filter_fields
404+
}
405+
406+
impl Filter {
407+
pub fn filter(&self, row: &#name) -> bool {
408+
#(#filter_conditions)&&*
409+
}
410+
}
411+
412+
#serde_extra
390413
}
391414
.into()
392415
}
393416

417+
fn serde_extra(timestamp_ty: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
418+
quote! {
419+
mod serde_extra {
420+
use super::*;
421+
use serde::de::{self, SeqAccess, Visitor};
422+
use serde::{Deserialize, Deserializer};
423+
use std::fmt;
424+
use std::ops::RangeInclusive;
425+
426+
pub(super) fn deserialize_time_range<'de, D>(deserializer: D) -> Result<Option<RangeInclusive<#timestamp_ty>>, D::Error>
427+
where
428+
D: Deserializer<'de>,
429+
{
430+
Ok(TimeRange::deserialize(deserializer)?.0)
431+
}
432+
433+
#[derive(Debug, PartialEq)]
434+
struct TimeRange(Option<RangeInclusive<#timestamp_ty>>);
435+
impl<'de> Deserialize<'de> for TimeRange {
436+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
437+
where
438+
D: Deserializer<'de>,
439+
{
440+
deserializer.deserialize_any(TimeRangeVisitor)
441+
}
442+
}
443+
444+
struct TimeRangeVisitor;
445+
impl<'de> Visitor<'de> for TimeRangeVisitor {
446+
type Value = TimeRange;
447+
448+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
449+
formatter.write_str("a single time string, or an array with 1-2 time strings")
450+
}
451+
452+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
453+
where
454+
E: de::Error,
455+
{
456+
match serde_json::from_str::<#timestamp_ty>(value) {
457+
Ok(start) => Ok(TimeRange(Some(start..=start))),
458+
Err(err) => Err(E::custom(format!("invalid time format: {err}"))),
459+
}
460+
}
461+
462+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
463+
where
464+
A: SeqAccess<'de>,
465+
{
466+
let start = match seq.next_element::<Option<#timestamp_ty>>()? {
467+
Some(Some(time)) => time,
468+
Some(None) | None => return Ok(TimeRange(None)),
469+
};
470+
let end = match seq.next_element::<Option<#timestamp_ty>>()? {
471+
Some(Some(time)) => time,
472+
Some(None) | None => start,
473+
};
474+
Ok(TimeRange(Some(start..=end)))
475+
}
476+
}
477+
}
478+
}
479+
}
480+
394481
fn copy_type(rust_type: String) -> &'static str {
395482
match rust_type.as_str() {
396483
"f32" => "FLOAT4",

0 commit comments

Comments
 (0)