From 6c94b3d1e8d680b9e909b6ee59fcd951362bcb2a Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Wed, 1 Jan 2025 16:19:20 +1100 Subject: [PATCH 1/6] Refactor retrieval of int decoders Getting unsigned/signed integer decoder shouldn't need knowledge of a column; only cares about if you need V1 or V2, so refactor to accomodate this. Also simplifies some methods by removing the check for invalid column encoding as this shouldn't be checked at decoder retrieval time anyway. --- src/array_decoder/decimal.rs | 10 ++--- src/array_decoder/list.rs | 4 +- src/array_decoder/map.rs | 4 +- src/array_decoder/mod.rs | 12 ++--- src/array_decoder/string.rs | 8 ++-- src/array_decoder/timestamp.rs | 80 ++++++++++++++++++---------------- src/column.rs | 12 ++++- src/encoding/integer/mod.rs | 51 +++++++++------------- src/error.rs | 9 ---- 9 files changed, 93 insertions(+), 97 deletions(-) diff --git a/src/array_decoder/decimal.rs b/src/array_decoder/decimal.rs index 40e766e..9ecf274 100644 --- a/src/array_decoder/decimal.rs +++ b/src/array_decoder/decimal.rs @@ -24,7 +24,7 @@ use arrow::datatypes::Decimal128Type; use snafu::ResultExt; use crate::encoding::decimal::UnboundedVarintStreamDecoder; -use crate::encoding::integer::get_rle_reader; +use crate::encoding::integer::get_signed_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::ArrowSnafu; use crate::proto::stream::Kind; @@ -38,13 +38,13 @@ pub fn new_decimal_decoder( stripe: &Stripe, precision: u32, fixed_scale: u32, -) -> Result> { +) -> Box { let varint_iter = stripe.stream_map().get(column, Kind::Data); let varint_iter = Box::new(UnboundedVarintStreamDecoder::new(varint_iter)); // Scale is specified on a per varint basis (in addition to being encoded in the type) let scale_iter = stripe.stream_map().get(column, Kind::Secondary); - let scale_iter = get_rle_reader::(column, scale_iter)?; + let scale_iter = get_signed_int_decoder::(scale_iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); @@ -55,12 +55,12 @@ pub fn new_decimal_decoder( }; let iter = Box::new(iter); - Ok(Box::new(DecimalArrayDecoder::new( + Box::new(DecimalArrayDecoder::new( precision as u8, fixed_scale as i8, iter, present, - ))) + )) } /// Wrapper around PrimitiveArrayDecoder to allow specifying the precision and scale diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs index 7a5ccfb..1301ee4 100644 --- a/src/array_decoder/list.rs +++ b/src/array_decoder/list.rs @@ -24,7 +24,7 @@ use snafu::ResultExt; use crate::array_decoder::derive_present_vec; use crate::column::Column; -use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::integer::get_unsigned_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::proto::stream::Kind; @@ -48,7 +48,7 @@ impl ListArrayDecoder { let inner = array_decoder_factory(child, field.clone(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); - let lengths = get_unsigned_rle_reader(column, reader); + let lengths = get_unsigned_int_decoder(reader, column.rle_version()); Ok(Self { inner, diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs index 7c01988..fb00f12 100644 --- a/src/array_decoder/map.rs +++ b/src/array_decoder/map.rs @@ -24,7 +24,7 @@ use snafu::ResultExt; use crate::array_decoder::derive_present_vec; use crate::column::Column; -use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::integer::get_unsigned_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::{ArrowSnafu, Result}; use crate::proto::stream::Kind; @@ -56,7 +56,7 @@ impl MapArrayDecoder { let values = array_decoder_factory(values_column, values_field.clone(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); - let lengths = get_unsigned_rle_reader(column, reader); + let lengths = get_unsigned_int_decoder(reader, column.rle_version()); let fields = Fields::from(vec![keys_field, values_field]); diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs index 695beed..32ae92a 100644 --- a/src/array_decoder/mod.rs +++ b/src/array_decoder/mod.rs @@ -32,7 +32,7 @@ use crate::column::Column; use crate::encoding::boolean::BooleanDecoder; use crate::encoding::byte::ByteRleDecoder; use crate::encoding::float::FloatDecoder; -use crate::encoding::integer::get_rle_reader; +use crate::encoding::integer::get_signed_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::{ self, MismatchedSchemaSnafu, Result, UnexpectedSnafu, UnsupportedTypeVariantSnafu, @@ -277,19 +277,19 @@ pub fn array_decoder_factory( } (DataType::Short { .. }, ArrowDataType::Int16) => { let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(Int16ArrayDecoder::new(iter, present)) } (DataType::Int { .. }, ArrowDataType::Int32) => { let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(Int32ArrayDecoder::new(iter, present)) } (DataType::Long { .. }, ArrowDataType::Int64) => { let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(Int64ArrayDecoder::new(iter, present)) } @@ -315,7 +315,7 @@ pub fn array_decoder_factory( }, ArrowDataType::Decimal128(a_precision, a_scale), ) if *precision as u8 == *a_precision && *scale as i8 == *a_scale => { - new_decimal_decoder(column, stripe, *precision, *scale)? + new_decimal_decoder(column, stripe, *precision, *scale) } (DataType::Timestamp { .. }, field_type) => { new_timestamp_decoder(column, field_type.clone(), stripe)? @@ -326,7 +326,7 @@ pub fn array_decoder_factory( (DataType::Date { .. }, ArrowDataType::Date32) => { // TODO: allow Date64 let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(DateArrayDecoder::new(iter, present)) } diff --git a/src/array_decoder/string.rs b/src/array_decoder/string.rs index dda72f0..0518990 100644 --- a/src/array_decoder/string.rs +++ b/src/array_decoder/string.rs @@ -28,7 +28,7 @@ use snafu::ResultExt; use crate::array_decoder::derive_present_vec; use crate::column::Column; use crate::compression::Decompressor; -use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::integer::get_unsigned_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::{ArrowSnafu, IoSnafu, Result}; use crate::proto::column_encoding::Kind as ColumnEncodingKind; @@ -42,7 +42,7 @@ pub fn new_binary_decoder(column: &Column, stripe: &Stripe) -> Result Result { @@ -72,7 +72,7 @@ pub fn new_string_decoder(column: &Column, stripe: &Stripe) -> Result( column: &Column, stripe: &Stripe, seconds_since_unix_epoch: i64, -) -> Result> { +) -> PrimitiveArrayDecoder { let data = stripe.stream_map().get(column, Kind::Data); - let data = get_rle_reader(column, data)?; + let data = get_signed_int_decoder(data, column.rle_version()); let secondary = stripe.stream_map().get(column, Kind::Secondary); - let secondary = get_unsigned_rle_reader(column, secondary); + let secondary = get_unsigned_int_decoder(secondary, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); @@ -68,29 +68,29 @@ fn get_inner_timestamp_decoder( data, secondary, )); - Ok(PrimitiveArrayDecoder::::new(iter, present)) + PrimitiveArrayDecoder::::new(iter, present) } fn get_timestamp_decoder( column: &Column, stripe: &Stripe, seconds_since_unix_epoch: i64, -) -> Result> { - let inner = get_inner_timestamp_decoder::(column, stripe, seconds_since_unix_epoch)?; +) -> Box { + let inner = get_inner_timestamp_decoder::(column, stripe, seconds_since_unix_epoch); match stripe.writer_tz() { - Some(writer_tz) => Ok(Box::new(TimestampOffsetArrayDecoder { inner, writer_tz })), - None => Ok(Box::new(inner)), + Some(writer_tz) => Box::new(TimestampOffsetArrayDecoder { inner, writer_tz }), + None => Box::new(inner), } } fn get_timestamp_instant_decoder( column: &Column, stripe: &Stripe, -) -> Result> { +) -> Box { // TIMESTAMP_INSTANT is encoded as UTC so we don't check writer timezone in stripe let inner = - get_inner_timestamp_decoder::(column, stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH)?; - Ok(Box::new(TimestampInstantArrayDecoder(inner))) + get_inner_timestamp_decoder::(column, stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH); + Box::new(TimestampInstantArrayDecoder(inner)) } fn decimal128_decoder( @@ -98,12 +98,12 @@ fn decimal128_decoder( stripe: &Stripe, seconds_since_unix_epoch: i64, writer_tz: Option, -) -> Result { +) -> DecimalArrayDecoder { let data = stripe.stream_map().get(column, Kind::Data); - let data = get_rle_reader(column, data)?; + let data = get_signed_int_decoder(data, column.rle_version()); let secondary = stripe.stream_map().get(column, Kind::Secondary); - let secondary = get_unsigned_rle_reader(column, secondary); + let secondary = get_unsigned_int_decoder(secondary, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); @@ -114,12 +114,12 @@ fn decimal128_decoder( Some(writer_tz) => Box::new(TimestampNanosecondAsDecimalWithTzDecoder(iter, writer_tz)), }; - Ok(DecimalArrayDecoder::new( + DecimalArrayDecoder::new( Decimal128Type::MAX_PRECISION, NANOSECOND_DIGITS, iter, present, - )) + ) } /// Decodes a TIMESTAMP column stripe into batches of Timestamp{Nano,Micro,Milli,}secondArrays @@ -148,28 +148,32 @@ pub fn new_timestamp_decoder( match field_type { ArrowDataType::Timestamp(TimeUnit::Second, None) => { - get_timestamp_decoder::(column, stripe, seconds_since_unix_epoch) + Ok(get_timestamp_decoder::( + column, + stripe, + seconds_since_unix_epoch, + )) } ArrowDataType::Timestamp(TimeUnit::Millisecond, None) => { - get_timestamp_decoder::( + Ok(get_timestamp_decoder::( column, stripe, seconds_since_unix_epoch, - ) + )) } ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { - get_timestamp_decoder::( + Ok(get_timestamp_decoder::( column, stripe, seconds_since_unix_epoch, - ) + )) } ArrowDataType::Timestamp(TimeUnit::Nanosecond, None) => { - get_timestamp_decoder::( + Ok(get_timestamp_decoder::( column, stripe, seconds_since_unix_epoch, - ) + )) } ArrowDataType::Decimal128(Decimal128Type::MAX_PRECISION, NANOSECOND_DIGITS) => { Ok(Box::new(decimal128_decoder( @@ -177,7 +181,7 @@ pub fn new_timestamp_decoder( stripe, seconds_since_unix_epoch, stripe.writer_tz(), - )?)) + ))) } _ => MismatchedSchemaSnafu { orc_type: column.data_type().clone(), @@ -195,18 +199,18 @@ pub fn new_timestamp_instant_decoder( stripe: &Stripe, ) -> Result> { match field_type { - ArrowDataType::Timestamp(TimeUnit::Second, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } - ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } - ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } + ArrowDataType::Timestamp(TimeUnit::Second, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), + ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), ArrowDataType::Timestamp(_, Some(_)) => UnsupportedTypeVariantSnafu { msg: "Non-UTC Arrow timestamps", } @@ -217,7 +221,7 @@ pub fn new_timestamp_instant_decoder( stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH, None, - )?)) + ))) } _ => MismatchedSchemaSnafu { orc_type: column.data_type().clone(), diff --git a/src/column.rs b/src/column.rs index aaacb31..bfb3f26 100644 --- a/src/column.rs +++ b/src/column.rs @@ -20,8 +20,9 @@ use std::sync::Arc; use bytes::Bytes; use snafu::ResultExt; +use crate::encoding::integer::RleVersion; use crate::error::{IoSnafu, Result}; -use crate::proto::{ColumnEncoding, StripeFooter}; +use crate::proto::{column_encoding::Kind as ProtoColumnKind, ColumnEncoding, StripeFooter}; use crate::reader::ChunkReader; use crate::schema::DataType; @@ -53,6 +54,15 @@ impl Column { self.footer.columns[column].clone() } + pub fn rle_version(&self) -> RleVersion { + // TODO: Validity check for this? e.g. ensure INT column isn't dictionary encoded. + // Or maybe check that at init time; to ensure we catch at earliest opportunity? + match self.encoding().kind() { + ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => RleVersion::V1, + ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => RleVersion::V2, + } + } + pub fn data_type(&self) -> &DataType { &self.data_type } diff --git a/src/encoding/integer/mod.rs b/src/encoding/integer/mod.rs index f652d4e..9751f3a 100644 --- a/src/encoding/integer/mod.rs +++ b/src/encoding/integer/mod.rs @@ -31,11 +31,7 @@ use util::{ get_closest_aligned_bit_width, signed_msb_decode, signed_zigzag_decode, signed_zigzag_encode, }; -use crate::{ - column::Column, - error::{InvalidColumnEncodingSnafu, IoSnafu, Result}, - proto::column_encoding::Kind as ProtoColumnKind, -}; +use crate::error::{IoSnafu, Result}; use super::PrimitiveValueDecoder; @@ -46,34 +42,29 @@ mod util; // TODO: consider having a separate varint.rs pub use util::read_varint_zigzagged; -pub fn get_unsigned_rle_reader( - column: &Column, - reader: R, -) -> Box + Send> { - match column.encoding().kind() { - ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => { - Box::new(RleV1Decoder::::new(reader)) - } - ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => { - Box::new(RleV2Decoder::::new(reader)) - } +#[derive(Debug, Clone, Copy)] +pub enum RleVersion { + V1, + V2, +} + +pub fn get_signed_int_decoder( + reader: impl Read + Send + 'static, + rle_version: RleVersion, +) -> Box + Send> { + match rle_version { + RleVersion::V1 => Box::new(RleV1Decoder::::new(reader)), + RleVersion::V2 => Box::new(RleV2Decoder::::new(reader)), } } -pub fn get_rle_reader( - column: &Column, - reader: R, -) -> Result + Send>> { - match column.encoding().kind() { - ProtoColumnKind::Direct => Ok(Box::new(RleV1Decoder::::new(reader))), - ProtoColumnKind::DirectV2 => { - Ok(Box::new(RleV2Decoder::::new(reader))) - } - k => InvalidColumnEncodingSnafu { - name: column.name(), - encoding: k, - } - .fail(), +pub fn get_unsigned_int_decoder( + reader: impl Read + Send + 'static, + rle_version: RleVersion, +) -> Box + Send> { + match rle_version { + RleVersion::V1 => Box::new(RleV1Decoder::::new(reader)), + RleVersion::V2 => Box::new(RleV2Decoder::::new(reader)), } } diff --git a/src/error.rs b/src/error.rs index 02713e2..9cf8caf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,7 +23,6 @@ use arrow::error::ArrowError; use snafu::prelude::*; use snafu::Location; -use crate::proto; use crate::schema::DataType; // TODO: consolidate error types? better to have a smaller set? @@ -103,14 +102,6 @@ pub enum OrcError { arrow_type: ArrowDataType, }, - #[snafu(display("Invalid encoding for column '{}': {:?}", name, encoding))] - InvalidColumnEncoding { - #[snafu(implicit)] - location: Location, - name: String, - encoding: proto::column_encoding::Kind, - }, - #[snafu(display("Failed to convert to record batch: {}", source))] ConvertRecordBatch { #[snafu(implicit)] From 6df9d3196a6c31d49f776c6ebe22094459ee69d8 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Wed, 1 Jan 2025 17:53:05 +1100 Subject: [PATCH 2/6] Remove read_stream methods from Column These are just indirections and serve no purpose as an abstraction, as they don't actually rely on fields from Column. --- src/column.rs | 18 ------------------ src/stripe.rs | 7 +++++-- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/column.rs b/src/column.rs index bfb3f26..b22ee22 100644 --- a/src/column.rs +++ b/src/column.rs @@ -17,13 +17,8 @@ use std::sync::Arc; -use bytes::Bytes; -use snafu::ResultExt; - use crate::encoding::integer::RleVersion; -use crate::error::{IoSnafu, Result}; use crate::proto::{column_encoding::Kind as ProtoColumnKind, ColumnEncoding, StripeFooter}; -use crate::reader::ChunkReader; use crate::schema::DataType; #[derive(Clone, Debug)] @@ -135,17 +130,4 @@ impl Column { } } } - - pub fn read_stream(reader: &mut R, start: u64, length: u64) -> Result { - reader.get_bytes(start, length).context(IoSnafu) - } - - #[cfg(feature = "async")] - pub async fn read_stream_async( - reader: &mut R, - start: u64, - length: u64, - ) -> Result { - reader.get_bytes(start, length).await.context(IoSnafu) - } } diff --git a/src/stripe.rs b/src/stripe.rs index fd09559..a12f5a3 100644 --- a/src/stripe.rs +++ b/src/stripe.rs @@ -154,7 +154,7 @@ impl Stripe { let column_id = stream.column(); if column_ids.contains(&column_id) { let kind = stream.kind(); - let data = Column::read_stream(reader, stream_offset, length)?; + let data = reader.get_bytes(stream_offset, length).context(IoSnafu)?; stream_map.insert((column_id, kind), data); } stream_offset += length; @@ -207,7 +207,10 @@ impl Stripe { let column_id = stream.column(); if column_ids.contains(&column_id) { let kind = stream.kind(); - let data = Column::read_stream_async(reader, stream_offset, length).await?; + let data = reader + .get_bytes(stream_offset, length) + .await + .context(IoSnafu)?; stream_map.insert((column_id, kind), data); } From 279efea68bcd252300ce474069a73753e58d439b Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Thu, 2 Jan 2025 14:57:19 +1100 Subject: [PATCH 3/6] Minor refactoring --- src/array_decoder/list.rs | 2 +- src/array_decoder/map.rs | 4 ++-- src/array_decoder/mod.rs | 18 +++++++----------- src/array_decoder/struct_decoder.rs | 4 ++-- src/array_decoder/union.rs | 2 +- 5 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs index 1301ee4..e70e059 100644 --- a/src/array_decoder/list.rs +++ b/src/array_decoder/list.rs @@ -45,7 +45,7 @@ impl ListArrayDecoder { let present = PresentDecoder::from_stripe(stripe, column); let child = &column.children()[0]; - let inner = array_decoder_factory(child, field.clone(), stripe)?; + let inner = array_decoder_factory(child, field.data_type(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); let lengths = get_unsigned_int_decoder(reader, column.rle_version()); diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs index fb00f12..175b09a 100644 --- a/src/array_decoder/map.rs +++ b/src/array_decoder/map.rs @@ -50,10 +50,10 @@ impl MapArrayDecoder { let present = PresentDecoder::from_stripe(stripe, column); let keys_column = &column.children()[0]; - let keys = array_decoder_factory(keys_column, keys_field.clone(), stripe)?; + let keys = array_decoder_factory(keys_column, keys_field.data_type(), stripe)?; let values_column = &column.children()[1]; - let values = array_decoder_factory(values_column, values_field.clone(), stripe)?; + let values = array_decoder_factory(values_column, values_field.data_type(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); let lengths = get_unsigned_int_decoder(reader, column.rle_version()); diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs index 32ae92a..00db5ff 100644 --- a/src/array_decoder/mod.rs +++ b/src/array_decoder/mod.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray} use arrow::buffer::NullBuffer; use arrow::datatypes::ArrowNativeTypeOp; use arrow::datatypes::ArrowPrimitiveType; -use arrow::datatypes::{DataType as ArrowDataType, Field}; +use arrow::datatypes::DataType as ArrowDataType; use arrow::datatypes::{ Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, }; @@ -258,10 +258,10 @@ impl Iterator for NaiveStripeDecoder { pub fn array_decoder_factory( column: &Column, - field: Arc, + hinted_arrow_type: &ArrowDataType, stripe: &Stripe, ) -> Result> { - let decoder: Box = match (column.data_type(), field.data_type()) { + let decoder: Box = match (column.data_type(), hinted_arrow_type) { // TODO: try make branches more generic, reduce duplication (DataType::Boolean { .. }, ArrowDataType::Boolean) => { let iter = stripe.stream_map().get(column, Kind::Data); @@ -433,17 +433,13 @@ impl NaiveStripeDecoder { } pub fn new(stripe: Stripe, schema_ref: SchemaRef, batch_size: usize) -> Result { - let mut decoders = Vec::with_capacity(stripe.columns().len()); let number_of_rows = stripe.number_of_rows(); - - for (col, field) in stripe + let decoders = stripe .columns() .iter() - .zip(schema_ref.fields.iter().cloned()) - { - let decoder = array_decoder_factory(col, field, &stripe)?; - decoders.push(decoder); - } + .zip(schema_ref.fields.iter()) + .map(|(col, field)| array_decoder_factory(col, field.data_type(), &stripe)) + .collect::>>()?; Ok(Self { stripe, diff --git a/src/array_decoder/struct_decoder.rs b/src/array_decoder/struct_decoder.rs index b09bb3b..e8486b2 100644 --- a/src/array_decoder/struct_decoder.rs +++ b/src/array_decoder/struct_decoder.rs @@ -43,8 +43,8 @@ impl StructArrayDecoder { let decoders = column .children() .iter() - .zip(fields.iter().cloned()) - .map(|(child, field)| array_decoder_factory(child, field, stripe)) + .zip(fields.iter()) + .map(|(child, field)| array_decoder_factory(child, field.data_type(), stripe)) .collect::>>()?; Ok(Self { diff --git a/src/array_decoder/union.rs b/src/array_decoder/union.rs index 39af4ea..a674832 100644 --- a/src/array_decoder/union.rs +++ b/src/array_decoder/union.rs @@ -53,7 +53,7 @@ impl UnionArrayDecoder { .children() .iter() .zip(fields.iter()) - .map(|(child, (_id, field))| array_decoder_factory(child, field.clone(), stripe)) + .map(|(child, (_, field))| array_decoder_factory(child, field.data_type(), stripe)) .collect::>>()?; Ok(Self { From 8c0bd7712a61fe1cd0c2147c42933518d2b361ec Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Thu, 2 Jan 2025 15:21:03 +1100 Subject: [PATCH 4/6] Minor refactoring --- src/encoding/integer/mod.rs | 5 ++--- src/encoding/mod.rs | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/encoding/integer/mod.rs b/src/encoding/integer/mod.rs index 9751f3a..de323ef 100644 --- a/src/encoding/integer/mod.rs +++ b/src/encoding/integer/mod.rs @@ -210,10 +210,9 @@ impl VarintSerde for i128 { } } -// We only implement for i16, i32, i64 and u64. +// We only implement for i16, i32 and i64. // ORC supports only signed Short, Integer and Long types for its integer types, -// and i8 is encoded as bytes. u64 is used for other encodings such as Strings -// (to encode length, etc.). +// and i8 is encoded as bytes. impl NInt for i16 { type Bytes = [u8; 2]; diff --git a/src/encoding/mod.rs b/src/encoding/mod.rs index 871ae0b..7a5dd08 100644 --- a/src/encoding/mod.rs +++ b/src/encoding/mod.rs @@ -33,10 +33,7 @@ mod util; /// Encodes primitive values into an internal buffer, usually with a specialized run length /// encoding for better compression. -pub trait PrimitiveValueEncoder: EstimateMemory -where - V: Copy, -{ +pub trait PrimitiveValueEncoder: EstimateMemory { fn new() -> Self; fn write_one(&mut self, value: V); From 205951505da059a56c222667bc4c3467f0527019 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Thu, 2 Jan 2025 18:56:29 +1100 Subject: [PATCH 5/6] Minor refactoring around stripe/column --- src/column.rs | 8 +++--- src/stripe.rs | 73 ++++++++++++++++++++++++++++++--------------------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/column.rs b/src/column.rs index b22ee22..6757e2f 100644 --- a/src/column.rs +++ b/src/column.rs @@ -29,11 +29,11 @@ pub struct Column { } impl Column { - pub fn new(name: &str, data_type: &DataType, footer: &Arc) -> Self { + pub fn new(name: String, data_type: DataType, footer: Arc) -> Self { Self { - footer: footer.clone(), - data_type: data_type.clone(), - name: name.to_string(), + footer, + data_type, + name, } } diff --git a/src/stripe.rs b/src/stripe.rs index a12f5a3..3384eb1 100644 --- a/src/stripe.rs +++ b/src/stripe.rs @@ -85,19 +85,19 @@ impl TryFrom<(&proto::StripeInformation, &proto::StripeStatistics)> for StripeMe type Error = error::OrcError; fn try_from(value: (&proto::StripeInformation, &proto::StripeStatistics)) -> Result { - let column_statistics = value - .1 + let (info, statistics) = value; + let column_statistics = statistics .col_stats .iter() .map(TryFrom::try_from) .collect::>>()?; Ok(Self { column_statistics, - offset: value.0.offset(), - index_length: value.0.index_length(), - data_length: value.0.data_length(), - footer_length: value.0.footer_length(), - number_of_rows: value.0.number_of_rows(), + offset: info.offset(), + index_length: info.index_length(), + data_length: info.data_length(), + footer_length: info.footer_length(), + number_of_rows: info.number_of_rows(), }) } } @@ -120,8 +120,7 @@ impl TryFrom<&proto::StripeInformation> for StripeMetadata { #[derive(Debug)] pub struct Stripe { columns: Vec, - /// <(ColumnId, Kind), Bytes> - stream_map: Arc, + stream_map: StreamMap, number_of_rows: usize, tz: Option, } @@ -129,21 +128,28 @@ pub struct Stripe { impl Stripe { pub fn new( reader: &mut R, - file_metadata: &Arc, + file_metadata: &FileMetadata, projected_data_type: &RootDataType, info: &StripeMetadata, ) -> Result { - let compression = file_metadata.compression(); - let footer = reader .get_bytes(info.footer_offset(), info.footer_length()) .context(IoSnafu)?; - let footer = Arc::new(deserialize_stripe_footer(footer, compression)?); + let footer = Arc::new(deserialize_stripe_footer( + footer, + file_metadata.compression(), + )?); let columns: Vec = projected_data_type .children() .iter() - .map(|col| Column::new(col.name(), col.data_type(), &footer)) + .map(|col| { + Column::new( + col.name().to_string(), + col.data_type().clone(), + footer.clone(), + ) + }) .collect(); let column_ids = collect_required_column_ids(&columns); @@ -160,7 +166,7 @@ impl Stripe { stream_offset += length; } - let tz: Option = footer + let tz = footer .writer_timezone .as_ref() // TODO: make this return error @@ -168,10 +174,10 @@ impl Stripe { Ok(Self { columns, - stream_map: Arc::new(StreamMap { + stream_map: StreamMap { inner: stream_map, - compression, - }), + compression: file_metadata.compression(), + }, number_of_rows: info.number_of_rows() as usize, tz, }) @@ -181,22 +187,29 @@ impl Stripe { #[cfg(feature = "async")] pub async fn new_async( reader: &mut R, - file_metadata: &Arc, + file_metadata: &FileMetadata, projected_data_type: &RootDataType, info: &StripeMetadata, ) -> Result { - let compression = file_metadata.compression(); - let footer = reader .get_bytes(info.footer_offset(), info.footer_length()) .await .context(IoSnafu)?; - let footer = Arc::new(deserialize_stripe_footer(footer, compression)?); + let footer = Arc::new(deserialize_stripe_footer( + footer, + file_metadata.compression(), + )?); let columns: Vec = projected_data_type .children() .iter() - .map(|col| Column::new(col.name(), col.data_type(), &footer)) + .map(|col| { + Column::new( + col.name().to_string(), + col.data_type().clone(), + footer.clone(), + ) + }) .collect(); let column_ids = collect_required_column_ids(&columns); @@ -217,7 +230,7 @@ impl Stripe { stream_offset += length; } - let tz: Option = footer + let tz = footer .writer_timezone .as_ref() // TODO: make this return error @@ -225,10 +238,10 @@ impl Stripe { Ok(Self { columns, - stream_map: Arc::new(StreamMap { + stream_map: StreamMap { inner: stream_map, - compression, - }), + compression: file_metadata.compression(), + }, number_of_rows: info.number_of_rows() as usize, tz, }) @@ -238,7 +251,6 @@ impl Stripe { self.number_of_rows } - /// Fetch the stream map pub fn stream_map(&self) -> &StreamMap { &self.stream_map } @@ -254,8 +266,9 @@ impl Stripe { #[derive(Debug)] pub struct StreamMap { - pub inner: HashMap<(u32, Kind), Bytes>, - pub compression: Option, + /// <(ColumnId, Kind), Bytes> + inner: HashMap<(u32, Kind), Bytes>, + compression: Option, } impl StreamMap { From 49892857717a9739752a4251f7637fabe7147712 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Thu, 2 Jan 2025 19:44:55 +1100 Subject: [PATCH 6/6] Allow RootDataType to lookup all transitive column indices --- src/schema.rs | 41 +++++++++++++++++++++++++++++++---------- src/stripe.rs | 16 ++-------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 8dcbb0e..d1fc241 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt::Display; use std::sync::Arc; @@ -41,6 +41,7 @@ use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, Union #[derive(Debug, Clone)] pub struct RootDataType { children: Vec, + all_children: HashSet, } impl RootDataType { @@ -54,6 +55,12 @@ impl RootDataType { &self.children } + /// If specified column index is one of the projected columns in this root data type, + /// considering transitive children of compound types. + pub fn contains_column_index(&self, index: usize) -> bool { + self.all_children.contains(&index) + } + /// Convert into an Arrow schema. pub fn create_arrow_schema(&self, user_metadata: &HashMap) -> Schema { let fields = self @@ -76,14 +83,22 @@ impl RootDataType { .filter(|col| mask.is_index_projected(col.data_type().column_index())) .map(|col| col.to_owned()) .collect::>(); - Self { children } + let all_children = get_all_children_indices_set(&children); + Self { + children, + all_children, + } } /// Construct from protobuf types. pub(crate) fn from_proto(types: &[proto::Type]) -> Result { ensure!(!types.is_empty(), NoTypesSnafu {}); let children = parse_struct_children_from_proto(types, 0)?; - Ok(Self { children }) + let all_children = get_all_children_indices_set(&children); + Ok(Self { + children, + all_children, + }) } } @@ -97,6 +112,13 @@ impl Display for RootDataType { } } +fn get_all_children_indices_set(columns: &[NamedColumn]) -> HashSet { + let mut set = HashSet::new(); + set.insert(0); + set.extend(columns.iter().flat_map(|c| c.data_type().all_indices())); + set +} + #[derive(Debug, Clone)] pub struct NamedColumn { name: String, @@ -285,18 +307,17 @@ impl DataType { | DataType::Date { .. } => vec![], DataType::Struct { children, .. } => children .iter() - .flat_map(|col| col.data_type().children_indices()) + .flat_map(|col| col.data_type().all_indices()) .collect(), DataType::List { child, .. } => child.all_indices(), DataType::Map { key, value, .. } => { - let mut indices = key.children_indices(); - indices.extend(value.children_indices()); + let mut indices = key.all_indices(); + indices.extend(value.all_indices()); indices } - DataType::Union { variants, .. } => variants - .iter() - .flat_map(|dt| dt.children_indices()) - .collect(), + DataType::Union { variants, .. } => { + variants.iter().flat_map(|dt| dt.all_indices()).collect() + } } } diff --git a/src/stripe.rs b/src/stripe.rs index 3384eb1..3703fab 100644 --- a/src/stripe.rs +++ b/src/stripe.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; use std::{collections::HashMap, io::Read, sync::Arc}; use bytes::Bytes; @@ -151,14 +150,13 @@ impl Stripe { ) }) .collect(); - let column_ids = collect_required_column_ids(&columns); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); for stream in &footer.streams { let length = stream.length(); let column_id = stream.column(); - if column_ids.contains(&column_id) { + if projected_data_type.contains_column_index(column_id as usize) { let kind = stream.kind(); let data = reader.get_bytes(stream_offset, length).context(IoSnafu)?; stream_map.insert((column_id, kind), data); @@ -211,14 +209,13 @@ impl Stripe { ) }) .collect(); - let column_ids = collect_required_column_ids(&columns); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); for stream in &footer.streams { let length = stream.length(); let column_id = stream.column(); - if column_ids.contains(&column_id) { + if projected_data_type.contains_column_index(column_id as usize) { let kind = stream.kind(); let data = reader .get_bytes(stream_offset, length) @@ -301,12 +298,3 @@ fn deserialize_stripe_footer( .context(error::IoSnafu)?; StripeFooter::decode(buffer.as_slice()).context(error::DecodeProtoSnafu) } - -fn collect_required_column_ids(columns: &[Column]) -> HashSet { - let mut set = HashSet::new(); - for column in columns { - set.insert(column.column_id()); - set.extend(collect_required_column_ids(&column.children())); - } - set -}