@@ -73,21 +73,19 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
7373 } ;
7474
7575 // load and delete
76- let mut fields = vec ! [ quote! { filter: Option <Filter >, } ] ;
76+ let mut packed_fields = vec ! [ quote! { filter: Option <Filter >, } ] ;
7777 let mut load_checks = Vec :: new ( ) ;
7878 let mut load_where = Vec :: new ( ) ;
7979 let mut load_params = Vec :: new ( ) ;
80- let mut load_fields = Vec :: new ( ) ;
8180 let mut bind = 1 ;
82- let mut index: usize = 0 ;
8381 let mut timestamp_ty = None ;
8482 let mut using_chrono = false ;
8583 for field in model. fields . iter ( ) {
8684 let ident = field. ident . clone ( ) . unwrap ( ) ;
8785 let ty = field. ty . clone ( ) ;
8886 let name = format ! ( "{ident}" ) ;
8987 if group_by. iter ( ) . any ( |i| * i == ident) {
90- fields . push ( quote ! { #ident: #ty, } ) ;
88+ packed_fields . push ( quote ! { #ident: #ty, } ) ;
9189 load_checks. push ( quote ! {
9290 if filter. #ident. is_empty( ) {
9391 return Err ( anyhow:: Error :: msg( #name. to_string( ) + " is required" ) ) ;
@@ -99,7 +97,7 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
9997 } else if timestamp. as_ref ( ) . map ( |t| * t == ident) . unwrap_or ( false ) {
10098 using_chrono = !ty. to_token_stream ( ) . to_string ( ) . contains ( "SystemTime" ) ;
10199 timestamp_ty = Some ( ty. clone ( ) ) ;
102- fields . push ( quote ! { #ident: Vec <u8 >, } ) ;
100+ packed_fields . push ( quote ! { #ident: Vec <u8 >, } ) ;
103101 load_checks. push ( quote ! {
104102 if filter. #ident. is_none( ) {
105103 return Err ( anyhow:: Error :: msg( #name. to_string( ) + " is required" ) ) ;
@@ -115,21 +113,13 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
115113 filter. #ident. as_ref( ) . unwrap( ) . end( ) ,
116114 } ) ;
117115 } else {
118- fields . push ( quote ! { #ident: Vec <u8 >, } ) ;
116+ packed_fields . push ( quote ! { #ident: Vec <u8 >, } ) ;
119117 }
120- if timestamp. as_ref ( ) . map ( |t| * t == ident) . unwrap_or ( false ) {
121- index += 2 ; // Skip over start_at and end_at
122- }
123- load_fields. push ( quote ! { #ident: row. get( #index) , } ) ;
124- index += 1 ;
125118 }
126- let fields = tokens ( fields ) ;
119+ let packed_fields = tokens ( packed_fields ) ;
127120 let load_checks = tokens ( load_checks) ;
128121 let load_where = if load_where. is_empty ( ) { "true" . to_string ( ) } else { load_where. join ( " AND " ) } ;
129122 let load_params = tokens ( load_params) ;
130- let load_fields = tokens ( load_fields) ;
131- let load_sql = format ! ( "SELECT * FROM {table_name} WHERE {load_where}" ) ;
132- let delete_sql = format ! ( "DELETE FROM {table_name} WHERE {load_where} RETURNING *" ) ;
133123
134124 // decompress
135125 let mut decompress_fields = Vec :: new ( ) ;
@@ -263,38 +253,53 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
263253 } ;
264254 let store_sql = format ! ( "COPY {table_name} ({store_fields}) FROM STDIN BINARY" ) ;
265255
266- let filter = filter ( model, args, using_chrono, & timestamp_ty) ;
256+ let filter = filter ( model. clone ( ) , args. clone ( ) , using_chrono, & timestamp_ty) ;
257+ let fields = fields ( model, args, packed_name. clone ( ) ) ;
267258 let deserialize_time_range = timestamp_ty. map ( |t| deserialize_time_range ( & t) ) ;
268259
269260 quote ! {
261+ use serde:: Deserialize as _;
262+
270263 #item
271264
272265 #[ doc=concat!( " Generated by pco_store to store and load compressed versions of [" , stringify!( #name) , "]" ) ]
273266 pub struct #packed_name {
274- #fields
267+ #packed_fields
275268 }
276269
277270 impl #packed_name {
278271 /// Loads data for the specified filters.
279- pub async fn load( db: & impl :: std:: ops:: Deref <Target = deadpool_postgres:: ClientWrapper >, mut filter: Filter ) -> anyhow:: Result <Vec <#packed_name>> {
272+ pub async fn load(
273+ db: & impl :: std:: ops:: Deref <Target = deadpool_postgres:: ClientWrapper >,
274+ mut filter: Filter ,
275+ fields: impl TryInto <Fields >
276+ ) -> anyhow:: Result <Vec <#packed_name>> {
277+ let mut fields = fields. try_into( ) . map_err( |_| anyhow:: Error :: msg( "unknown field" ) ) ?;
278+ fields. merge_filter( & filter) ;
280279 #load_checks
281- let sql = #load_sql ;
280+ let sql = "SELECT " . to_string ( ) + fields . select ( ) . as_str ( ) + " FROM " + #table_name + " WHERE " + #load_where ;
282281 let mut results = Vec :: new( ) ;
283282 for row in db. query( & db. prepare_cached( & sql) . await ?, & [ #load_params] ) . await ? {
284- results. push( #packed_name { filter : Some ( filter. clone( ) ) , #load_fields } ) ;
283+ results. push( fields . load_from_row ( row , Some ( filter. clone( ) ) ) ? ) ;
285284 }
286285 Ok ( results)
287286 }
288287
289288 /// Deletes data for the specified filters, returning it to the caller.
290289 ///
291290 /// Note that all rows are returned from [decompress][Self::decompress] even if post-decompress filters would normally apply.
292- pub async fn delete( db: & impl :: std:: ops:: Deref <Target = deadpool_postgres:: ClientWrapper >, mut filter: Filter ) -> anyhow:: Result <Vec <#packed_name>> {
291+ pub async fn delete(
292+ db: & impl :: std:: ops:: Deref <Target = deadpool_postgres:: ClientWrapper >,
293+ mut filter: Filter ,
294+ fields: impl TryInto <Fields >
295+ ) -> anyhow:: Result <Vec <#packed_name>> {
296+ let mut fields = fields. try_into( ) . map_err( |_| anyhow:: Error :: msg( "unknown field" ) ) ?;
297+ fields. merge_filter( & filter) ;
293298 #load_checks
294- let sql = #delete_sql ;
299+ let sql = "DELETE FROM " . to_string ( ) + #table_name + " WHERE " + #load_where + " RETURNING " + fields . select ( ) . as_str ( ) ;
295300 let mut results = Vec :: new( ) ;
296301 for row in db. query( & db. prepare_cached( & sql) . await ?, & [ #load_params] ) . await ? {
297- results. push( #packed_name { filter : None , #load_fields } ) ;
302+ results. push( fields . load_from_row ( row , None ) ? ) ;
298303 }
299304 Ok ( results)
300305 }
@@ -370,6 +375,7 @@ pub fn store(args: TokenStream, item: TokenStream) -> TokenStream {
370375 }
371376
372377 #filter
378+ #fields
373379 #deserialize_time_range
374380 }
375381 . into ( )
@@ -514,6 +520,162 @@ fn filter(model: ItemStruct, args: Arguments, using_chrono: bool, timestamp_ty:
514520 }
515521}
516522
523+ fn fields ( model : ItemStruct , args : Arguments , packed_name : Ident ) -> proc_macro2:: TokenStream {
524+ let name = model. ident . clone ( ) ;
525+ let Arguments { timestamp, group_by, .. } = args;
526+ let mut fields = Vec :: new ( ) ;
527+ let mut required = Vec :: new ( ) ;
528+ let mut merge_filter = Vec :: new ( ) ;
529+ let mut select = Vec :: new ( ) ;
530+ let mut load = Vec :: new ( ) ;
531+ let mut default = Vec :: new ( ) ;
532+ let mut from = Vec :: new ( ) ;
533+ for field in model. fields . iter ( ) {
534+ let ident = field. ident . clone ( ) . unwrap ( ) ;
535+ let name = format ! ( "{ident}" ) ;
536+ let is_timestamp = timestamp. as_ref ( ) . map ( |t| * t == ident) . unwrap_or ( false ) ;
537+ fields. push ( quote ! { #ident: bool , } ) ;
538+ if group_by. iter ( ) . any ( |i| * i == ident) || is_timestamp {
539+ required. push ( quote ! { #ident: true , } ) ;
540+ } else {
541+ required. push ( quote ! { #ident: false , } ) ;
542+ merge_filter. push ( quote ! {
543+ ( !filter. #ident. is_empty( ) ) . then( || self . #ident = true ) ;
544+ } ) ;
545+ }
546+ select. push ( quote ! { self . #ident. then( || fields. push( #name) ) ; } ) ;
547+ load. push ( quote ! { #ident: if self . #ident {
548+ let v = row. get( index) ;
549+ index += 1 ;
550+ v
551+ } else {
552+ Default :: default ( )
553+ } ,
554+ } ) ;
555+ default. push ( quote ! { #ident: true , } ) ;
556+ from. push ( quote ! { #name => fields. #ident = true , } ) ;
557+ }
558+ let fields = tokens ( fields) ;
559+ let required = tokens ( required) ;
560+ let merge_filter = tokens ( merge_filter) ;
561+ let select = tokens ( select) ;
562+ let load = tokens ( load) ;
563+ let default = tokens ( default) ;
564+ let from = tokens ( from) ;
565+ quote ! {
566+ #[ derive( Clone , Copy , Debug , PartialEq ) ]
567+ #[ doc=concat!( " Generated by pco_store to choose which fields to decompress when loading [" , stringify!( #name) , "]" ) ]
568+ pub struct Fields {
569+ #fields
570+ }
571+
572+ impl Fields {
573+ pub fn new( fields: & [ & str ] ) -> anyhow:: Result <Self > {
574+ fields. try_into( ) . map_err( |e| anyhow:: Error :: msg( e) )
575+ }
576+
577+ pub fn required( ) -> Self {
578+ Self { #required }
579+ }
580+
581+ fn merge_filter( & mut self , filter: & Filter ) {
582+ #merge_filter
583+ }
584+
585+ fn select( & self ) -> String {
586+ let mut fields = Vec :: new( ) ;
587+ #select
588+ fields. join( ", " )
589+ }
590+
591+ fn load_from_row( & self , row: tokio_postgres:: Row , filter: Option <Filter >) -> anyhow:: Result <#packed_name> {
592+ let mut index = 0 ;
593+ Ok ( #packed_name {
594+ filter,
595+ #load
596+ } )
597+ }
598+ }
599+
600+ impl Default for Fields {
601+ fn default ( ) -> Self {
602+ Self { #default }
603+ }
604+ }
605+
606+ impl TryFrom <& [ & str ] > for Fields {
607+ type Error = & ' static str ;
608+ fn try_from( input: & [ & str ] ) -> Result <Self , Self :: Error > {
609+ let mut fields = Self :: required( ) ;
610+ for s in input {
611+ match * s {
612+ #from
613+ _ => return Err ( "unknown field" ) ,
614+ }
615+ }
616+ Ok ( fields)
617+ }
618+ }
619+
620+ impl <const N : usize > TryFrom <& [ & str ; N ] > for Fields {
621+ type Error = & ' static str ;
622+ fn try_from( input: & [ & str ; N ] ) -> Result <Self , Self :: Error > {
623+ Self :: try_from( & input[ ..] )
624+ }
625+ }
626+
627+ impl TryFrom <Vec <String >> for Fields {
628+ type Error = & ' static str ;
629+ fn try_from( input: Vec <String >) -> Result <Self , Self :: Error > {
630+ let input: Vec <_> = input. iter( ) . map( |s| s. as_str( ) ) . collect( ) ;
631+ Self :: try_from( input. as_slice( ) )
632+ }
633+ }
634+
635+ impl From <( ) > for Fields {
636+ fn from( _: ( ) ) -> Self {
637+ Self :: default ( )
638+ }
639+ }
640+
641+ impl <' de> serde:: Deserialize <' de> for Fields {
642+ fn deserialize<D >( deserializer: D ) -> Result <Self , D :: Error >
643+ where
644+ D : serde:: Deserializer <' de>,
645+ {
646+ deserializer. deserialize_any( FieldsVisitor )
647+ }
648+ }
649+
650+ struct FieldsVisitor ;
651+ impl <' de> serde:: de:: Visitor <' de> for FieldsVisitor {
652+ type Value = Fields ;
653+
654+ fn expecting( & self , formatter: & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
655+ formatter. write_str( "an array of strings matching the struct fields" )
656+ }
657+
658+ fn visit_seq<A >( self , mut seq: A ) -> Result <Self :: Value , A :: Error >
659+ where
660+ A : serde:: de:: SeqAccess <' de>,
661+ {
662+ let mut fields = Vec :: new( ) ;
663+ while let Some ( field) = seq. next_element( ) ? {
664+ fields. push( field) ;
665+ }
666+ Fields :: try_from( fields) . map_err( serde:: de:: Error :: custom)
667+ }
668+
669+ fn visit_unit<E >( self ) -> Result <Self :: Value , E >
670+ where
671+ E : serde:: de:: Error ,
672+ {
673+ Ok ( Fields :: default ( ) )
674+ }
675+ }
676+ }
677+ }
678+
517679fn deserialize_time_range ( timestamp_ty : & Type ) -> proc_macro2:: TokenStream {
518680 quote ! {
519681 /// Deserializes many different time range formats:
@@ -527,8 +689,6 @@ fn deserialize_time_range(timestamp_ty: &Type) -> proc_macro2::TokenStream {
527689 Ok ( TimeRange :: deserialize( deserializer) ?. 0 )
528690 }
529691
530- use serde:: Deserialize as _;
531-
532692 #[ derive( Debug , PartialEq ) ]
533693 struct TimeRange ( Option <std:: ops:: RangeInclusive <#timestamp_ty>>) ;
534694 impl <' de> serde:: Deserialize <' de> for TimeRange {
0 commit comments