@@ -25,6 +25,7 @@ use prost::Message;
2525use snafu:: { ensure, ResultExt } ;
2626
2727use crate :: {
28+ compression:: { compress_stream, WriterCompression , DEFAULT_COMPRESSION_BLOCK_SIZE } ,
2829 error:: { IoSnafu , Result , UnexpectedSnafu } ,
2930 memory:: EstimateMemory ,
3031 proto,
@@ -38,6 +39,9 @@ pub struct ArrowWriterBuilder<W> {
3839 schema : SchemaRef ,
3940 batch_size : usize ,
4041 stripe_byte_size : usize ,
42+ compression : WriterCompression ,
43+ compression_block_size : usize ,
44+ row_index_stride : Option < usize > ,
4145}
4246
4347impl < W : Write > ArrowWriterBuilder < W > {
@@ -50,6 +54,9 @@ impl<W: Write> ArrowWriterBuilder<W> {
5054 batch_size : 1024 ,
5155 // 64 MiB
5256 stripe_byte_size : 64 * 1024 * 1024 ,
57+ compression : WriterCompression :: None ,
58+ compression_block_size : DEFAULT_COMPRESSION_BLOCK_SIZE as usize ,
59+ row_index_stride : None ,
5360 }
5461 }
5562
@@ -66,17 +73,66 @@ impl<W: Write> ArrowWriterBuilder<W> {
6673 self
6774 }
6875
76+ /// Compress ORC streams and metadata with the provided writer codec.
77+ pub fn with_compression ( mut self , compression : WriterCompression ) -> Self {
78+ self . compression = compression;
79+ self
80+ }
81+
82+ /// Enable ORC `ZLIB` compression.
83+ ///
84+ /// ORC does not have a separate protobuf value for `GZIP`; common "gzip"
85+ /// writer options map to ORC `ZLIB`, so this is a naming convenience.
86+ pub fn with_gzip_compression ( mut self ) -> Self {
87+ self . compression = WriterCompression :: Zlib ;
88+ self
89+ }
90+
91+ /// The maximum uncompressed size of each ORC compression block.
92+ pub fn with_compression_block_size ( mut self , compression_block_size : usize ) -> Self {
93+ self . compression_block_size = compression_block_size;
94+ self
95+ }
96+
97+ /// Enable writer row indexes with `rows_per_group` rows per row group.
98+ pub fn with_row_index_stride ( mut self , rows_per_group : usize ) -> Self {
99+ self . row_index_stride = Some ( rows_per_group) ;
100+ self
101+ }
102+
69103 /// Construct an [`ArrowWriter`] ready to encode [`RecordBatch`]es into
70104 /// an ORC file.
71105 pub fn try_build ( mut self ) -> Result < ArrowWriter < W > > {
106+ ensure ! (
107+ self . compression_block_size > 0 ,
108+ UnexpectedSnafu {
109+ msg: "compression block size must be greater than zero"
110+ }
111+ ) ;
112+ ensure ! (
113+ self . row_index_stride. map_or( true , |stride| stride > 0 ) ,
114+ UnexpectedSnafu {
115+ msg: "row index stride must be greater than zero"
116+ }
117+ ) ;
118+
72119 // Required magic "ORC" bytes at start of file
73120 self . writer . write_all ( b"ORC" ) . context ( IoSnafu ) ?;
74- let writer = StripeWriter :: new ( self . writer , & self . schema ) ;
121+ let writer = StripeWriter :: new (
122+ self . writer ,
123+ & self . schema ,
124+ self . compression ,
125+ self . compression_block_size ,
126+ self . row_index_stride ,
127+ ) ;
75128 Ok ( ArrowWriter {
76129 writer,
77130 schema : self . schema ,
78131 batch_size : self . batch_size ,
79132 stripe_byte_size : self . stripe_byte_size ,
133+ compression : self . compression ,
134+ compression_block_size : self . compression_block_size ,
135+ row_index_stride : self . row_index_stride ,
80136 written_stripes : vec ! [ ] ,
81137 // Accounting for the 3 magic bytes above
82138 total_bytes_written : 3 ,
@@ -92,6 +148,9 @@ pub struct ArrowWriter<W> {
92148 schema : SchemaRef ,
93149 batch_size : usize ,
94150 stripe_byte_size : usize ,
151+ compression : WriterCompression ,
152+ compression_block_size : usize ,
153+ row_index_stride : Option < usize > ,
95154 written_stripes : Vec < StripeInformation > ,
96155 /// Used to keep track of progress in file so far (instead of needing Seek on the writer)
97156 total_bytes_written : u64 ,
@@ -138,9 +197,18 @@ impl<W: Write> ArrowWriter<W> {
138197 if self . writer . row_count > 0 {
139198 self . flush_stripe ( ) ?;
140199 }
141- let footer = serialize_footer ( & self . written_stripes , & self . schema ) ;
200+ let footer = serialize_footer ( & self . written_stripes , & self . schema , self . row_index_stride ) ;
142201 let footer = footer. encode_to_vec ( ) ;
143- let postscript = serialize_postscript ( footer. len ( ) as u64 ) ;
202+ let footer = compress_stream (
203+ bytes:: Bytes :: from ( footer) ,
204+ self . compression ,
205+ self . compression_block_size ,
206+ ) ?;
207+ let postscript = serialize_postscript (
208+ footer. len ( ) as u64 ,
209+ self . compression ,
210+ self . compression_block_size ,
211+ ) ;
144212 let postscript = postscript. encode_to_vec ( ) ;
145213 let postscript_len = postscript. len ( ) as u8 ;
146214
@@ -221,7 +289,11 @@ fn serialize_schema(schema: &SchemaRef) -> Vec<proto::Type> {
221289 types
222290}
223291
224- fn serialize_footer ( stripes : & [ StripeInformation ] , schema : & SchemaRef ) -> proto:: Footer {
292+ fn serialize_footer (
293+ stripes : & [ StripeInformation ] ,
294+ schema : & SchemaRef ,
295+ row_index_stride : Option < usize > ,
296+ ) -> proto:: Footer {
225297 let body_length = stripes
226298 . iter ( )
227299 . map ( |s| s. index_length + s. data_length + s. footer_length )
@@ -237,19 +309,23 @@ fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto:
237309 metadata : vec ! [ ] ,
238310 number_of_rows : Some ( number_of_rows) ,
239311 statistics : vec ! [ ] ,
240- row_index_stride : None ,
312+ row_index_stride : row_index_stride . map ( |stride| stride as u32 ) ,
241313 writer : Some ( u32:: MAX ) ,
242314 encryption : None ,
243315 calendar : None ,
244316 software_version : None ,
245317 }
246318}
247319
248- fn serialize_postscript ( footer_length : u64 ) -> proto:: PostScript {
320+ fn serialize_postscript (
321+ footer_length : u64 ,
322+ compression : WriterCompression ,
323+ compression_block_size : usize ,
324+ ) -> proto:: PostScript {
249325 proto:: PostScript {
250326 footer_length : Some ( footer_length) ,
251- compression : Some ( proto :: CompressionKind :: None . into ( ) ) , // TODO: support compression
252- compression_block_size : None ,
327+ compression : Some ( compression . to_proto ( ) . into ( ) ) ,
328+ compression_block_size : ( !compression . is_none ( ) ) . then_some ( compression_block_size as u64 ) ,
253329 version : vec ! [ 0 , 12 ] ,
254330 metadata_length : Some ( 0 ) , // TODO: statistics
255331 writer_version : Some ( u32:: MAX ) , // TODO: check which version to use
@@ -274,7 +350,7 @@ mod tests {
274350 } ;
275351 use bytes:: Bytes ;
276352
277- use crate :: { stripe:: Stripe , ArrowReaderBuilder } ;
353+ use crate :: { statistics :: TypeStatistics , stripe:: Stripe , ArrowReaderBuilder } ;
278354
279355 use super :: * ;
280356
@@ -293,6 +369,25 @@ mod tests {
293369 reader. collect :: < Result < Vec < _ > , _ > > ( ) . unwrap ( )
294370 }
295371
372+ fn write_to_bytes (
373+ batch : & RecordBatch ,
374+ gzip_compression : bool ,
375+ row_index_stride : Option < usize > ,
376+ ) -> Bytes {
377+ let mut f = vec ! [ ] ;
378+ let mut builder = ArrowWriterBuilder :: new ( & mut f, batch. schema ( ) ) ;
379+ if gzip_compression {
380+ builder = builder. with_gzip_compression ( ) ;
381+ }
382+ if let Some ( row_index_stride) = row_index_stride {
383+ builder = builder. with_row_index_stride ( row_index_stride) ;
384+ }
385+ let mut writer = builder. try_build ( ) . unwrap ( ) ;
386+ writer. write ( batch) . unwrap ( ) ;
387+ writer. close ( ) . unwrap ( ) ;
388+ Bytes :: from ( f)
389+ }
390+
296391 #[ test]
297392 fn test_roundtrip_write ( ) {
298393 let f32_array = Arc :: new ( Float32Array :: from ( vec ! [ 0.0 , 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) ) ;
@@ -354,6 +449,55 @@ mod tests {
354449 assert_eq ! ( batch, rows[ 0 ] ) ;
355450 }
356451
452+ #[ test]
453+ fn test_roundtrip_write_gzip_compression ( ) {
454+ let array = Arc :: new ( Int64Array :: from ( ( 0 ..1024 ) . collect :: < Vec < _ > > ( ) ) ) ;
455+ let schema = Schema :: new ( vec ! [ Field :: new( "int64" , ArrowDataType :: Int64 , false ) ] ) ;
456+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ array] ) . unwrap ( ) ;
457+
458+ let f = write_to_bytes ( & batch, true , None ) ;
459+ let builder = ArrowReaderBuilder :: try_new ( f) . unwrap ( ) ;
460+ assert ! ( builder. file_metadata( ) . compression( ) . is_some( ) ) ;
461+
462+ let rows = builder. build ( ) . collect :: < Result < Vec < _ > , _ > > ( ) . unwrap ( ) ;
463+ assert_eq ! ( batch, rows[ 0 ] ) ;
464+ }
465+
466+ #[ test]
467+ fn test_write_row_indexes ( ) {
468+ let array = Arc :: new ( Int64Array :: from ( ( 0 ..12 ) . collect :: < Vec < _ > > ( ) ) ) ;
469+ let schema = Schema :: new ( vec ! [ Field :: new( "int64" , ArrowDataType :: Int64 , false ) ] ) ;
470+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ array] ) . unwrap ( ) ;
471+
472+ let mut f = write_to_bytes ( & batch, false , Some ( 5 ) ) ;
473+ let builder = ArrowReaderBuilder :: try_new ( f. clone ( ) ) . unwrap ( ) ;
474+ assert_eq ! ( builder. file_metadata( ) . row_index_stride( ) , Some ( 5 ) ) ;
475+
476+ let stripe = Stripe :: new (
477+ & mut f,
478+ builder. file_metadata ( ) ,
479+ builder. file_metadata ( ) . root_data_type ( ) ,
480+ & builder. file_metadata ( ) . stripe_metadatas ( ) [ 0 ] ,
481+ )
482+ . unwrap ( ) ;
483+ let row_index = stripe. read_row_indexes ( builder. file_metadata ( ) ) . unwrap ( ) ;
484+ let column_index = row_index. column ( 1 ) . unwrap ( ) ;
485+
486+ assert_eq ! ( column_index. num_row_groups( ) , 3 ) ;
487+ assert_eq ! ( row_index. total_rows( ) , 12 ) ;
488+ assert_eq ! ( row_index. rows_per_group( ) , 5 ) ;
489+
490+ let stats = column_index. row_group_stats ( 0 ) . unwrap ( ) ;
491+ assert_eq ! ( stats. number_of_values( ) , 5 ) ;
492+ assert ! ( !stats. has_null( ) ) ;
493+ match stats. type_statistics ( ) . unwrap ( ) {
494+ TypeStatistics :: Integer { min, max, sum } => {
495+ assert_eq ! ( ( * min, * max, * sum) , ( 0 , 4 , Some ( 10 ) ) ) ;
496+ }
497+ other => panic ! ( "expected integer stats, got {other:?}" ) ,
498+ }
499+ }
500+
357501 #[ test]
358502 fn test_roundtrip_write_large_type ( ) {
359503 let large_utf8_array = Arc :: new ( LargeStringArray :: from ( vec ! [
0 commit comments