@@ -71,76 +71,60 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
7171 } ;
7272
7373 // load and delete
74- let mut fields = Vec :: new ( ) ;
75- let mut load_filters = Vec :: new ( ) ;
74+ let mut fields = vec ! [ quote!{ pub filter: Option <& ' a Filter >, } ] ;
7675 let mut load_checks = Vec :: new ( ) ;
7776 let mut load_where = Vec :: new ( ) ;
7877 let mut load_params = Vec :: new ( ) ;
7978 let mut load_fields = Vec :: new ( ) ;
8079 let mut bind = 1 ;
8180 let mut index: usize = 0 ;
82- let mut timestamp_ty: Option < _ > = None ;
81+ let mut timestamp_ty = None ;
8382 let mut using_chrono = false ;
8483 for field in model. fields . iter ( ) {
8584 let ident = field. ident . clone ( ) . unwrap ( ) ;
8685 let ty = field. ty . clone ( ) ;
86+ let name = format ! ( "{ident}" ) ;
8787 if group_by. iter ( ) . any ( |i| * i == ident) {
8888 fields. push ( quote ! { pub #ident: #ty, } ) ;
89- load_filters. push ( quote ! { #ident: & [ #ty] , } ) ;
90- let name = format ! ( "{ident}" ) ;
9189 load_checks. push ( quote ! {
92- if #ident. is_empty( ) {
93- return Err ( anyhow:: Error :: msg( #name. to_string( ) + "must be specified " ) ) ;
90+ if filter . #ident. is_empty( ) {
91+ return Err ( anyhow:: Error :: msg( #name. to_string( ) + " is required " ) ) ;
9492 }
9593 } ) ;
9694 load_where. push ( format ! ( "{ident} = ANY(${bind})" ) ) ;
9795 bind += 1 ;
98- load_params. push ( quote ! { & #ident, } ) ;
96+ load_params. push ( quote ! { & filter . #ident, } ) ;
9997 } else if timestamp. as_ref ( ) . map ( |t| * t == ident) . unwrap_or ( false ) {
100- timestamp_ty = Some ( ty . clone ( ) ) ;
98+ timestamp_ty = Some ( quote ! { #ty } ) ;
10199 using_chrono = !ty. to_token_stream ( ) . to_string ( ) . contains ( "SystemTime" ) ;
102- fields. push ( quote ! {
103- pub filter: bool ,
104- pub filter_start: #ty,
105- pub filter_end: #ty,
106- pub start_at: #ty,
107- pub end_at: #ty,
108- #ident: Vec <u8 >,
100+ fields. push ( quote ! { #ident: Vec <u8 >, } ) ;
101+ load_checks. push ( quote ! {
102+ if filter. #ident. is_none( ) {
103+ return Err ( anyhow:: Error :: msg( #name. to_string( ) + " is required" ) ) ;
104+ }
109105 } ) ;
110106 load_where. push ( format ! ( "end_at >= ${bind}" ) ) ;
111107 bind += 1 ;
112108 load_where. push ( format ! ( "start_at <= ${bind}" ) ) ;
113109 bind += 1 ;
114- load_params. push ( quote ! { & filter_start, & filter_end, } ) ;
110+ load_params. push ( quote ! {
111+ filter. #ident. as_ref( ) . unwrap( ) . start( ) ,
112+ filter. #ident. as_ref( ) . unwrap( ) . end( ) ,
113+ } ) ;
115114 } else {
116115 fields. push ( quote ! { #ident: Vec <u8 >, } ) ;
117116 }
118117 if timestamp. as_ref ( ) . map ( |t| * t == ident) . unwrap_or ( false ) {
119- load_fields. push ( quote ! { start_at: row. get( #index) , } ) ;
120- index += 1 ;
121- load_fields. push ( quote ! { end_at: row. get( #index) , } ) ;
122- index += 1 ;
118+ index += 2 ; // Skip over start_at and end_at
123119 }
124120 load_fields. push ( quote ! { #ident: row. get( #index) , } ) ;
125121 index += 1 ;
126122 }
127- let mut delete_fields = load_fields. clone ( ) ;
128- if timestamp. is_some ( ) {
129- let ty = timestamp_ty. unwrap_or_else ( || Type :: Verbatim ( quote ! { std:: time:: SystemTime } ) ) ;
130- load_filters. push ( quote ! {
131- filter_start: #ty,
132- filter_end: #ty,
133- } ) ;
134- load_fields. push ( quote ! { filter: true , filter_start, filter_end, } ) ;
135- delete_fields. push ( quote ! { filter: false , filter_start, filter_end, } ) ;
136- }
137123 let fields = tokens ( fields) ;
138- let load_filters = tokens ( load_filters) ;
139124 let load_checks = tokens ( load_checks) ;
140125 let load_where = if load_where. is_empty ( ) { "true" . to_string ( ) } else { load_where. join ( " AND " ) } ;
141126 let load_params = tokens ( load_params) ;
142127 let load_fields = tokens ( load_fields) ;
143- let delete_fields = tokens ( delete_fields) ;
144128 let load_sql = format ! ( "SELECT * FROM {table_name} WHERE {load_where}" ) ;
145129 let delete_sql = format ! ( "DELETE FROM {table_name} WHERE {load_where} RETURNING *" ) ;
146130
@@ -204,11 +188,6 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
204188 let decompress_fields = tokens ( decompress_fields) ;
205189 let compressed_field_sizes = tokens ( compressed_field_sizes) ;
206190 let decompressed_fields = tokens ( decompressed_fields) ;
207- let in_time_range = if timestamp. is_some ( ) {
208- quote ! { !self . filter || row. #timestamp >= self . filter_start && row. #timestamp <= self . filter_end }
209- } else {
210- quote ! { true }
211- } ;
212191
213192 // store
214193 let mut store_fields = Vec :: new ( ) ;
@@ -281,39 +260,67 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
281260 } ;
282261 let store_sql = format ! ( "COPY {table_name} ({store_fields}) FROM STDIN BINARY" ) ;
283262
263+ let mut filter_fields = Vec :: new ( ) ;
264+ let mut filter_conditions = Vec :: new ( ) ;
265+ for field in model. fields . iter ( ) {
266+ let ident = field. ident . clone ( ) . unwrap ( ) ;
267+ let ty = field. ty . clone ( ) ;
268+ let time = ty. to_token_stream ( ) . to_string ( ) . contains ( "Time" ) ;
269+ if time {
270+ filter_fields. push ( quote ! {
271+ #[ serde( deserialize_with = "serde_extra::deserialize_time_range" ) ]
272+ pub #ident: Option <std:: ops:: RangeInclusive <#ty>>,
273+ } ) ;
274+ } else {
275+ filter_fields. push ( quote ! {
276+ #[ serde( default ) ]
277+ #[ serde_as( deserialize_as = "serde_with::OneOrMany<_>" ) ]
278+ pub #ident: Vec <#ty>,
279+ } ) ;
280+ }
281+ if time {
282+ filter_conditions. push ( quote ! {
283+ self . #ident. as_ref( ) . map( |t| t. contains( & row. #ident) ) != Some ( false )
284+ } ) ;
285+ } else {
286+ filter_conditions. push ( quote ! {
287+ ( self . #ident. is_empty( ) || self . #ident. contains( & row. #ident) )
288+ } ) ;
289+ }
290+ }
291+ let filter_fields = tokens ( filter_fields) ;
292+
293+ let serde_extra = timestamp_ty. map ( |t| serde_extra ( & t) ) ;
294+
284295 quote ! {
285296 #item
286297
287298 #[ doc=concat!( " Generated by pco_store to store and load compressed versions of [" , stringify!( #name) , "]" ) ]
288- pub struct #packed_name {
299+ pub struct #packed_name< ' a> {
289300 #fields
290301 }
291302
292- impl #packed_name {
303+ impl < ' a> #packed_name< ' a> {
293304 /// Loads data for the specified filters.
294- ///
295- /// For models with a timestamp, [decompress][Self::decompress] automatically filters out
296- /// rows outside the requested time range.
297- pub async fn load( db: & deadpool_postgres:: Object , #load_filters) -> anyhow:: Result <Vec <#packed_name>> {
305+ pub async fn load( db: & deadpool_postgres:: Object , filter: & ' a Filter ) -> anyhow:: Result <Vec <#packed_name<' a>>> {
298306 #load_checks
299307 let sql = #load_sql;
300308 let mut results = Vec :: new( ) ;
301309 for row in db. query( & db. prepare_cached( & sql) . await ?, & [ #load_params] ) . await ? {
302- results. push( #packed_name { #load_fields } ) ;
310+ results. push( #packed_name { filter : Some ( filter ) , #load_fields } ) ;
303311 }
304312 Ok ( results)
305313 }
306314
307315 /// Deletes data for the specified filters, returning it to the caller.
308316 ///
309- /// For models with a timestamp, [decompress][Self::decompress] **will not** filter out
310- /// rows outside the requested time range.
311- pub async fn delete( db: & deadpool_postgres:: Object , #load_filters) -> anyhow:: Result <Vec <#packed_name>> {
317+ /// Note that all rows are returned from [decompress][Self::decompress] even if post-decompress filters would normally apply.
318+ pub async fn delete( db: & deadpool_postgres:: Object , filter: & ' a Filter ) -> anyhow:: Result <Vec <#packed_name<' a>>> {
312319 #load_checks
313320 let sql = #delete_sql;
314321 let mut results = Vec :: new( ) ;
315322 for row in db. query( & db. prepare_cached( & sql) . await ?, & [ #load_params] ) . await ? {
316- results. push( #packed_name { #delete_fields } ) ;
323+ results. push( #packed_name { filter : None , #load_fields } ) ;
317324 }
318325 Ok ( results)
319326 }
@@ -325,7 +332,7 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
325332 let len = [ #compressed_field_sizes] . into_iter( ) . max( ) . unwrap_or( 0 ) ;
326333 for index in 0 ..len {
327334 let row = #name { #decompressed_fields } ;
328- if #in_time_range {
335+ if self . filter . as_ref ( ) . map ( |f| f . filter ( & row ) ) != Some ( false ) {
329336 results. push( row) ;
330337 }
331338 }
@@ -387,10 +394,90 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
387394 Ok ( ( ) )
388395 }
389396 }
397+
398+ #[ serde_with:: serde_as]
399+ #[ derive( Debug , Default , serde:: Deserialize , Clone ) ]
400+ #[ serde( deny_unknown_fields) ]
401+ #[ doc=concat!( " Generated by pco_store to specify filters when loading [" , stringify!( #name) , "]" ) ]
402+ pub struct Filter {
403+ #filter_fields
404+ }
405+
406+ impl Filter {
407+ pub fn filter( & self , row: & #name) -> bool {
408+ #( #filter_conditions) &&*
409+ }
410+ }
411+
412+ #serde_extra
390413 }
391414 . into ( )
392415}
393416
417+ fn serde_extra ( timestamp_ty : & proc_macro2:: TokenStream ) -> proc_macro2:: TokenStream {
418+ quote ! {
419+ mod serde_extra {
420+ use super :: * ;
421+ use serde:: de:: { self , SeqAccess , Visitor } ;
422+ use serde:: { Deserialize , Deserializer } ;
423+ use std:: fmt;
424+ use std:: ops:: RangeInclusive ;
425+
426+ pub ( super ) fn deserialize_time_range<' de, D >( deserializer: D ) -> Result <Option <RangeInclusive <#timestamp_ty>>, D :: Error >
427+ where
428+ D : Deserializer <' de>,
429+ {
430+ Ok ( TimeRange :: deserialize( deserializer) ?. 0 )
431+ }
432+
433+ #[ derive( Debug , PartialEq ) ]
434+ struct TimeRange ( Option <RangeInclusive <#timestamp_ty>>) ;
435+ impl <' de> Deserialize <' de> for TimeRange {
436+ fn deserialize<D >( deserializer: D ) -> Result <Self , D :: Error >
437+ where
438+ D : Deserializer <' de>,
439+ {
440+ deserializer. deserialize_any( TimeRangeVisitor )
441+ }
442+ }
443+
444+ struct TimeRangeVisitor ;
445+ impl <' de> Visitor <' de> for TimeRangeVisitor {
446+ type Value = TimeRange ;
447+
448+ fn expecting( & self , formatter: & mut fmt:: Formatter ) -> fmt:: Result {
449+ formatter. write_str( "a single time string, or an array with 1-2 time strings" )
450+ }
451+
452+ fn visit_str<E >( self , value: & str ) -> Result <Self :: Value , E >
453+ where
454+ E : de:: Error ,
455+ {
456+ match serde_json:: from_str:: <#timestamp_ty>( value) {
457+ Ok ( start) => Ok ( TimeRange ( Some ( start..=start) ) ) ,
458+ Err ( err) => Err ( E :: custom( format!( "invalid time format: {err}" ) ) ) ,
459+ }
460+ }
461+
462+ fn visit_seq<A >( self , mut seq: A ) -> Result <Self :: Value , A :: Error >
463+ where
464+ A : SeqAccess <' de>,
465+ {
466+ let start = match seq. next_element:: <Option <#timestamp_ty>>( ) ? {
467+ Some ( Some ( time) ) => time,
468+ Some ( None ) | None => return Ok ( TimeRange ( None ) ) ,
469+ } ;
470+ let end = match seq. next_element:: <Option <#timestamp_ty>>( ) ? {
471+ Some ( Some ( time) ) => time,
472+ Some ( None ) | None => start,
473+ } ;
474+ Ok ( TimeRange ( Some ( start..=end) ) )
475+ }
476+ }
477+ }
478+ }
479+ }
480+
394481fn copy_type ( rust_type : String ) -> & ' static str {
395482 match rust_type. as_str ( ) {
396483 "f32" => "FLOAT4" ,
0 commit comments