|
| 1 | +use std::{path::Path, sync::Arc}; |
| 2 | + |
| 3 | +use arrow_array::{Array, RecordBatch}; |
| 4 | +use arrow_ord::sort::{lexsort_to_indices, SortColumn}; |
| 5 | +use arrow_schema::{DataType, Schema}; |
| 6 | +use arrow_select::{concat::concat_batches, take::take}; |
| 7 | +use delta_kernel::DeltaResult; |
| 8 | +use futures::{stream::TryStreamExt, StreamExt}; |
| 9 | +use object_store::{local::LocalFileSystem, ObjectStore}; |
| 10 | +use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; |
| 11 | + |
| 12 | +use super::TestCaseInfo; |
| 13 | +use crate::TestResult; |
| 14 | + |
| 15 | +pub async fn read_golden(path: &Path, _version: Option<&str>) -> DeltaResult<RecordBatch> { |
| 16 | + let expected_root = path.join("expected").join("latest").join("table_content"); |
| 17 | + let store = Arc::new(LocalFileSystem::new_with_prefix(&expected_root)?); |
| 18 | + let files: Vec<_> = store.list(None).try_collect().await?; |
| 19 | + let mut batches = vec![]; |
| 20 | + let mut schema = None; |
| 21 | + for meta in files.into_iter() { |
| 22 | + if let Some(ext) = meta.location.extension() { |
| 23 | + if ext == "parquet" { |
| 24 | + let reader = ParquetObjectReader::new(store.clone(), meta.location); |
| 25 | + let builder = ParquetRecordBatchStreamBuilder::new(reader).await?; |
| 26 | + if schema.is_none() { |
| 27 | + schema = Some(builder.schema().clone()); |
| 28 | + } |
| 29 | + let mut stream = builder.build()?; |
| 30 | + while let Some(batch) = stream.next().await { |
| 31 | + batches.push(batch?); |
| 32 | + } |
| 33 | + } |
| 34 | + } |
| 35 | + } |
| 36 | + let all_data = concat_batches(&schema.unwrap(), &batches)?; |
| 37 | + Ok(all_data) |
| 38 | +} |
| 39 | + |
| 40 | +pub fn sort_record_batch(batch: RecordBatch) -> DeltaResult<RecordBatch> { |
| 41 | + // Sort by all columns |
| 42 | + let mut sort_columns = vec![]; |
| 43 | + for col in batch.columns() { |
| 44 | + match col.data_type() { |
| 45 | + DataType::Struct(_) | DataType::List(_) | DataType::Map(_, _) => { |
| 46 | + // can't sort structs, lists, or maps |
| 47 | + } |
| 48 | + _ => sort_columns.push(SortColumn { |
| 49 | + values: col.clone(), |
| 50 | + options: None, |
| 51 | + }), |
| 52 | + } |
| 53 | + } |
| 54 | + let indices = lexsort_to_indices(&sort_columns, None)?; |
| 55 | + let columns = batch |
| 56 | + .columns() |
| 57 | + .iter() |
| 58 | + .map(|c| take(c, &indices, None).unwrap()) |
| 59 | + .collect(); |
| 60 | + Ok(RecordBatch::try_new(batch.schema(), columns)?) |
| 61 | +} |
| 62 | + |
| 63 | +// Ensure that two schema have the same field names, and dict_id/ordering. |
| 64 | +// We ignore: |
| 65 | +// - data type: This is checked already in `assert_columns_match` |
| 66 | +// - nullability: parquet marks many things as nullable that we don't in our schema |
| 67 | +// - metadata: because that diverges from the real data to the golden tabled data |
| 68 | +fn assert_schema_fields_match(schema: &Schema, golden: &Schema) { |
| 69 | + for (schema_field, golden_field) in schema.fields.iter().zip(golden.fields.iter()) { |
| 70 | + assert!( |
| 71 | + schema_field.name() == golden_field.name(), |
| 72 | + "Field names don't match" |
| 73 | + ); |
| 74 | + assert!( |
| 75 | + schema_field.dict_id() == golden_field.dict_id(), |
| 76 | + "Field dict_id doesn't match" |
| 77 | + ); |
| 78 | + assert!( |
| 79 | + schema_field.dict_is_ordered() == golden_field.dict_is_ordered(), |
| 80 | + "Field dict_is_ordered doesn't match" |
| 81 | + ); |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +// some things are equivalent, but don't show up as equivalent for `==`, so we normalize here |
| 86 | +fn normalize_col(col: Arc<dyn Array>) -> Arc<dyn Array> { |
| 87 | + if let DataType::Timestamp(unit, Some(zone)) = col.data_type() { |
| 88 | + if **zone == *"+00:00" { |
| 89 | + arrow_cast::cast::cast(&col, &DataType::Timestamp(*unit, Some("UTC".into()))) |
| 90 | + .expect("Could not cast to UTC") |
| 91 | + } else { |
| 92 | + col |
| 93 | + } |
| 94 | + } else { |
| 95 | + col |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +fn assert_columns_match(actual: &[Arc<dyn Array>], expected: &[Arc<dyn Array>]) { |
| 100 | + for (actual, expected) in actual.iter().zip(expected) { |
| 101 | + let actual = normalize_col(actual.clone()); |
| 102 | + let expected = normalize_col(expected.clone()); |
| 103 | + // note that array equality includes data_type equality |
| 104 | + // See: https://arrow.apache.org/rust/arrow_data/equal/fn.equal.html |
| 105 | + assert_eq!( |
| 106 | + &actual, &expected, |
| 107 | + "Column data didn't match. Got {actual:?}, expected {expected:?}" |
| 108 | + ); |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +pub async fn assert_scan_data( |
| 113 | + all_data: Vec<RecordBatch>, |
| 114 | + test_case: &TestCaseInfo, |
| 115 | +) -> TestResult<()> { |
| 116 | + let all_data = concat_batches(&all_data[0].schema(), all_data.iter()).unwrap(); |
| 117 | + let all_data = sort_record_batch(all_data)?; |
| 118 | + |
| 119 | + let golden = read_golden(test_case.root_dir(), None).await?; |
| 120 | + let golden = sort_record_batch(golden)?; |
| 121 | + |
| 122 | + assert_columns_match(all_data.columns(), golden.columns()); |
| 123 | + assert_schema_fields_match(all_data.schema().as_ref(), golden.schema().as_ref()); |
| 124 | + assert!( |
| 125 | + all_data.num_rows() == golden.num_rows(), |
| 126 | + "Didn't have same number of rows" |
| 127 | + ); |
| 128 | + |
| 129 | + Ok(()) |
| 130 | +} |
0 commit comments