Skip to content

Commit d99bcbb

Browse files
desmondcheongzxclaudeVarun MadanmadvartVarun Madan
authored
fix(scan): use parquet metadata for scan task size estimates (#6542)
During schema inference, `GlobScanOperator::try_new` already reads the full parquet footer via `read_parquet_schema_and_metadata` but only preserves `num_rows` in `TableMetadata` - row group size information is discarded. When `estimate_in_memory_size_bytes` later needs a size estimate and no column-level `TableStatistics` are available, it falls back to `approx_num_rows * schema.estimate_row_size_bytes()`, which uses a fixed 20 bytes for Utf8 columns. For data with dictionary encoding or low-cardinality columns, this produces wildly inflated estimates. To fix this, we extend `TableMetadata` with an optional `size_bytes` field populated from the sum of uncompressed (`total_byte_size`) row group sizes during schema inference, and use it as a middle-tier fallback in `estimate_in_memory_size_bytes` between table statistics (most accurate) and the schema-based guess (least accurate). For a 7.4 MB parquet file with 40M rows of 7 repeated URLs (from the Rivian repro at `s3://cgrinstead/daft-rivian-repro/data.parquet`), the estimate drops from **1.66 GiB to 44 MiB**. ## Changes - Expose `total_byte_size()` on `DaftRowGroupMetaData` (wraps arrow-rs `RowGroupMetaData::total_byte_size()`) - Add `size_bytes: Option<usize>` to `TableMetadata` with `#[serde(default)]` for backwards compatibility - Populate `size_bytes` from parquet row group metadata during schema inference in `GlobScanOperator::try_new` - Aggregate `size_bytes` across sources in `ScanTask::new` - Add `metadata.size_bytes` as a fallback tier in `estimate_in_memory_size_bytes` --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Varun Madan <varun@Varuns-MacBook-Pro.local> Co-authored-by: Varun Madan <varun.madan@gmail.com> Co-authored-by: Varun Madan <varun@Mac.localdomain>
1 parent 04bc5c9 commit d99bcbb

8 files changed

Lines changed: 264 additions & 13 deletions

File tree

src/daft-micropartition/src/micropartition.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ impl MicroPartition {
9393
Self {
9494
schema,
9595
chunks: record_batches,
96-
metadata: TableMetadata { length },
96+
metadata: TableMetadata {
97+
length,
98+
column_sizes: None,
99+
},
97100
statistics,
98101
}
99102
}

src/daft-micropartition/src/ops/concat.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ impl MicroPartition {
3232
Ok(Self {
3333
schema,
3434
chunks: Arc::new(all_tables),
35-
metadata: TableMetadata { length: new_len },
35+
metadata: TableMetadata {
36+
length: new_len,
37+
column_sizes: None,
38+
},
3639
statistics: all_stats,
3740
})
3841
}

src/daft-parquet/src/metadata_adapter.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,31 @@ impl DaftRowGroupMetaData {
201201
pub fn compressed_size(&self) -> usize {
202202
self.inner.compressed_size() as usize
203203
}
204+
205+
/// Uncompressed (encoded) size in bytes for each column chunk in this row group,
206+
/// keyed by the top-level (root) column name.
207+
///
208+
/// Nested columns are flattened to their root: every leaf chunk (e.g.
209+
/// `position_ids.list.element`) is attributed to its top-level field
210+
/// (`position_ids`). The caller is responsible for summing across row groups.
211+
pub fn column_uncompressed_sizes(&self) -> Vec<(String, u64)> {
212+
self.inner
213+
.columns()
214+
.iter()
215+
.map(|col| {
216+
let root = col
217+
.column_descr()
218+
.path()
219+
.parts()
220+
.first()
221+
.cloned()
222+
.unwrap_or_default();
223+
// `uncompressed_size()` is an i64 and is always non-negative in practice;
224+
// clamp defensively so a malformed (negative) value can't wrap to a huge u64.
225+
(root, col.uncompressed_size().max(0) as u64)
226+
})
227+
.collect()
228+
}
204229
}
205230

206231
#[cfg(test)]

src/daft-scan/src/glob.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{sync::Arc, vec};
1+
use std::{collections::BTreeMap, sync::Arc, vec};
22

33
use common_error::{DaftError, DaftResult};
44
use common_file_formats::FileFormat;
@@ -288,10 +288,20 @@ impl GlobScanOperator {
288288
Err(e) => return Err(e),
289289
};
290290

291+
// Sum the per-column uncompressed sizes across all row groups, keyed by
292+
// top-level column name. Storing per-column (rather than a single total)
293+
// lets size estimates respect column-projection pushdown.
294+
let mut column_sizes: BTreeMap<String, u64> = BTreeMap::new();
295+
for (_, rg) in metadata.row_groups() {
296+
for (name, bytes) in rg.column_uncompressed_sizes() {
297+
*column_sizes.entry(name).or_insert(0) += bytes;
298+
}
299+
}
291300
let first_metadata = Some((
292301
filepath.clone(),
293302
TableMetadata {
294303
length: metadata.num_rows(),
304+
column_sizes: (!column_sizes.is_empty()).then_some(column_sizes),
295305
},
296306
));
297307
(schema, first_metadata, filepath)

src/daft-scan/src/lib.rs

Lines changed: 194 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,19 +411,31 @@ impl ScanTask {
411411
.all(|s| s.partition_spec == sources.first().unwrap().partition_spec),
412412
"ScanTask sources must all have the same PartitionSpec at construction",
413413
);
414-
let (length, size_bytes_on_disk, statistics) = sources
414+
let (length, column_sizes, size_bytes_on_disk, statistics) = sources
415415
.iter()
416416
.map(|s| {
417417
(
418418
s.metadata.as_ref().map(|m| m.length),
419+
s.metadata.as_ref().and_then(|m| m.column_sizes.clone()),
419420
s.size_bytes,
420421
s.statistics.clone(),
421422
)
422423
})
423424
.reduce(
424-
|(acc_len, acc_size, acc_stats), (curr_len, curr_size, curr_stats)| {
425+
|(acc_len, acc_col_sizes, acc_size, acc_stats),
426+
(curr_len, curr_col_sizes, curr_size, curr_stats)| {
425427
(
426428
acc_len.and_then(|acc_len| curr_len.map(|curr_len| acc_len + curr_len)),
429+
// All-or-nothing: only retain per-column sizes if every source has them,
430+
// summing each column across sources.
431+
acc_col_sizes.and_then(|mut acc| {
432+
curr_col_sizes.map(|curr| {
433+
for (name, bytes) in curr {
434+
*acc.entry(name).or_insert(0) += bytes;
435+
}
436+
acc
437+
})
438+
}),
427439
acc_size
428440
.and_then(|acc_size| curr_size.map(|curr_size| acc_size + curr_size)),
429441
acc_stats.and_then(|acc_stats| {
@@ -444,7 +456,10 @@ impl ScanTask {
444456
},
445457
)
446458
.unwrap();
447-
let metadata = length.map(|l| TableMetadata { length: l });
459+
let metadata = length.map(|l| TableMetadata {
460+
length: l,
461+
column_sizes,
462+
});
448463
Self {
449464
sources,
450465
schema,
@@ -756,7 +771,45 @@ impl ScanTask {
756771
})
757772
})
758773
.or_else(|| {
759-
// use approximate number of rows multiplied by an approximate bytes-per-row
774+
// Use per-column uncompressed sizes from file metadata (e.g. Parquet
775+
// column-chunk totals) when available. This is more accurate than the
776+
// schema-based estimate for data with dictionary encoding, low-cardinality
777+
// columns, or variable-length nested types (e.g. List) where the schema
778+
// heuristic assumes a fixed element count.
779+
//
780+
// We restrict the sum to the materialized (projected) columns so the
781+
// estimate respects column-projection pushdown, and convert to a per-row
782+
// size scaled by `approx_num_rows`, so limit/filter pushdowns are honored
783+
// the same way as the schema-based fallback below.
784+
let metadata = self.metadata.as_ref()?;
785+
let column_sizes = metadata.column_sizes.as_ref()?;
786+
if metadata.length == 0 {
787+
return None;
788+
}
789+
let projected_bytes: u64 = mat_schema
790+
.field_names()
791+
.filter_map(|name| column_sizes.get(name).copied())
792+
.sum();
793+
// No projected column was found in the metadata (e.g. the scan only reads
794+
// generated/partition columns); defer to the schema-based estimate.
795+
if projected_bytes == 0 {
796+
return None;
797+
}
798+
let row_size = (projected_bytes as f64) / (metadata.length as f64);
799+
self.approx_num_rows(config).map(|approx_num_rows| {
800+
let estimate_f64 = approx_num_rows * row_size;
801+
if estimate_f64.is_nan()
802+
|| estimate_f64.is_infinite()
803+
|| estimate_f64 > REASONABLE_SIZE_BYTES as f64
804+
{
805+
REASONABLE_SIZE_BYTES
806+
} else {
807+
estimate_f64 as usize
808+
}
809+
})
810+
})
811+
.or_else(|| {
812+
// Fall back to approximate number of rows multiplied by an approximate bytes-per-row
760813
self.approx_num_rows(config).map(|approx_num_rows| {
761814
let row_size = mat_schema.estimate_row_size_bytes();
762815

@@ -874,7 +927,7 @@ Pushdowns = {pushdowns}
874927

875928
#[cfg(test)]
876929
mod test {
877-
use std::sync::Arc;
930+
use std::{collections::BTreeMap, sync::Arc};
878931

879932
use common_display::{DisplayAs, DisplayLevel};
880933
use common_error::DaftResult;
@@ -1045,6 +1098,7 @@ mod test {
10451098
size_bytes: Some(1_000_000),
10461099
metadata: Some(TableMetadata {
10471100
length: usize::MAX, // Extremely large row count
1101+
column_sizes: None,
10481102
}),
10491103
statistics: None,
10501104
partition_spec: None,
@@ -1132,6 +1186,7 @@ mod test {
11321186
size_bytes: Some(10_000_000), // 10MB
11331187
metadata: Some(TableMetadata {
11341188
length: 1000, // 1000 rows
1189+
column_sizes: None,
11351190
}),
11361191
statistics: None,
11371192
partition_spec: None,
@@ -1175,6 +1230,7 @@ mod test {
11751230
size_bytes: Some(1_000_000),
11761231
metadata: Some(TableMetadata {
11771232
length: usize::MAX, // Extremely large row count
1233+
column_sizes: None,
11781234
}),
11791235
statistics: None,
11801236
partition_spec: None,
@@ -1321,7 +1377,10 @@ mod test {
13211377
fn test_schema_row_size_estimation_valid_case() {
13221378
let sources = vec![ScanSource {
13231379
size_bytes: Some(1_000_000),
1324-
metadata: Some(TableMetadata { length: 10_000 }),
1380+
metadata: Some(TableMetadata {
1381+
length: 10_000,
1382+
column_sizes: None,
1383+
}),
13251384
statistics: None,
13261385
partition_spec: None,
13271386
kind: ScanSourceKind::File {
@@ -1364,6 +1423,135 @@ mod test {
13641423
assert!(estimate_val < 1_000_000_000); // Less than 1GB is reasonable
13651424
}
13661425

1426+
/// Builds a parquet scan task modeled on the customer's tokenized-sequence dataset:
1427+
/// four `List(Int64)` columns where `position_ids` dominates the on-disk bytes. The
1428+
/// per-column uncompressed sizes are taken from row group 0 of `rank_0_train.parquet`.
1429+
fn make_list_column_scan_task(pushdowns: Pushdowns) -> ScanTask {
1430+
// Uncompressed byte sizes per column, from the customer's parquet metadata.
1431+
let column_sizes = BTreeMap::from([
1432+
("input_ids".to_string(), 10_760_083u64),
1433+
("attention_mask".to_string(), 3_339u64),
1434+
("labels".to_string(), 10_760_083u64),
1435+
("position_ids".to_string(), 122_969_239u64),
1436+
]);
1437+
let num_rows = 200; // row group 0 num_rows
1438+
1439+
let sources = vec![ScanSource {
1440+
size_bytes: Some(32_418_149), // compressed row group 0 bytes
1441+
metadata: Some(TableMetadata {
1442+
length: num_rows,
1443+
column_sizes: Some(column_sizes),
1444+
}),
1445+
statistics: None,
1446+
partition_spec: None,
1447+
kind: ScanSourceKind::File {
1448+
path: "rank_0_train.parquet".to_string(),
1449+
chunk_spec: None,
1450+
iceberg_delete_files: None,
1451+
parquet_metadata: None,
1452+
},
1453+
}];
1454+
1455+
let schema = Arc::new(Schema::new(vec![
1456+
Field::new("input_ids", DataType::List(Box::new(DataType::Int64))),
1457+
Field::new("attention_mask", DataType::List(Box::new(DataType::Int64))),
1458+
Field::new("labels", DataType::List(Box::new(DataType::Int64))),
1459+
Field::new("position_ids", DataType::List(Box::new(DataType::Int64))),
1460+
]));
1461+
1462+
ScanTask::new(
1463+
sources,
1464+
Arc::new(SourceConfig::File(FileFormatConfig::Parquet(
1465+
ParquetSourceConfig {
1466+
coerce_int96_timestamp_unit: TimeUnit::Seconds,
1467+
field_id_mapping: None,
1468+
row_groups: None,
1469+
chunk_size: None,
1470+
ignore_corrupt_files: false,
1471+
},
1472+
))),
1473+
schema,
1474+
Arc::new(StorageConfig::new_internal(false, None)),
1475+
pushdowns,
1476+
None,
1477+
)
1478+
}
1479+
1480+
/// Regression test for OOMs caused by under-estimating `List`-typed columns.
1481+
///
1482+
/// The schema-based fallback assumes a fixed list length (`DEFAULT_LIST_LEN = 4`),
1483+
/// estimating ~130 bytes/row (~26 KB for 200 rows) when the real uncompressed size
1484+
/// is ~144 MB. Per-column metadata sizes must be used instead.
1485+
#[test]
1486+
fn test_list_column_estimation_uses_metadata_not_schema() {
1487+
let scan_task = make_list_column_scan_task(Pushdowns::default());
1488+
let estimate = scan_task.estimate_in_memory_size_bytes(None).unwrap();
1489+
1490+
// Expect the sum of all four columns' uncompressed sizes (~144 MB).
1491+
// (Allow a few bytes of slack for the per-row f64 roundtrip.)
1492+
let expected: i64 = 10_760_083 + 3_339 + 10_760_083 + 122_969_239;
1493+
assert!((estimate as i64 - expected).abs() <= 4);
1494+
1495+
// Sanity check: this must be vastly larger than the schema-based guess, which is
1496+
// what caused the OOM. DEFAULT_LIST_LEN=4 yields ~130 bytes/row.
1497+
let schema_based = 200.0 * scan_task.materialized_schema().estimate_row_size_bytes();
1498+
assert!(schema_based < 30_000.0);
1499+
assert!((estimate as f64) > 1000.0 * schema_based);
1500+
}
1501+
1502+
/// The metadata-based estimate must respect column-projection pushdown: selecting only
1503+
/// the small `attention_mask` column should not estimate the whole (position_ids-heavy)
1504+
/// row group.
1505+
#[test]
1506+
fn test_list_column_estimation_respects_projection() {
1507+
// Project only the small column.
1508+
let small_pushdowns = Pushdowns::new(
1509+
None,
1510+
None,
1511+
Some(Arc::new(vec!["attention_mask".to_string()])),
1512+
None,
1513+
None,
1514+
None,
1515+
);
1516+
let small = make_list_column_scan_task(small_pushdowns)
1517+
.estimate_in_memory_size_bytes(None)
1518+
.unwrap();
1519+
assert!((small as i64 - 3_339).abs() <= 4);
1520+
1521+
// Project only the dominant column.
1522+
let large_pushdowns = Pushdowns::new(
1523+
None,
1524+
None,
1525+
Some(Arc::new(vec!["position_ids".to_string()])),
1526+
None,
1527+
None,
1528+
None,
1529+
);
1530+
let large = make_list_column_scan_task(large_pushdowns)
1531+
.estimate_in_memory_size_bytes(None)
1532+
.unwrap();
1533+
assert!((large as i64 - 122_969_239).abs() <= 4);
1534+
1535+
// The projected small column must be orders of magnitude smaller than the full scan.
1536+
assert!(large > 1000 * small);
1537+
}
1538+
1539+
/// A limit pushdown should scale the metadata-based estimate down proportionally,
1540+
/// rather than returning the full-file size.
1541+
#[test]
1542+
fn test_list_column_estimation_respects_limit() {
1543+
let limit_pushdowns = Pushdowns::new(None, None, None, Some(50), None, None);
1544+
let estimate = make_list_column_scan_task(limit_pushdowns)
1545+
.estimate_in_memory_size_bytes(None)
1546+
.unwrap();
1547+
1548+
// 50 of 200 rows => roughly a quarter of the full ~144 MB.
1549+
let full = 10_760_083 + 3_339 + 10_760_083 + 122_969_239;
1550+
let expected = full / 4;
1551+
// Allow for f64 rounding.
1552+
assert!((estimate as i64 - expected as i64).abs() <= 4);
1553+
}
1554+
13671555
#[test]
13681556
fn test_overflow_protection_with_infinity() {
13691557
let sources = vec![ScanSource {

0 commit comments

Comments
 (0)