Skip to content

Commit 285b317

Browse files
authored
Fixes for sort/limit pushdown + mongo timestamp (#619)
* Fix timestamp filter pushdown, schema inference, and sort/limit correctness - Unwrap Expr::Cast/TryCast in extract_literal_value/extract_column_name so timestamp filters push down through DataFusion's coercion casts - Handle (Timestamp, Date32) order in unify_types so mixed midnight and non-midnight DateTime collections infer as Timestamp, not Utf8 - Implement supports_limit_pushdown/with_fetch/fetch on MongoDBExec - Return SortOrderPushdownResult::Inexact so SortExec's fetch survives when ORDER BY ... LIMIT N is pushed down * Fix for sort+limit * Lint * Remove extra changes * Fix
1 parent 30d973a commit 285b317

9 files changed

Lines changed: 869 additions & 23 deletions

File tree

core/src/mongodb/table.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,13 @@ impl ExecutionPlan for MongoDBExec {
261261
);
262262
new_exec.properties = new_exec.properties.with_eq_properties(eq_properties);
263263

264-
Ok(SortOrderPushdownResult::Exact {
264+
// Return Inexact rather than Exact so DataFusion keeps the SortExec wrapper
265+
// above us. Exact would replace the SortExec with `inner`, which loses the
266+
// SortExec's embedded fetch (`ORDER BY ... LIMIT N` is represented as a
267+
// single SortExec with fetch=N in DF 52). Keeping the SortExec preserves the
268+
// fetch as a TopK applied to our already-sorted SQL output.
269+
// We can use Exact once we use DF version which includes PR https://github.com/apache/datafusion/pull/21182
270+
Ok(SortOrderPushdownResult::Inexact {
265271
inner: Arc::new(new_exec),
266272
})
267273
}
@@ -659,7 +665,7 @@ mod tests {
659665

660666
let result = exec.try_pushdown_sort(&sort_exprs).unwrap();
661667
match result {
662-
SortOrderPushdownResult::Exact { inner } => {
668+
SortOrderPushdownResult::Inexact { inner } => {
663669
let mongo_exec = inner.as_any().downcast_ref::<MongoDBExec>().unwrap();
664670
assert_eq!(mongo_exec.sort_doc, doc! { "name": 1 });
665671
let display = format_exec(mongo_exec);
@@ -668,7 +674,7 @@ mod tests {
668674
"Display should show sort: {display}"
669675
);
670676
}
671-
other => panic!("Expected Exact, got: {other:?}"),
677+
other => panic!("Expected Inexact, got: {other:?}"),
672678
}
673679
}
674680

@@ -690,11 +696,11 @@ mod tests {
690696

691697
let result = exec.try_pushdown_sort(&sort_exprs).unwrap();
692698
match result {
693-
SortOrderPushdownResult::Exact { inner } => {
699+
SortOrderPushdownResult::Inexact { inner } => {
694700
let mongo_exec = inner.as_any().downcast_ref::<MongoDBExec>().unwrap();
695701
assert_eq!(mongo_exec.sort_doc, doc! { "age": -1 });
696702
}
697-
other => panic!("Expected Exact, got: {other:?}"),
703+
other => panic!("Expected Inexact, got: {other:?}"),
698704
}
699705
}
700706

@@ -725,11 +731,11 @@ mod tests {
725731

726732
let result = exec.try_pushdown_sort(&sort_exprs).unwrap();
727733
match result {
728-
SortOrderPushdownResult::Exact { inner } => {
734+
SortOrderPushdownResult::Inexact { inner } => {
729735
let mongo_exec = inner.as_any().downcast_ref::<MongoDBExec>().unwrap();
730736
assert_eq!(mongo_exec.sort_doc, doc! { "name": 1, "age": -1 });
731737
}
732-
other => panic!("Expected Exact, got: {other:?}"),
738+
other => panic!("Expected Inexact, got: {other:?}"),
733739
}
734740
}
735741

@@ -771,7 +777,7 @@ mod tests {
771777

772778
let result = exec.try_pushdown_sort(&sort_exprs).unwrap();
773779
match result {
774-
SortOrderPushdownResult::Exact { inner } => {
780+
SortOrderPushdownResult::Inexact { inner } => {
775781
let mongo_exec = inner.as_any().downcast_ref::<MongoDBExec>().unwrap();
776782
assert!(
777783
!mongo_exec.filters_doc.is_empty(),
@@ -780,7 +786,7 @@ mod tests {
780786
assert_eq!(mongo_exec.limit, Some(10), "Limit should be preserved");
781787
assert_eq!(mongo_exec.sort_doc, doc! { "name": 1 });
782788
}
783-
other => panic!("Expected Exact, got: {other:?}"),
789+
other => panic!("Expected Inexact, got: {other:?}"),
784790
}
785791
}
786792

@@ -792,14 +798,14 @@ mod tests {
792798

793799
let result = exec.try_pushdown_sort(&[]).unwrap();
794800
match result {
795-
SortOrderPushdownResult::Exact { inner } => {
801+
SortOrderPushdownResult::Inexact { inner } => {
796802
let mongo_exec = inner.as_any().downcast_ref::<MongoDBExec>().unwrap();
797803
assert!(
798804
mongo_exec.sort_doc.is_empty(),
799805
"Empty sort should produce empty doc"
800806
);
801807
}
802-
other => panic!("Expected Exact, got: {other:?}"),
808+
other => panic!("Expected Inexact, got: {other:?}"),
803809
}
804810
}
805811
}

core/src/mongodb/utils/expression.rs

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use datafusion::{
2+
logical_expr::expr::{Cast, TryCast},
23
logical_expr::{Expr, Operator},
34
scalar::ScalarValue,
45
};
@@ -243,13 +244,36 @@ fn extract_string_literal(expr: &Expr) -> Option<String> {
243244
fn extract_column_name(expr: &Expr) -> Option<String> {
244245
match expr {
245246
Expr::Column(col) => Some(col.name.clone()),
247+
// Unwrap casts around columns (e.g., type coercion in comparisons)
248+
Expr::Cast(cast) => extract_column_name(&cast.expr),
249+
Expr::TryCast(cast) => extract_column_name(&cast.expr),
246250
_ => None,
247251
}
248252
}
249253

250254
fn extract_literal_value(expr: &Expr) -> Option<Bson> {
251255
match expr {
252256
Expr::Literal(scalar, _) => scalar_to_bson(scalar),
257+
// Handle Cast(Literal(...), target_type) by evaluating the cast and converting the result.
258+
// Critical for timestamp/date filters: TIMESTAMP '2024-01-01' is parsed as
259+
// TimestampNanosecond but columns are often Timestamp(Millisecond, Some("UTC")),
260+
// and DataFusion wraps the literal in a Cast to reconcile.
261+
Expr::Cast(Cast { expr, data_type }) => {
262+
if let Expr::Literal(scalar, _) = expr.as_ref() {
263+
let casted = scalar.cast_to(data_type).ok()?;
264+
scalar_to_bson(&casted)
265+
} else {
266+
None
267+
}
268+
}
269+
Expr::TryCast(TryCast { expr, data_type }) => {
270+
if let Expr::Literal(scalar, _) = expr.as_ref() {
271+
let casted = scalar.cast_to(data_type).ok()?;
272+
scalar_to_bson(&casted)
273+
} else {
274+
None
275+
}
276+
}
253277
_ => None,
254278
}
255279
}
@@ -1486,4 +1510,174 @@ mod tests {
14861510
let regex = sql_like_to_regex("A___Z", None);
14871511
assert_eq!(regex, "^A...Z$");
14881512
}
1513+
1514+
// --- Cast / TryCast handling ---
1515+
1516+
#[test]
1517+
fn test_cast_timestamp_literal_pushdown() {
1518+
use datafusion::arrow::datatypes::{DataType, TimeUnit};
1519+
1520+
// Simulates DataFusion's output for:
1521+
// created_at >= TIMESTAMP '2024-06-01 00:00:00'
1522+
// when the column is Timestamp(Millisecond, Some("UTC")) but the SQL literal
1523+
// is parsed as Timestamp(Nanosecond, None) — DataFusion wraps the literal in a Cast.
1524+
let ns_value = 1_717_200_000_000_000_000_i64; // 2024-06-01T00:00:00Z
1525+
let cast_expr = Expr::Cast(Cast {
1526+
expr: Box::new(Expr::Literal(
1527+
ScalarValue::TimestampNanosecond(Some(ns_value), None),
1528+
None,
1529+
)),
1530+
data_type: DataType::Timestamp(
1531+
TimeUnit::Millisecond,
1532+
Some(std::sync::Arc::from("UTC")),
1533+
),
1534+
});
1535+
1536+
let expr = Expr::BinaryExpr(BinaryExpr {
1537+
left: Box::new(col("created_at")),
1538+
op: Operator::GtEq,
1539+
right: Box::new(cast_expr),
1540+
});
1541+
1542+
let filter = expr_to_mongo_filter(&expr).unwrap();
1543+
let expected = doc! { "created_at": { "$gte": mongodb::bson::DateTime::from_millis(1_717_200_000_000) } };
1544+
assert_eq!(filter, expected);
1545+
}
1546+
1547+
#[test]
1548+
fn test_cast_int_literal_pushdown() {
1549+
use datafusion::arrow::datatypes::DataType;
1550+
1551+
let cast_expr = Expr::Cast(Cast {
1552+
expr: Box::new(lit(42_i32)),
1553+
data_type: DataType::Int64,
1554+
});
1555+
1556+
let expr = Expr::BinaryExpr(BinaryExpr {
1557+
left: Box::new(col("count")),
1558+
op: Operator::Gt,
1559+
right: Box::new(cast_expr),
1560+
});
1561+
1562+
let filter = expr_to_mongo_filter(&expr).unwrap();
1563+
let expected = doc! { "count": { "$gt": 42_i64 } };
1564+
assert_eq!(filter, expected);
1565+
}
1566+
1567+
#[test]
1568+
fn test_try_cast_literal_pushdown() {
1569+
use datafusion::arrow::datatypes::DataType;
1570+
1571+
let try_cast_expr = Expr::TryCast(TryCast {
1572+
expr: Box::new(lit(100_i32)),
1573+
data_type: DataType::Int64,
1574+
});
1575+
1576+
let expr = Expr::BinaryExpr(BinaryExpr {
1577+
left: Box::new(col("value")),
1578+
op: Operator::Eq,
1579+
right: Box::new(try_cast_expr),
1580+
});
1581+
1582+
let filter = expr_to_mongo_filter(&expr).unwrap();
1583+
let expected = doc! { "value": 100_i64 };
1584+
assert_eq!(filter, expected);
1585+
}
1586+
1587+
#[test]
1588+
fn test_cast_column_with_literal_pushdown() {
1589+
use datafusion::arrow::datatypes::DataType;
1590+
1591+
// Cast(col("age"), Int64) >= lit(30_i64) — column side wrapped in Cast
1592+
let cast_col = Expr::Cast(Cast {
1593+
expr: Box::new(col("age")),
1594+
data_type: DataType::Int64,
1595+
});
1596+
1597+
let expr = Expr::BinaryExpr(BinaryExpr {
1598+
left: Box::new(cast_col),
1599+
op: Operator::GtEq,
1600+
right: Box::new(lit(30_i64)),
1601+
});
1602+
1603+
let filter = expr_to_mongo_filter(&expr).unwrap();
1604+
let expected = doc! { "age": { "$gte": 30_i64 } };
1605+
assert_eq!(filter, expected);
1606+
}
1607+
1608+
#[test]
1609+
fn test_cast_timestamp_between_pushdown() {
1610+
use datafusion::arrow::datatypes::{DataType, TimeUnit};
1611+
1612+
let target_type =
1613+
DataType::Timestamp(TimeUnit::Millisecond, Some(std::sync::Arc::from("UTC")));
1614+
let low = Expr::Cast(Cast {
1615+
expr: Box::new(Expr::Literal(
1616+
ScalarValue::TimestampNanosecond(Some(1_704_067_200_000_000_000), None),
1617+
None,
1618+
)),
1619+
data_type: target_type.clone(),
1620+
});
1621+
let high = Expr::Cast(Cast {
1622+
expr: Box::new(Expr::Literal(
1623+
ScalarValue::TimestampNanosecond(Some(1_735_689_600_000_000_000), None),
1624+
None,
1625+
)),
1626+
data_type: target_type,
1627+
});
1628+
1629+
let expr = Expr::Between(datafusion::logical_expr::expr::Between {
1630+
expr: Box::new(col("created_at")),
1631+
negated: false,
1632+
low: Box::new(low),
1633+
high: Box::new(high),
1634+
});
1635+
1636+
let filter = expr_to_mongo_filter(&expr).unwrap();
1637+
let expected = doc! {
1638+
"created_at": {
1639+
"$gte": mongodb::bson::DateTime::from_millis(1_704_067_200_000),
1640+
"$lte": mongodb::bson::DateTime::from_millis(1_735_689_600_000),
1641+
}
1642+
};
1643+
assert_eq!(filter, expected);
1644+
}
1645+
1646+
#[test]
1647+
fn test_cast_unsupported_target_type_returns_none() {
1648+
use datafusion::arrow::datatypes::DataType;
1649+
1650+
let cast_expr = Expr::Cast(Cast {
1651+
expr: Box::new(lit("hello")),
1652+
data_type: DataType::Binary,
1653+
});
1654+
assert!(extract_literal_value(&cast_expr).is_none());
1655+
}
1656+
1657+
#[test]
1658+
fn test_cast_in_reversed_operand_order() {
1659+
use datafusion::arrow::datatypes::{DataType, TimeUnit};
1660+
1661+
let cast_expr = Expr::Cast(Cast {
1662+
expr: Box::new(Expr::Literal(
1663+
ScalarValue::TimestampNanosecond(Some(1_717_200_000_000_000_000), None),
1664+
None,
1665+
)),
1666+
data_type: DataType::Timestamp(
1667+
TimeUnit::Millisecond,
1668+
Some(std::sync::Arc::from("UTC")),
1669+
),
1670+
});
1671+
1672+
// Cast(lit) > col → col < Cast(lit)
1673+
let expr = Expr::BinaryExpr(BinaryExpr {
1674+
left: Box::new(cast_expr),
1675+
op: Operator::Gt,
1676+
right: Box::new(col("created_at")),
1677+
});
1678+
1679+
let filter = expr_to_mongo_filter(&expr).unwrap();
1680+
let expected = doc! { "created_at": { "$lt": mongodb::bson::DateTime::from_millis(1_717_200_000_000) } };
1681+
assert_eq!(filter, expected);
1682+
}
14891683
}

core/src/mongodb/utils/schema.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ fn unify_types(type1: &DataType, type2: &DataType) -> DataType {
114114
(DataType::Int64, DataType::Float64) | (DataType::Float64, DataType::Int64) => {
115115
DataType::Float64
116116
}
117-
(DataType::Date32, DataType::Timestamp(tu, tz)) => DataType::Timestamp(*tu, tz.clone()),
117+
(DataType::Date32, DataType::Timestamp(tu, tz))
118+
| (DataType::Timestamp(tu, tz), DataType::Date32) => DataType::Timestamp(*tu, tz.clone()),
118119

119120
// Otherwise use string
120121
_ => DataType::Utf8,
@@ -228,6 +229,49 @@ mod tests {
228229
);
229230
}
230231

232+
#[test]
233+
fn test_unify_timestamp_and_date32_both_orders() {
234+
// Regression: when a collection has both midnight DateTimes (inferred as Date32)
235+
// and non-midnight DateTimes (inferred as Timestamp) for the same field, the
236+
// unified type must be Timestamp regardless of which one is seen first.
237+
// HashMap iteration order is non-deterministic, so both directions must work.
238+
let midnight = mongodb::bson::DateTime::builder()
239+
.year(2024)
240+
.month(1)
241+
.day(1)
242+
.build()
243+
.unwrap();
244+
let non_midnight = mongodb::bson::DateTime::builder()
245+
.year(2024)
246+
.month(6)
247+
.day(15)
248+
.hour(12)
249+
.build()
250+
.unwrap();
251+
252+
// Order A: midnight first, then non-midnight
253+
let docs = vec![
254+
doc! { "created_at": midnight },
255+
doc! { "created_at": non_midnight },
256+
];
257+
let schema = infer_arrow_schema_from_documents(&docs, None).unwrap();
258+
assert_eq!(
259+
schema.field_with_name("created_at").unwrap().data_type(),
260+
&DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()))
261+
);
262+
263+
// Order B: non-midnight first, then midnight — previously fell through to Utf8
264+
let docs = vec![
265+
doc! { "created_at": non_midnight },
266+
doc! { "created_at": midnight },
267+
];
268+
let schema = infer_arrow_schema_from_documents(&docs, None).unwrap();
269+
assert_eq!(
270+
schema.field_with_name("created_at").unwrap().data_type(),
271+
&DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()))
272+
);
273+
}
274+
231275
#[test]
232276
fn test_date32_detection() {
233277
let doc = doc! {

0 commit comments

Comments
 (0)