Skip to content

Commit 206f628

Browse files
author
xiadong
committed
feat: add writer compression and row indexes
1 parent c1e4fe4 commit 206f628

5 files changed

Lines changed: 733 additions & 15 deletions

File tree

src/arrow_writer.rs

Lines changed: 153 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use prost::Message;
2525
use snafu::{ensure, ResultExt};
2626

2727
use crate::{
28+
compression::{compress_stream, WriterCompression, DEFAULT_COMPRESSION_BLOCK_SIZE},
2829
error::{IoSnafu, Result, UnexpectedSnafu},
2930
memory::EstimateMemory,
3031
proto,
@@ -38,6 +39,9 @@ pub struct ArrowWriterBuilder<W> {
3839
schema: SchemaRef,
3940
batch_size: usize,
4041
stripe_byte_size: usize,
42+
compression: WriterCompression,
43+
compression_block_size: usize,
44+
row_index_stride: Option<usize>,
4145
}
4246

4347
impl<W: Write> ArrowWriterBuilder<W> {
@@ -50,6 +54,9 @@ impl<W: Write> ArrowWriterBuilder<W> {
5054
batch_size: 1024,
5155
// 64 MiB
5256
stripe_byte_size: 64 * 1024 * 1024,
57+
compression: WriterCompression::None,
58+
compression_block_size: DEFAULT_COMPRESSION_BLOCK_SIZE as usize,
59+
row_index_stride: None,
5360
}
5461
}
5562

@@ -66,17 +73,66 @@ impl<W: Write> ArrowWriterBuilder<W> {
6673
self
6774
}
6875

76+
/// Compress ORC streams and metadata with the provided writer codec.
77+
pub fn with_compression(mut self, compression: WriterCompression) -> Self {
78+
self.compression = compression;
79+
self
80+
}
81+
82+
/// Enable ORC `ZLIB` compression.
83+
///
84+
/// ORC does not have a separate protobuf value for `GZIP`; common "gzip"
85+
/// writer options map to ORC `ZLIB`, so this is a naming convenience.
86+
pub fn with_gzip_compression(mut self) -> Self {
87+
self.compression = WriterCompression::Zlib;
88+
self
89+
}
90+
91+
/// The maximum uncompressed size of each ORC compression block.
92+
pub fn with_compression_block_size(mut self, compression_block_size: usize) -> Self {
93+
self.compression_block_size = compression_block_size;
94+
self
95+
}
96+
97+
/// Enable writer row indexes with `rows_per_group` rows per row group.
98+
pub fn with_row_index_stride(mut self, rows_per_group: usize) -> Self {
99+
self.row_index_stride = Some(rows_per_group);
100+
self
101+
}
102+
69103
/// Construct an [`ArrowWriter`] ready to encode [`RecordBatch`]es into
70104
/// an ORC file.
71105
pub fn try_build(mut self) -> Result<ArrowWriter<W>> {
106+
ensure!(
107+
self.compression_block_size > 0,
108+
UnexpectedSnafu {
109+
msg: "compression block size must be greater than zero"
110+
}
111+
);
112+
ensure!(
113+
self.row_index_stride.map_or(true, |stride| stride > 0),
114+
UnexpectedSnafu {
115+
msg: "row index stride must be greater than zero"
116+
}
117+
);
118+
72119
// Required magic "ORC" bytes at start of file
73120
self.writer.write_all(b"ORC").context(IoSnafu)?;
74-
let writer = StripeWriter::new(self.writer, &self.schema);
121+
let writer = StripeWriter::new(
122+
self.writer,
123+
&self.schema,
124+
self.compression,
125+
self.compression_block_size,
126+
self.row_index_stride,
127+
);
75128
Ok(ArrowWriter {
76129
writer,
77130
schema: self.schema,
78131
batch_size: self.batch_size,
79132
stripe_byte_size: self.stripe_byte_size,
133+
compression: self.compression,
134+
compression_block_size: self.compression_block_size,
135+
row_index_stride: self.row_index_stride,
80136
written_stripes: vec![],
81137
// Accounting for the 3 magic bytes above
82138
total_bytes_written: 3,
@@ -92,6 +148,9 @@ pub struct ArrowWriter<W> {
92148
schema: SchemaRef,
93149
batch_size: usize,
94150
stripe_byte_size: usize,
151+
compression: WriterCompression,
152+
compression_block_size: usize,
153+
row_index_stride: Option<usize>,
95154
written_stripes: Vec<StripeInformation>,
96155
/// Used to keep track of progress in file so far (instead of needing Seek on the writer)
97156
total_bytes_written: u64,
@@ -138,9 +197,18 @@ impl<W: Write> ArrowWriter<W> {
138197
if self.writer.row_count > 0 {
139198
self.flush_stripe()?;
140199
}
141-
let footer = serialize_footer(&self.written_stripes, &self.schema);
200+
let footer = serialize_footer(&self.written_stripes, &self.schema, self.row_index_stride);
142201
let footer = footer.encode_to_vec();
143-
let postscript = serialize_postscript(footer.len() as u64);
202+
let footer = compress_stream(
203+
bytes::Bytes::from(footer),
204+
self.compression,
205+
self.compression_block_size,
206+
)?;
207+
let postscript = serialize_postscript(
208+
footer.len() as u64,
209+
self.compression,
210+
self.compression_block_size,
211+
);
144212
let postscript = postscript.encode_to_vec();
145213
let postscript_len = postscript.len() as u8;
146214

@@ -221,7 +289,11 @@ fn serialize_schema(schema: &SchemaRef) -> Vec<proto::Type> {
221289
types
222290
}
223291

224-
fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto::Footer {
292+
fn serialize_footer(
293+
stripes: &[StripeInformation],
294+
schema: &SchemaRef,
295+
row_index_stride: Option<usize>,
296+
) -> proto::Footer {
225297
let body_length = stripes
226298
.iter()
227299
.map(|s| s.index_length + s.data_length + s.footer_length)
@@ -237,19 +309,23 @@ fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto:
237309
metadata: vec![],
238310
number_of_rows: Some(number_of_rows),
239311
statistics: vec![],
240-
row_index_stride: None,
312+
row_index_stride: row_index_stride.map(|stride| stride as u32),
241313
writer: Some(u32::MAX),
242314
encryption: None,
243315
calendar: None,
244316
software_version: None,
245317
}
246318
}
247319

248-
fn serialize_postscript(footer_length: u64) -> proto::PostScript {
320+
fn serialize_postscript(
321+
footer_length: u64,
322+
compression: WriterCompression,
323+
compression_block_size: usize,
324+
) -> proto::PostScript {
249325
proto::PostScript {
250326
footer_length: Some(footer_length),
251-
compression: Some(proto::CompressionKind::None.into()), // TODO: support compression
252-
compression_block_size: None,
327+
compression: Some(compression.to_proto().into()),
328+
compression_block_size: (!compression.is_none()).then_some(compression_block_size as u64),
253329
version: vec![0, 12],
254330
metadata_length: Some(0), // TODO: statistics
255331
writer_version: Some(u32::MAX), // TODO: check which version to use
@@ -274,7 +350,7 @@ mod tests {
274350
};
275351
use bytes::Bytes;
276352

277-
use crate::{stripe::Stripe, ArrowReaderBuilder};
353+
use crate::{statistics::TypeStatistics, stripe::Stripe, ArrowReaderBuilder};
278354

279355
use super::*;
280356

@@ -293,6 +369,25 @@ mod tests {
293369
reader.collect::<Result<Vec<_>, _>>().unwrap()
294370
}
295371

372+
fn write_to_bytes(
373+
batch: &RecordBatch,
374+
gzip_compression: bool,
375+
row_index_stride: Option<usize>,
376+
) -> Bytes {
377+
let mut f = vec![];
378+
let mut builder = ArrowWriterBuilder::new(&mut f, batch.schema());
379+
if gzip_compression {
380+
builder = builder.with_gzip_compression();
381+
}
382+
if let Some(row_index_stride) = row_index_stride {
383+
builder = builder.with_row_index_stride(row_index_stride);
384+
}
385+
let mut writer = builder.try_build().unwrap();
386+
writer.write(batch).unwrap();
387+
writer.close().unwrap();
388+
Bytes::from(f)
389+
}
390+
296391
#[test]
297392
fn test_roundtrip_write() {
298393
let f32_array = Arc::new(Float32Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
@@ -354,6 +449,55 @@ mod tests {
354449
assert_eq!(batch, rows[0]);
355450
}
356451

452+
#[test]
453+
fn test_roundtrip_write_gzip_compression() {
454+
let array = Arc::new(Int64Array::from((0..1024).collect::<Vec<_>>()));
455+
let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, false)]);
456+
let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap();
457+
458+
let f = write_to_bytes(&batch, true, None);
459+
let builder = ArrowReaderBuilder::try_new(f).unwrap();
460+
assert!(builder.file_metadata().compression().is_some());
461+
462+
let rows = builder.build().collect::<Result<Vec<_>, _>>().unwrap();
463+
assert_eq!(batch, rows[0]);
464+
}
465+
466+
#[test]
467+
fn test_write_row_indexes() {
468+
let array = Arc::new(Int64Array::from((0..12).collect::<Vec<_>>()));
469+
let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, false)]);
470+
let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap();
471+
472+
let mut f = write_to_bytes(&batch, false, Some(5));
473+
let builder = ArrowReaderBuilder::try_new(f.clone()).unwrap();
474+
assert_eq!(builder.file_metadata().row_index_stride(), Some(5));
475+
476+
let stripe = Stripe::new(
477+
&mut f,
478+
builder.file_metadata(),
479+
builder.file_metadata().root_data_type(),
480+
&builder.file_metadata().stripe_metadatas()[0],
481+
)
482+
.unwrap();
483+
let row_index = stripe.read_row_indexes(builder.file_metadata()).unwrap();
484+
let column_index = row_index.column(1).unwrap();
485+
486+
assert_eq!(column_index.num_row_groups(), 3);
487+
assert_eq!(row_index.total_rows(), 12);
488+
assert_eq!(row_index.rows_per_group(), 5);
489+
490+
let stats = column_index.row_group_stats(0).unwrap();
491+
assert_eq!(stats.number_of_values(), 5);
492+
assert!(!stats.has_null());
493+
match stats.type_statistics().unwrap() {
494+
TypeStatistics::Integer { min, max, sum } => {
495+
assert_eq!((*min, *max, *sum), (0, 4, Some(10)));
496+
}
497+
other => panic!("expected integer stats, got {other:?}"),
498+
}
499+
}
500+
357501
#[test]
358502
fn test_roundtrip_write_large_type() {
359503
let large_utf8_array = Arc::new(LargeStringArray::from(vec![

0 commit comments

Comments
 (0)