Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 182 additions & 71 deletions crates/sail-function/src/scalar/datetime/spark_time.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
use std::sync::Arc;

use chrono::{NaiveTime, Timelike};
use datafusion::arrow::array::Time64MicrosecondArray;
use datafusion::arrow::array::{Array, ArrayRef, Time64MicrosecondArray};
use datafusion::arrow::compute::{cast_with_options, CastOptions};
use datafusion::arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::cast::{as_large_string_array, as_string_array, as_string_view_array};
use datafusion_common::types::logical_string;
use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
use datafusion_common::{exec_datafusion_err, exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
use sail_common_datafusion::utils::items::ItemTaker;
use sail_sql_analyzer::parser::parse_time;
use datafusion_functions::utils::make_scalar_function;

use crate::error::{invalid_arg_count_exec_err, unsupported_data_type_exec_err};

const DEFAULT_TIME_FORMATS: &[&str] = &[
"%H:%M:%S%.f",
"%H:%M:%S",
"%H:%M",
"%H:%M:%S%.f %p",
"%H:%M:%S %p",
"%H:%M %p",
];

/// Spark-compatible `to_time` / `try_to_time` function.
/// <https://spark.apache.org/docs/latest/api/sql/index.html#to_time>
///
/// Accepts 1 or 2 arguments:
/// - `(expr)` — parses strings with default formats, or casts other types to Time64.
/// - `(expr, format)` — parses strings with the given chrono format. The format
/// may be a scalar string (broadcast) or a string column (per-row).
///
/// `to_time` always errors on invalid input (Spark's `ToTime` does not honor ANSI);
/// `try_to_time` (`is_try = true`) returns NULL on parse/cast failure — mirroring
/// Spark's `try_to_time = TryEval(ToTime(...))`.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkTime {
signature: Signature,
Expand All @@ -20,12 +40,7 @@ pub struct SparkTime {
impl SparkTime {
pub fn new(is_try: bool) -> Self {
Self {
signature: Signature::coercible(
vec![Coercion::new_exact(TypeSignatureClass::Native(
logical_string(),
))],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
is_try,
}
}
Expand All @@ -34,26 +49,128 @@ impl SparkTime {
self.is_try
}

fn string_to_time_micros(value: &str, is_try: bool) -> Result<Option<i64>> {
let result = parse_time(value)
.map_err(|e| exec_datafusion_err!("{e}"))
.and_then(|t| NaiveTime::try_from(t).map_err(|e| exec_datafusion_err!("{e}")))
.map(|naive_time| {
let seconds_from_midnight = naive_time.num_seconds_from_midnight() as i64;
let nanoseconds = naive_time.nanosecond() as i64;
seconds_from_midnight * 1_000_000 + nanoseconds / 1_000
});
match result {
Ok(v) => Ok(Some(v)),
Err(_e) if is_try => Ok(None),
Err(e) => Err(e),
fn naive_time_to_us(t: NaiveTime) -> i64 {
let seconds = t.num_seconds_from_midnight() as i64;
let nanos = t.nanosecond() as i64;
seconds * 1_000_000 + nanos / 1_000
}

fn string_to_time_us_default(value: &str, is_try: bool) -> Result<Option<i64>> {
for fmt in DEFAULT_TIME_FORMATS {
if let Ok(t) = NaiveTime::parse_from_str(value, fmt) {
return Ok(Some(Self::naive_time_to_us(t)));
}
}
if is_try {
Ok(None)
} else {
Err(exec_datafusion_err!(
"cannot parse '{value}' as time with default formats"
))
}
}

fn string_to_time_us_with_format(
value: &str,
format: &str,
is_try: bool,
) -> Result<Option<i64>> {
match NaiveTime::parse_from_str(value, format) {
Ok(t) => Ok(Some(Self::naive_time_to_us(t))),
Err(_) if is_try => Ok(None),
Err(e) => Err(exec_datafusion_err!("{e}")),
}
}

fn string_array_iter(array: &ArrayRef) -> Result<Box<dyn Iterator<Item = Option<&str>> + '_>> {
match array.data_type() {
DataType::Utf8 => Ok(Box::new(as_string_array(array)?.iter())),
DataType::LargeUtf8 => Ok(Box::new(as_large_string_array(array)?.iter())),
DataType::Utf8View => Ok(Box::new(as_string_view_array(array)?.iter())),
other => exec_err!("expected string array, got {other}"),
}
}

fn parse_value_array(value_arr: &ArrayRef, is_try: bool) -> Result<ArrayRef> {
let out: Time64MicrosecondArray = Self::string_array_iter(value_arr)?
.map(|v| match v {
Some(s) => Self::string_to_time_us_default(s, is_try),
None => Ok(None),
})
.collect::<Result<_>>()?;
Ok(Arc::new(out) as ArrayRef)
}

fn parse_value_with_format_array(
value_arr: &ArrayRef,
format_arr: &ArrayRef,
is_try: bool,
) -> Result<ArrayRef> {
if value_arr.len() != format_arr.len() {
return exec_err!(
"{}: value array length ({}) does not match format array length ({})",
if is_try { "try_to_time" } else { "to_time" },
value_arr.len(),
format_arr.len()
);
}
let values = Self::string_array_iter(value_arr)?;
let formats = Self::string_array_iter(format_arr)?;
let out: Time64MicrosecondArray = values
.zip(formats)
.map(|(v, f)| match (v, f) {
(Some(s), Some(fmt)) => Self::string_to_time_us_with_format(s, fmt, is_try),
_ => Ok(None),
})
.collect::<Result<_>>()?;
Ok(Arc::new(out) as ArrayRef)
}

fn cast_nonstring_to_time(array: &ArrayRef, is_try: bool) -> Result<ArrayRef> {
Ok(cast_with_options(
array,
&DataType::Time64(TimeUnit::Microsecond),
&CastOptions {
safe: is_try,
..Default::default()
},
)?)
}

fn kernel(is_try: bool, args: &[ArrayRef]) -> Result<ArrayRef> {
let value_arr = &args[0];
let format_arr = args.get(1);
let is_string = matches!(
value_arr.data_type(),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
);
match (is_string, format_arr) {
(true, Some(fmt)) => {
if !matches!(
fmt.data_type(),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
) {
return Err(unsupported_data_type_exec_err(
if is_try { "try_to_time" } else { "to_time" },
"STRING",
fmt.data_type(),
));
}
Self::parse_value_with_format_array(value_arr, fmt, is_try)
}
(true, None) => Self::parse_value_array(value_arr, is_try),
(false, _) => Self::cast_nonstring_to_time(value_arr, is_try),
}
}
}

impl ScalarUDFImpl for SparkTime {
fn name(&self) -> &str {
"spark_time"
if self.is_try {
"try_to_time"
} else {
"to_time"
}
}

fn signature(&self) -> &Signature {
Expand All @@ -64,53 +181,47 @@ impl ScalarUDFImpl for SparkTime {
Ok(DataType::Time64(TimeUnit::Microsecond))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
let arg = args.one()?;
let is_try = self.is_try;
match arg {
ColumnarValue::Array(array) => {
let array = match array.data_type() {
DataType::Utf8 => as_string_array(&array)?
.iter()
.map(|x| {
x.map(|v| Self::string_to_time_micros(v, is_try))
.transpose()
.map(|opt| opt.flatten())
})
.collect::<Result<Time64MicrosecondArray>>()?,
DataType::LargeUtf8 => as_large_string_array(&array)?
.iter()
.map(|x| {
x.map(|v| Self::string_to_time_micros(v, is_try))
.transpose()
.map(|opt| opt.flatten())
})
.collect::<Result<Time64MicrosecondArray>>()?,
DataType::Utf8View => as_string_view_array(&array)?
.iter()
.map(|x| {
x.map(|v| Self::string_to_time_micros(v, is_try))
.transpose()
.map(|opt| opt.flatten())
})
.collect::<Result<Time64MicrosecondArray>>()?,
_ => return exec_err!("expected string array for `time`"),
};
Ok(ColumnarValue::Array(Arc::new(array)))
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if !matches!(arg_types.len(), 1 | 2) {
return Err(invalid_arg_count_exec_err(
self.name(),
(1, 2),
arg_types.len(),
));
}
match &arg_types[0] {
DataType::Utf8
| DataType::LargeUtf8
| DataType::Utf8View
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Null => {}
other => {
return Err(unsupported_data_type_exec_err(
self.name(),
"STRING, TIME, TIMESTAMP or NULL",
other,
));
}
ColumnarValue::Scalar(scalar) => {
let value = match scalar.try_as_str() {
Some(x) => x
.map(|v| Self::string_to_time_micros(v, is_try))
.transpose()?
.flatten(),
_ => {
return exec_err!("expected string scalar for `time`");
}
};
Ok(ColumnarValue::Scalar(ScalarValue::Time64Microsecond(value)))
}
let mut coerced = arg_types.to_vec();
if let Some(format) = arg_types.get(1) {
match format {
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {}
// A NULL format yields a NULL result; coerce it to a Utf8 null so
// the kernel's per-row None handling produces NULL instead of
// erroring on the `Null` type.
DataType::Null => coerced[1] = DataType::Utf8,
other => {
return Err(unsupported_data_type_exec_err(self.name(), "STRING", other));
}
}
}
Ok(coerced)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let is_try = self.is_try;
make_scalar_function(move |a: &[ArrayRef]| Self::kernel(is_try, a), vec![])(&args.args)
}
}
34 changes: 32 additions & 2 deletions crates/sail-plan/src/function/scalar/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use datafusion_common::{DFSchemaRef, ScalarValue};
use datafusion_expr::expr::{self, Expr};
use datafusion_expr::{cast, lit, try_cast, when, BinaryExpr, ExprSchemable, Operator, ScalarUDF};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions::expr_fn::to_time;
use datafusion_spark::function::datetime::make_dt_interval::SparkMakeDtInterval;
use datafusion_spark::function::datetime::make_interval::SparkMakeInterval;
use sail_common::datetime::time_unit_to_multiplier;
Expand All @@ -21,6 +20,7 @@ use sail_function::scalar::datetime::spark_make_time::SparkMakeTime;
use sail_function::scalar::datetime::spark_make_timestamp_ntz::SparkMakeTimestampNtz;
use sail_function::scalar::datetime::spark_make_ym_interval::SparkMakeYmInterval;
use sail_function::scalar::datetime::spark_next_day::SparkNextDay;
use sail_function::scalar::datetime::spark_time::SparkTime;
use sail_function::scalar::datetime::spark_time_diff::SparkTimeDiff;
use sail_function::scalar::datetime::spark_time_trunc::SparkTimeTrunc;
use sail_function::scalar::datetime::spark_timestamp::SparkTimestamp;
Expand Down Expand Up @@ -386,6 +386,35 @@ fn to_timestamp(input: ScalarFunctionInput, timestamp_ntz: bool) -> PlanResult<E
}
}

fn to_time(input: ScalarFunctionInput) -> PlanResult<Expr> {
time_with_try(input, false)
}

fn try_to_time(input: ScalarFunctionInput) -> PlanResult<Expr> {
time_with_try(input, true)
}

/// Shared `to_time` / `try_to_time` planner. Routes through `SparkTime`, which
/// parses strings (with an optional chrono format) or casts time/timestamp args.
/// `to_time` errors on failure (Spark's `ToTime` is ANSI-invariant); `try_to_time`
/// (`is_try`) returns NULL.
fn time_with_try(input: ScalarFunctionInput, is_try: bool) -> PlanResult<Expr> {
let udf = ScalarUDF::from(SparkTime::new(is_try));
if input.arguments.len() == 1 {
Ok(udf.call(input.arguments))
} else if input.arguments.len() == 2 {
let (expr, format) = input.arguments.two()?;
let expr = cast(expr, DataType::Utf8);
let format = to_chrono_fmt(format);
Ok(udf.call(vec![expr, format]))
} else {
let name = if is_try { "try_to_time" } else { "to_time" };
Err(PlanError::invalid(format!(
"{name} requires 1 or 2 arguments"
)))
}
}

fn try_to_timestamp(input: ScalarFunctionInput, timestamp_ntz: bool) -> PlanResult<Expr> {
let data_type = timestamp_data_type(&input, timestamp_ntz);
if input.arguments.len() == 1 {
Expand Down Expand Up @@ -1013,7 +1042,8 @@ pub(super) fn list_built_in_datetime_functions() -> Vec<(&'static str, ScalarFun
),
("timestampdiff", F::custom(datediff)),
("to_date", F::custom(to_date)),
("to_time", F::var_arg(to_time)),
("to_time", F::custom(to_time)),
("try_to_time", F::custom(try_to_time)),
(
"to_timestamp",
F::custom(|input| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2856,7 +2856,7 @@
}
},
"output": {
"failure": "error in DataFusion: Execution error: Error parsing '12.10.05' as time. Tried formats: [\"HH.mm.ss\"]"
"success": "ok"
}
},
{
Expand Down Expand Up @@ -3758,7 +3758,7 @@
}
},
"output": {
"failure": "not supported: unknown function: try_to_time"
"success": "ok"
}
},
{
Expand All @@ -3780,7 +3780,7 @@
}
},
"output": {
"failure": "not supported: unknown function: try_to_time"
"success": "ok"
}
},
{
Expand All @@ -3802,7 +3802,7 @@
}
},
"output": {
"failure": "not supported: unknown function: try_to_time"
"success": "ok"
}
},
{
Expand Down
Loading
Loading