diff --git a/src/array_decoder/decimal.rs b/src/array_decoder/decimal.rs index 40e766e..9ecf274 100644 --- a/src/array_decoder/decimal.rs +++ b/src/array_decoder/decimal.rs @@ -24,7 +24,7 @@ use arrow::datatypes::Decimal128Type; use snafu::ResultExt; use crate::encoding::decimal::UnboundedVarintStreamDecoder; -use crate::encoding::integer::get_rle_reader; +use crate::encoding::integer::get_signed_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::ArrowSnafu; use crate::proto::stream::Kind; @@ -38,13 +38,13 @@ pub fn new_decimal_decoder( stripe: &Stripe, precision: u32, fixed_scale: u32, -) -> Result> { +) -> Box { let varint_iter = stripe.stream_map().get(column, Kind::Data); let varint_iter = Box::new(UnboundedVarintStreamDecoder::new(varint_iter)); // Scale is specified on a per varint basis (in addition to being encoded in the type) let scale_iter = stripe.stream_map().get(column, Kind::Secondary); - let scale_iter = get_rle_reader::(column, scale_iter)?; + let scale_iter = get_signed_int_decoder::(scale_iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); @@ -55,12 +55,12 @@ pub fn new_decimal_decoder( }; let iter = Box::new(iter); - Ok(Box::new(DecimalArrayDecoder::new( + Box::new(DecimalArrayDecoder::new( precision as u8, fixed_scale as i8, iter, present, - ))) + )) } /// Wrapper around PrimitiveArrayDecoder to allow specifying the precision and scale diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs index 7a5ccfb..e70e059 100644 --- a/src/array_decoder/list.rs +++ b/src/array_decoder/list.rs @@ -24,7 +24,7 @@ use snafu::ResultExt; use crate::array_decoder::derive_present_vec; use crate::column::Column; -use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::integer::get_unsigned_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::proto::stream::Kind; @@ -45,10 +45,10 @@ impl ListArrayDecoder { let present = PresentDecoder::from_stripe(stripe, column); let child = &column.children()[0]; - let inner = array_decoder_factory(child, field.clone(), stripe)?; + let inner = array_decoder_factory(child, field.data_type(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); - let lengths = get_unsigned_rle_reader(column, reader); + let lengths = get_unsigned_int_decoder(reader, column.rle_version()); Ok(Self { inner, diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs index 7c01988..175b09a 100644 --- a/src/array_decoder/map.rs +++ b/src/array_decoder/map.rs @@ -24,7 +24,7 @@ use snafu::ResultExt; use crate::array_decoder::derive_present_vec; use crate::column::Column; -use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::integer::get_unsigned_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::{ArrowSnafu, Result}; use crate::proto::stream::Kind; @@ -50,13 +50,13 @@ impl MapArrayDecoder { let present = PresentDecoder::from_stripe(stripe, column); let keys_column = &column.children()[0]; - let keys = array_decoder_factory(keys_column, keys_field.clone(), stripe)?; + let keys = array_decoder_factory(keys_column, keys_field.data_type(), stripe)?; let values_column = &column.children()[1]; - let values = array_decoder_factory(values_column, values_field.clone(), stripe)?; + let values = array_decoder_factory(values_column, values_field.data_type(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); - let lengths = get_unsigned_rle_reader(column, reader); + let lengths = get_unsigned_int_decoder(reader, column.rle_version()); let fields = Fields::from(vec![keys_field, values_field]); diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs index 695beed..00db5ff 100644 --- a/src/array_decoder/mod.rs +++ b/src/array_decoder/mod.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray} use arrow::buffer::NullBuffer; use arrow::datatypes::ArrowNativeTypeOp; use arrow::datatypes::ArrowPrimitiveType; -use arrow::datatypes::{DataType as ArrowDataType, Field}; +use arrow::datatypes::DataType as ArrowDataType; use arrow::datatypes::{ Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, }; @@ -32,7 +32,7 @@ use crate::column::Column; use crate::encoding::boolean::BooleanDecoder; use crate::encoding::byte::ByteRleDecoder; use crate::encoding::float::FloatDecoder; -use crate::encoding::integer::get_rle_reader; +use crate::encoding::integer::get_signed_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::{ self, MismatchedSchemaSnafu, Result, UnexpectedSnafu, UnsupportedTypeVariantSnafu, @@ -258,10 +258,10 @@ impl Iterator for NaiveStripeDecoder { pub fn array_decoder_factory( column: &Column, - field: Arc, + hinted_arrow_type: &ArrowDataType, stripe: &Stripe, ) -> Result> { - let decoder: Box = match (column.data_type(), field.data_type()) { + let decoder: Box = match (column.data_type(), hinted_arrow_type) { // TODO: try make branches more generic, reduce duplication (DataType::Boolean { .. }, ArrowDataType::Boolean) => { let iter = stripe.stream_map().get(column, Kind::Data); @@ -277,19 +277,19 @@ pub fn array_decoder_factory( } (DataType::Short { .. }, ArrowDataType::Int16) => { let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(Int16ArrayDecoder::new(iter, present)) } (DataType::Int { .. }, ArrowDataType::Int32) => { let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(Int32ArrayDecoder::new(iter, present)) } (DataType::Long { .. }, ArrowDataType::Int64) => { let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(Int64ArrayDecoder::new(iter, present)) } @@ -315,7 +315,7 @@ pub fn array_decoder_factory( }, ArrowDataType::Decimal128(a_precision, a_scale), ) if *precision as u8 == *a_precision && *scale as i8 == *a_scale => { - new_decimal_decoder(column, stripe, *precision, *scale)? + new_decimal_decoder(column, stripe, *precision, *scale) } (DataType::Timestamp { .. }, field_type) => { new_timestamp_decoder(column, field_type.clone(), stripe)? @@ -326,7 +326,7 @@ pub fn array_decoder_factory( (DataType::Date { .. }, ArrowDataType::Date32) => { // TODO: allow Date64 let iter = stripe.stream_map().get(column, Kind::Data); - let iter = get_rle_reader(column, iter)?; + let iter = get_signed_int_decoder(iter, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); Box::new(DateArrayDecoder::new(iter, present)) } @@ -433,17 +433,13 @@ impl NaiveStripeDecoder { } pub fn new(stripe: Stripe, schema_ref: SchemaRef, batch_size: usize) -> Result { - let mut decoders = Vec::with_capacity(stripe.columns().len()); let number_of_rows = stripe.number_of_rows(); - - for (col, field) in stripe + let decoders = stripe .columns() .iter() - .zip(schema_ref.fields.iter().cloned()) - { - let decoder = array_decoder_factory(col, field, &stripe)?; - decoders.push(decoder); - } + .zip(schema_ref.fields.iter()) + .map(|(col, field)| array_decoder_factory(col, field.data_type(), &stripe)) + .collect::>>()?; Ok(Self { stripe, diff --git a/src/array_decoder/string.rs b/src/array_decoder/string.rs index dda72f0..0518990 100644 --- a/src/array_decoder/string.rs +++ b/src/array_decoder/string.rs @@ -28,7 +28,7 @@ use snafu::ResultExt; use crate::array_decoder::derive_present_vec; use crate::column::Column; use crate::compression::Decompressor; -use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::integer::get_unsigned_int_decoder; use crate::encoding::PrimitiveValueDecoder; use crate::error::{ArrowSnafu, IoSnafu, Result}; use crate::proto::column_encoding::Kind as ColumnEncodingKind; @@ -42,7 +42,7 @@ pub fn new_binary_decoder(column: &Column, stripe: &Stripe) -> Result Result { @@ -72,7 +72,7 @@ pub fn new_string_decoder(column: &Column, stripe: &Stripe) -> Result>>()?; Ok(Self { diff --git a/src/array_decoder/timestamp.rs b/src/array_decoder/timestamp.rs index 1bad694..ee3bd8d 100644 --- a/src/array_decoder/timestamp.rs +++ b/src/array_decoder/timestamp.rs @@ -21,7 +21,7 @@ use crate::{ array_decoder::ArrowDataType, column::Column, encoding::{ - integer::{get_rle_reader, get_unsigned_rle_reader}, + integer::{get_signed_int_decoder, get_unsigned_int_decoder}, timestamp::{TimestampDecoder, TimestampNanosecondAsDecimalDecoder}, PrimitiveValueDecoder, }, @@ -54,12 +54,12 @@ fn get_inner_timestamp_decoder( column: &Column, stripe: &Stripe, seconds_since_unix_epoch: i64, -) -> Result> { +) -> PrimitiveArrayDecoder { let data = stripe.stream_map().get(column, Kind::Data); - let data = get_rle_reader(column, data)?; + let data = get_signed_int_decoder(data, column.rle_version()); let secondary = stripe.stream_map().get(column, Kind::Secondary); - let secondary = get_unsigned_rle_reader(column, secondary); + let secondary = get_unsigned_int_decoder(secondary, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); @@ -68,29 +68,29 @@ fn get_inner_timestamp_decoder( data, secondary, )); - Ok(PrimitiveArrayDecoder::::new(iter, present)) + PrimitiveArrayDecoder::::new(iter, present) } fn get_timestamp_decoder( column: &Column, stripe: &Stripe, seconds_since_unix_epoch: i64, -) -> Result> { - let inner = get_inner_timestamp_decoder::(column, stripe, seconds_since_unix_epoch)?; +) -> Box { + let inner = get_inner_timestamp_decoder::(column, stripe, seconds_since_unix_epoch); match stripe.writer_tz() { - Some(writer_tz) => Ok(Box::new(TimestampOffsetArrayDecoder { inner, writer_tz })), - None => Ok(Box::new(inner)), + Some(writer_tz) => Box::new(TimestampOffsetArrayDecoder { inner, writer_tz }), + None => Box::new(inner), } } fn get_timestamp_instant_decoder( column: &Column, stripe: &Stripe, -) -> Result> { +) -> Box { // TIMESTAMP_INSTANT is encoded as UTC so we don't check writer timezone in stripe let inner = - get_inner_timestamp_decoder::(column, stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH)?; - Ok(Box::new(TimestampInstantArrayDecoder(inner))) + get_inner_timestamp_decoder::(column, stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH); + Box::new(TimestampInstantArrayDecoder(inner)) } fn decimal128_decoder( @@ -98,12 +98,12 @@ fn decimal128_decoder( stripe: &Stripe, seconds_since_unix_epoch: i64, writer_tz: Option, -) -> Result { +) -> DecimalArrayDecoder { let data = stripe.stream_map().get(column, Kind::Data); - let data = get_rle_reader(column, data)?; + let data = get_signed_int_decoder(data, column.rle_version()); let secondary = stripe.stream_map().get(column, Kind::Secondary); - let secondary = get_unsigned_rle_reader(column, secondary); + let secondary = get_unsigned_int_decoder(secondary, column.rle_version()); let present = PresentDecoder::from_stripe(stripe, column); @@ -114,12 +114,12 @@ fn decimal128_decoder( Some(writer_tz) => Box::new(TimestampNanosecondAsDecimalWithTzDecoder(iter, writer_tz)), }; - Ok(DecimalArrayDecoder::new( + DecimalArrayDecoder::new( Decimal128Type::MAX_PRECISION, NANOSECOND_DIGITS, iter, present, - )) + ) } /// Decodes a TIMESTAMP column stripe into batches of Timestamp{Nano,Micro,Milli,}secondArrays @@ -148,28 +148,32 @@ pub fn new_timestamp_decoder( match field_type { ArrowDataType::Timestamp(TimeUnit::Second, None) => { - get_timestamp_decoder::(column, stripe, seconds_since_unix_epoch) + Ok(get_timestamp_decoder::( + column, + stripe, + seconds_since_unix_epoch, + )) } ArrowDataType::Timestamp(TimeUnit::Millisecond, None) => { - get_timestamp_decoder::( + Ok(get_timestamp_decoder::( column, stripe, seconds_since_unix_epoch, - ) + )) } ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { - get_timestamp_decoder::( + Ok(get_timestamp_decoder::( column, stripe, seconds_since_unix_epoch, - ) + )) } ArrowDataType::Timestamp(TimeUnit::Nanosecond, None) => { - get_timestamp_decoder::( + Ok(get_timestamp_decoder::( column, stripe, seconds_since_unix_epoch, - ) + )) } ArrowDataType::Decimal128(Decimal128Type::MAX_PRECISION, NANOSECOND_DIGITS) => { Ok(Box::new(decimal128_decoder( @@ -177,7 +181,7 @@ pub fn new_timestamp_decoder( stripe, seconds_since_unix_epoch, stripe.writer_tz(), - )?)) + ))) } _ => MismatchedSchemaSnafu { orc_type: column.data_type().clone(), @@ -195,18 +199,18 @@ pub fn new_timestamp_instant_decoder( stripe: &Stripe, ) -> Result> { match field_type { - ArrowDataType::Timestamp(TimeUnit::Second, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } - ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } - ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) if tz.as_ref() == "UTC" => { - get_timestamp_instant_decoder::(column, stripe) - } + ArrowDataType::Timestamp(TimeUnit::Second, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), + ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) if tz.as_ref() == "UTC" => Ok( + get_timestamp_instant_decoder::(column, stripe), + ), ArrowDataType::Timestamp(_, Some(_)) => UnsupportedTypeVariantSnafu { msg: "Non-UTC Arrow timestamps", } @@ -217,7 +221,7 @@ pub fn new_timestamp_instant_decoder( stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH, None, - )?)) + ))) } _ => MismatchedSchemaSnafu { orc_type: column.data_type().clone(), diff --git a/src/array_decoder/union.rs b/src/array_decoder/union.rs index 39af4ea..a674832 100644 --- a/src/array_decoder/union.rs +++ b/src/array_decoder/union.rs @@ -53,7 +53,7 @@ impl UnionArrayDecoder { .children() .iter() .zip(fields.iter()) - .map(|(child, (_id, field))| array_decoder_factory(child, field.clone(), stripe)) + .map(|(child, (_, field))| array_decoder_factory(child, field.data_type(), stripe)) .collect::>>()?; Ok(Self { diff --git a/src/column.rs b/src/column.rs index aaacb31..6757e2f 100644 --- a/src/column.rs +++ b/src/column.rs @@ -17,12 +17,8 @@ use std::sync::Arc; -use bytes::Bytes; -use snafu::ResultExt; - -use crate::error::{IoSnafu, Result}; -use crate::proto::{ColumnEncoding, StripeFooter}; -use crate::reader::ChunkReader; +use crate::encoding::integer::RleVersion; +use crate::proto::{column_encoding::Kind as ProtoColumnKind, ColumnEncoding, StripeFooter}; use crate::schema::DataType; #[derive(Clone, Debug)] @@ -33,11 +29,11 @@ pub struct Column { } impl Column { - pub fn new(name: &str, data_type: &DataType, footer: &Arc) -> Self { + pub fn new(name: String, data_type: DataType, footer: Arc) -> Self { Self { - footer: footer.clone(), - data_type: data_type.clone(), - name: name.to_string(), + footer, + data_type, + name, } } @@ -53,6 +49,15 @@ impl Column { self.footer.columns[column].clone() } + pub fn rle_version(&self) -> RleVersion { + // TODO: Validity check for this? e.g. ensure INT column isn't dictionary encoded. + // Or maybe check that at init time; to ensure we catch at earliest opportunity? + match self.encoding().kind() { + ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => RleVersion::V1, + ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => RleVersion::V2, + } + } + pub fn data_type(&self) -> &DataType { &self.data_type } @@ -125,17 +130,4 @@ impl Column { } } } - - pub fn read_stream(reader: &mut R, start: u64, length: u64) -> Result { - reader.get_bytes(start, length).context(IoSnafu) - } - - #[cfg(feature = "async")] - pub async fn read_stream_async( - reader: &mut R, - start: u64, - length: u64, - ) -> Result { - reader.get_bytes(start, length).await.context(IoSnafu) - } } diff --git a/src/encoding/integer/mod.rs b/src/encoding/integer/mod.rs index f652d4e..de323ef 100644 --- a/src/encoding/integer/mod.rs +++ b/src/encoding/integer/mod.rs @@ -31,11 +31,7 @@ use util::{ get_closest_aligned_bit_width, signed_msb_decode, signed_zigzag_decode, signed_zigzag_encode, }; -use crate::{ - column::Column, - error::{InvalidColumnEncodingSnafu, IoSnafu, Result}, - proto::column_encoding::Kind as ProtoColumnKind, -}; +use crate::error::{IoSnafu, Result}; use super::PrimitiveValueDecoder; @@ -46,34 +42,29 @@ mod util; // TODO: consider having a separate varint.rs pub use util::read_varint_zigzagged; -pub fn get_unsigned_rle_reader( - column: &Column, - reader: R, -) -> Box + Send> { - match column.encoding().kind() { - ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => { - Box::new(RleV1Decoder::::new(reader)) - } - ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => { - Box::new(RleV2Decoder::::new(reader)) - } +#[derive(Debug, Clone, Copy)] +pub enum RleVersion { + V1, + V2, +} + +pub fn get_signed_int_decoder( + reader: impl Read + Send + 'static, + rle_version: RleVersion, +) -> Box + Send> { + match rle_version { + RleVersion::V1 => Box::new(RleV1Decoder::::new(reader)), + RleVersion::V2 => Box::new(RleV2Decoder::::new(reader)), } } -pub fn get_rle_reader( - column: &Column, - reader: R, -) -> Result + Send>> { - match column.encoding().kind() { - ProtoColumnKind::Direct => Ok(Box::new(RleV1Decoder::::new(reader))), - ProtoColumnKind::DirectV2 => { - Ok(Box::new(RleV2Decoder::::new(reader))) - } - k => InvalidColumnEncodingSnafu { - name: column.name(), - encoding: k, - } - .fail(), +pub fn get_unsigned_int_decoder( + reader: impl Read + Send + 'static, + rle_version: RleVersion, +) -> Box + Send> { + match rle_version { + RleVersion::V1 => Box::new(RleV1Decoder::::new(reader)), + RleVersion::V2 => Box::new(RleV2Decoder::::new(reader)), } } @@ -219,10 +210,9 @@ impl VarintSerde for i128 { } } -// We only implement for i16, i32, i64 and u64. +// We only implement for i16, i32 and i64. // ORC supports only signed Short, Integer and Long types for its integer types, -// and i8 is encoded as bytes. u64 is used for other encodings such as Strings -// (to encode length, etc.). +// and i8 is encoded as bytes. impl NInt for i16 { type Bytes = [u8; 2]; diff --git a/src/encoding/mod.rs b/src/encoding/mod.rs index 871ae0b..7a5dd08 100644 --- a/src/encoding/mod.rs +++ b/src/encoding/mod.rs @@ -33,10 +33,7 @@ mod util; /// Encodes primitive values into an internal buffer, usually with a specialized run length /// encoding for better compression. -pub trait PrimitiveValueEncoder: EstimateMemory -where - V: Copy, -{ +pub trait PrimitiveValueEncoder: EstimateMemory { fn new() -> Self; fn write_one(&mut self, value: V); diff --git a/src/error.rs b/src/error.rs index 02713e2..9cf8caf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,7 +23,6 @@ use arrow::error::ArrowError; use snafu::prelude::*; use snafu::Location; -use crate::proto; use crate::schema::DataType; // TODO: consolidate error types? better to have a smaller set? @@ -103,14 +102,6 @@ pub enum OrcError { arrow_type: ArrowDataType, }, - #[snafu(display("Invalid encoding for column '{}': {:?}", name, encoding))] - InvalidColumnEncoding { - #[snafu(implicit)] - location: Location, - name: String, - encoding: proto::column_encoding::Kind, - }, - #[snafu(display("Failed to convert to record batch: {}", source))] ConvertRecordBatch { #[snafu(implicit)] diff --git a/src/schema.rs b/src/schema.rs index 8dcbb0e..d1fc241 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt::Display; use std::sync::Arc; @@ -41,6 +41,7 @@ use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, Union #[derive(Debug, Clone)] pub struct RootDataType { children: Vec, + all_children: HashSet, } impl RootDataType { @@ -54,6 +55,12 @@ impl RootDataType { &self.children } + /// If specified column index is one of the projected columns in this root data type, + /// considering transitive children of compound types. + pub fn contains_column_index(&self, index: usize) -> bool { + self.all_children.contains(&index) + } + /// Convert into an Arrow schema. pub fn create_arrow_schema(&self, user_metadata: &HashMap) -> Schema { let fields = self @@ -76,14 +83,22 @@ impl RootDataType { .filter(|col| mask.is_index_projected(col.data_type().column_index())) .map(|col| col.to_owned()) .collect::>(); - Self { children } + let all_children = get_all_children_indices_set(&children); + Self { + children, + all_children, + } } /// Construct from protobuf types. pub(crate) fn from_proto(types: &[proto::Type]) -> Result { ensure!(!types.is_empty(), NoTypesSnafu {}); let children = parse_struct_children_from_proto(types, 0)?; - Ok(Self { children }) + let all_children = get_all_children_indices_set(&children); + Ok(Self { + children, + all_children, + }) } } @@ -97,6 +112,13 @@ impl Display for RootDataType { } } +fn get_all_children_indices_set(columns: &[NamedColumn]) -> HashSet { + let mut set = HashSet::new(); + set.insert(0); + set.extend(columns.iter().flat_map(|c| c.data_type().all_indices())); + set +} + #[derive(Debug, Clone)] pub struct NamedColumn { name: String, @@ -285,18 +307,17 @@ impl DataType { | DataType::Date { .. } => vec![], DataType::Struct { children, .. } => children .iter() - .flat_map(|col| col.data_type().children_indices()) + .flat_map(|col| col.data_type().all_indices()) .collect(), DataType::List { child, .. } => child.all_indices(), DataType::Map { key, value, .. } => { - let mut indices = key.children_indices(); - indices.extend(value.children_indices()); + let mut indices = key.all_indices(); + indices.extend(value.all_indices()); indices } - DataType::Union { variants, .. } => variants - .iter() - .flat_map(|dt| dt.children_indices()) - .collect(), + DataType::Union { variants, .. } => { + variants.iter().flat_map(|dt| dt.all_indices()).collect() + } } } diff --git a/src/stripe.rs b/src/stripe.rs index fd09559..3703fab 100644 --- a/src/stripe.rs +++ b/src/stripe.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; use std::{collections::HashMap, io::Read, sync::Arc}; use bytes::Bytes; @@ -85,19 +84,19 @@ impl TryFrom<(&proto::StripeInformation, &proto::StripeStatistics)> for StripeMe type Error = error::OrcError; fn try_from(value: (&proto::StripeInformation, &proto::StripeStatistics)) -> Result { - let column_statistics = value - .1 + let (info, statistics) = value; + let column_statistics = statistics .col_stats .iter() .map(TryFrom::try_from) .collect::>>()?; Ok(Self { column_statistics, - offset: value.0.offset(), - index_length: value.0.index_length(), - data_length: value.0.data_length(), - footer_length: value.0.footer_length(), - number_of_rows: value.0.number_of_rows(), + offset: info.offset(), + index_length: info.index_length(), + data_length: info.data_length(), + footer_length: info.footer_length(), + number_of_rows: info.number_of_rows(), }) } } @@ -120,8 +119,7 @@ impl TryFrom<&proto::StripeInformation> for StripeMetadata { #[derive(Debug)] pub struct Stripe { columns: Vec, - /// <(ColumnId, Kind), Bytes> - stream_map: Arc, + stream_map: StreamMap, number_of_rows: usize, tz: Option, } @@ -129,38 +127,44 @@ pub struct Stripe { impl Stripe { pub fn new( reader: &mut R, - file_metadata: &Arc, + file_metadata: &FileMetadata, projected_data_type: &RootDataType, info: &StripeMetadata, ) -> Result { - let compression = file_metadata.compression(); - let footer = reader .get_bytes(info.footer_offset(), info.footer_length()) .context(IoSnafu)?; - let footer = Arc::new(deserialize_stripe_footer(footer, compression)?); + let footer = Arc::new(deserialize_stripe_footer( + footer, + file_metadata.compression(), + )?); let columns: Vec = projected_data_type .children() .iter() - .map(|col| Column::new(col.name(), col.data_type(), &footer)) + .map(|col| { + Column::new( + col.name().to_string(), + col.data_type().clone(), + footer.clone(), + ) + }) .collect(); - let column_ids = collect_required_column_ids(&columns); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); for stream in &footer.streams { let length = stream.length(); let column_id = stream.column(); - if column_ids.contains(&column_id) { + if projected_data_type.contains_column_index(column_id as usize) { let kind = stream.kind(); - let data = Column::read_stream(reader, stream_offset, length)?; + let data = reader.get_bytes(stream_offset, length).context(IoSnafu)?; stream_map.insert((column_id, kind), data); } stream_offset += length; } - let tz: Option = footer + let tz = footer .writer_timezone .as_ref() // TODO: make this return error @@ -168,10 +172,10 @@ impl Stripe { Ok(Self { columns, - stream_map: Arc::new(StreamMap { + stream_map: StreamMap { inner: stream_map, - compression, - }), + compression: file_metadata.compression(), + }, number_of_rows: info.number_of_rows() as usize, tz, }) @@ -181,40 +185,49 @@ impl Stripe { #[cfg(feature = "async")] pub async fn new_async( reader: &mut R, - file_metadata: &Arc, + file_metadata: &FileMetadata, projected_data_type: &RootDataType, info: &StripeMetadata, ) -> Result { - let compression = file_metadata.compression(); - let footer = reader .get_bytes(info.footer_offset(), info.footer_length()) .await .context(IoSnafu)?; - let footer = Arc::new(deserialize_stripe_footer(footer, compression)?); + let footer = Arc::new(deserialize_stripe_footer( + footer, + file_metadata.compression(), + )?); let columns: Vec = projected_data_type .children() .iter() - .map(|col| Column::new(col.name(), col.data_type(), &footer)) + .map(|col| { + Column::new( + col.name().to_string(), + col.data_type().clone(), + footer.clone(), + ) + }) .collect(); - let column_ids = collect_required_column_ids(&columns); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); for stream in &footer.streams { let length = stream.length(); let column_id = stream.column(); - if column_ids.contains(&column_id) { + if projected_data_type.contains_column_index(column_id as usize) { let kind = stream.kind(); - let data = Column::read_stream_async(reader, stream_offset, length).await?; + let data = reader + .get_bytes(stream_offset, length) + .await + .context(IoSnafu)?; stream_map.insert((column_id, kind), data); } stream_offset += length; } - let tz: Option = footer + let tz = footer .writer_timezone .as_ref() // TODO: make this return error @@ -222,10 +235,10 @@ impl Stripe { Ok(Self { columns, - stream_map: Arc::new(StreamMap { + stream_map: StreamMap { inner: stream_map, - compression, - }), + compression: file_metadata.compression(), + }, number_of_rows: info.number_of_rows() as usize, tz, }) @@ -235,7 +248,6 @@ impl Stripe { self.number_of_rows } - /// Fetch the stream map pub fn stream_map(&self) -> &StreamMap { &self.stream_map } @@ -251,8 +263,9 @@ impl Stripe { #[derive(Debug)] pub struct StreamMap { - pub inner: HashMap<(u32, Kind), Bytes>, - pub compression: Option, + /// <(ColumnId, Kind), Bytes> + inner: HashMap<(u32, Kind), Bytes>, + compression: Option, } impl StreamMap { @@ -285,12 +298,3 @@ fn deserialize_stripe_footer( .context(error::IoSnafu)?; StripeFooter::decode(buffer.as_slice()).context(error::DecodeProtoSnafu) } - -fn collect_required_column_ids(columns: &[Column]) -> HashSet { - let mut set = HashSet::new(); - for column in columns { - set.insert(column.column_id()); - set.extend(collect_required_column_ids(&column.children())); - } - set -}