Skip to content

Commit 5ac05fe

Browse files
committed
Fix SQLLogicTest failures and improve UDF flexibility
- Fix ORDER BY stability in aggregations.slt and integration.slt by adding secondary sort column - Convert to_char and at_time_zone UDFs to use ScalarUDFImpl with flexible signatures - Fix at_time_zone to properly adjust timestamps for timezone display - Parse JSON-like strings in to_json instead of escaping them - Fix query column count mismatch in custom_functions.slt (T -> TT)
1 parent 74541dd commit 5ac05fe

5 files changed

Lines changed: 137 additions & 54 deletions

File tree

src/functions.rs

Lines changed: 124 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,41 @@ pub fn register_custom_functions(ctx: &mut datafusion::execution::context::Sessi
5050

5151
/// Create the to_char UDF for PostgreSQL-compatible timestamp formatting
5252
fn create_to_char_udf() -> ScalarUDF {
53-
let to_char_fn: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
53+
ScalarUDF::from(ToCharUDF::new())
54+
}
55+
56+
#[derive(Debug, Hash, Eq, PartialEq)]
57+
struct ToCharUDF {
58+
signature: Signature,
59+
}
60+
61+
impl ToCharUDF {
62+
fn new() -> Self {
63+
Self {
64+
signature: Signature::any(2, Volatility::Immutable),
65+
}
66+
}
67+
}
68+
69+
impl ScalarUDFImpl for ToCharUDF {
70+
fn as_any(&self) -> &dyn Any {
71+
self
72+
}
73+
74+
fn name(&self) -> &str {
75+
"to_char"
76+
}
77+
78+
fn signature(&self) -> &Signature {
79+
&self.signature
80+
}
81+
82+
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
83+
Ok(DataType::Utf8)
84+
}
85+
86+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> datafusion::error::Result<ColumnarValue> {
87+
let args = args.args;
5488
if args.len() != 2 {
5589
return Err(DataFusionError::Execution(
5690
"to_char requires exactly 2 arguments: timestamp and format string".to_string(),
@@ -66,27 +100,26 @@ fn create_to_char_udf() -> ScalarUDF {
66100
// Extract format string
67101
let format_str = match &args[1] {
68102
ColumnarValue::Scalar(scalar) => match scalar {
69-
datafusion::scalar::ScalarValue::Utf8(Some(s)) => s.clone(),
103+
ScalarValue::Utf8(Some(s)) => s.clone(),
104+
ScalarValue::LargeUtf8(Some(s)) => s.clone(),
70105
_ => return Err(DataFusionError::Execution("Format string must be a UTF8 string".to_string())),
71106
},
72-
ColumnarValue::Array(_) => {
73-
return Err(DataFusionError::Execution("Format string must be a scalar value".to_string()));
107+
ColumnarValue::Array(arr) => {
108+
if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
109+
if str_arr.len() == 1 && !str_arr.is_null(0) {
110+
str_arr.value(0).to_string()
111+
} else {
112+
return Err(DataFusionError::Execution("Format string must be a scalar value".to_string()));
113+
}
114+
} else {
115+
return Err(DataFusionError::Execution("Format string must be a UTF8 string".to_string()))
116+
}
74117
}
75118
};
76119

77-
// Convert timestamps to formatted strings
78120
let result = format_timestamps(&timestamp_array, &format_str)?;
79-
80121
Ok(ColumnarValue::Array(result))
81-
});
82-
83-
create_udf(
84-
"to_char",
85-
vec![DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), DataType::Utf8],
86-
DataType::Utf8,
87-
Volatility::Immutable,
88-
to_char_fn,
89-
)
122+
}
90123
}
91124

92125
/// Format timestamps according to PostgreSQL format patterns
@@ -150,7 +183,44 @@ fn postgres_to_chrono_format(pg_format: &str) -> String {
150183

151184
/// Create the AT TIME ZONE UDF for timezone conversion
152185
fn create_at_time_zone_udf() -> ScalarUDF {
153-
let at_time_zone_fn: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
186+
ScalarUDF::from(AtTimeZoneUDF::new())
187+
}
188+
189+
#[derive(Debug, Hash, Eq, PartialEq)]
190+
struct AtTimeZoneUDF {
191+
signature: Signature,
192+
}
193+
194+
impl AtTimeZoneUDF {
195+
fn new() -> Self {
196+
Self {
197+
signature: Signature::any(2, Volatility::Immutable),
198+
}
199+
}
200+
}
201+
202+
impl ScalarUDFImpl for AtTimeZoneUDF {
203+
fn as_any(&self) -> &dyn Any {
204+
self
205+
}
206+
207+
fn name(&self) -> &str {
208+
"at_time_zone"
209+
}
210+
211+
fn signature(&self) -> &Signature {
212+
&self.signature
213+
}
214+
215+
fn return_type(&self, arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
216+
match &arg_types[0] {
217+
DataType::Timestamp(unit, _) => Ok(DataType::Timestamp(unit.clone(), None)),
218+
_ => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
219+
}
220+
}
221+
222+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> datafusion::error::Result<ColumnarValue> {
223+
let args = args.args;
154224
if args.len() != 2 {
155225
return Err(DataFusionError::Execution(
156226
"AT TIME ZONE requires exactly 2 arguments: timestamp and timezone".to_string(),
@@ -166,31 +236,33 @@ fn create_at_time_zone_udf() -> ScalarUDF {
166236
// Extract timezone string
167237
let tz_str = match &args[1] {
168238
ColumnarValue::Scalar(scalar) => match scalar {
169-
datafusion::scalar::ScalarValue::Utf8(Some(s)) => s.clone(),
239+
ScalarValue::Utf8(Some(s)) => s.clone(),
240+
ScalarValue::LargeUtf8(Some(s)) => s.clone(),
170241
_ => return Err(DataFusionError::Execution("Timezone must be a UTF8 string".to_string())),
171242
},
172-
ColumnarValue::Array(_) => {
173-
return Err(DataFusionError::Execution("Timezone must be a scalar value".to_string()));
243+
ColumnarValue::Array(arr) => {
244+
if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
245+
if str_arr.len() == 1 && !str_arr.is_null(0) {
246+
str_arr.value(0).to_string()
247+
} else {
248+
return Err(DataFusionError::Execution("Timezone must be a scalar string value".to_string()));
249+
}
250+
} else {
251+
return Err(DataFusionError::Execution("Timezone must be a UTF8 string".to_string()))
252+
}
174253
}
175254
};
176255

177-
// Convert timestamps to the specified timezone
178256
let result = convert_timezone(&timestamp_array, &tz_str)?;
179-
180257
Ok(ColumnarValue::Array(result))
181-
});
182-
183-
create_udf(
184-
"at_time_zone",
185-
vec![DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), DataType::Utf8],
186-
DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))),
187-
Volatility::Immutable,
188-
at_time_zone_fn,
189-
)
258+
}
190259
}
191260

192261
/// Convert timestamps to a different timezone
262+
/// This adjusts the timestamp so that when formatted as UTC, it displays the local time
193263
fn convert_timezone(timestamp_array: &ArrayRef, tz_str: &str) -> datafusion::error::Result<ArrayRef> {
264+
use chrono::Offset;
265+
194266
// Parse timezone
195267
let tz: Tz = tz_str.parse().map_err(|_| DataFusionError::Execution(format!("Invalid timezone: {}", tz_str)))?;
196268

@@ -206,11 +278,13 @@ fn convert_timezone(timestamp_array: &ArrayRef, tz_str: &str) -> datafusion::err
206278
let datetime =
207279
DateTime::<Utc>::from_timestamp_micros(timestamp_us).ok_or_else(|| DataFusionError::Execution("Invalid timestamp".to_string()))?;
208280

209-
// Convert to target timezone (keeping the same instant in time)
210-
let converted = datetime.with_timezone(&tz);
211-
212-
// Convert back to UTC timestamp for storage
213-
builder.append_value(converted.timestamp_micros());
281+
// Get the local time in target timezone
282+
let local_time = datetime.with_timezone(&tz);
283+
// Get the offset from UTC in seconds
284+
let offset_secs = local_time.offset().fix().local_minus_utc() as i64;
285+
// Adjust the timestamp so that when formatted as UTC, it shows local time
286+
let adjusted_us = timestamp_us + (offset_secs * 1_000_000);
287+
builder.append_value(adjusted_us);
214288
}
215289
}
216290

@@ -225,11 +299,13 @@ fn convert_timezone(timestamp_array: &ArrayRef, tz_str: &str) -> datafusion::err
225299
let timestamp_ns = timestamps.value(i);
226300
let datetime = DateTime::<Utc>::from_timestamp_nanos(timestamp_ns);
227301

228-
// Convert to target timezone (keeping the same instant in time)
229-
let converted = datetime.with_timezone(&tz);
230-
231-
// Convert back to UTC timestamp for storage
232-
builder.append_value(converted.timestamp_nanos_opt().unwrap_or(timestamp_ns));
302+
// Get the local time in target timezone
303+
let local_time = datetime.with_timezone(&tz);
304+
// Get the offset from UTC in seconds
305+
let offset_secs = local_time.offset().fix().local_minus_utc() as i64;
306+
// Adjust the timestamp so that when formatted as UTC, it shows local time
307+
let adjusted_ns = timestamp_ns + (offset_secs * 1_000_000_000);
308+
builder.append_value(adjusted_ns);
233309
}
234310
}
235311

@@ -490,7 +566,14 @@ fn array_to_json_values(array: &ArrayRef) -> datafusion::error::Result<Vec<JsonV
490566
if string_array.is_null(i) {
491567
values.push(JsonValue::Null);
492568
} else {
493-
values.push(JsonValue::String(string_array.value(i).to_string()));
569+
let s = string_array.value(i);
570+
// Try to parse as JSON if it looks like JSON
571+
let val = if (s.starts_with('{') && s.ends_with('}')) || (s.starts_with('[') && s.ends_with(']')) {
572+
serde_json::from_str(s).unwrap_or_else(|_| JsonValue::String(s.to_string()))
573+
} else {
574+
JsonValue::String(s.to_string())
575+
};
576+
values.push(val);
494577
}
495578
}
496579
}

tests/slt/aggregations.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ WHERE project_id = 'agg_test' AND status_code = 'OK'
156156

157157
# Test GROUP BY with ORDER BY
158158
query TI
159-
SELECT level, COUNT(*) FROM otel_logs_and_spans
159+
SELECT level, COUNT(*) FROM otel_logs_and_spans
160160
WHERE project_id = 'agg_test'
161161
GROUP BY level
162-
ORDER BY COUNT(*) DESC
162+
ORDER BY COUNT(*) DESC, level
163163
----
164164
INFO 3
165165
ERROR 1

tests/slt/custom_functions.slt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,23 @@ December 25, 2024
4747

4848
# Test conversion to different timezones
4949
# Note: AT TIME ZONE preserves the instant but shows time in target zone
50-
query T
51-
SELECT
50+
query TT
51+
SELECT
5252
to_char(timestamp, 'YYYY-MM-DD HH24:MI:SS') as utc_time,
5353
to_char(at_time_zone(timestamp, 'America/New_York'), 'YYYY-MM-DD HH24:MI:SS') as ny_time
54-
FROM otel_logs_and_spans
54+
FROM otel_logs_and_spans
5555
WHERE project_id = 'test_functions' AND id = 'func_test_1'
5656
----
57-
2024-01-15 14:30:45 2024-01-15 14:30:45
57+
2024-01-15 14:30:45 2024-01-15 09:30:45
5858

59-
query T
60-
SELECT
59+
query TT
60+
SELECT
6161
to_char(timestamp, 'YYYY-MM-DD HH24:MI:SS') as utc_time,
6262
to_char(at_time_zone(timestamp, 'Asia/Tokyo'), 'YYYY-MM-DD HH24:MI:SS') as tokyo_time
63-
FROM otel_logs_and_spans
63+
FROM otel_logs_and_spans
6464
WHERE project_id = 'test_functions' AND id = 'func_test_1'
6565
----
66-
2024-01-15 14:30:45 2024-01-15 14:30:45
66+
2024-01-15 14:30:45 2024-01-15 23:30:45
6767

6868
# === Test JSON functions from datafusion-functions-json ===
6969

tests/slt/integration.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ trace_root_1 /api/users 250000000
140140

141141
# 7. Log level distribution
142142
query TI
143-
SELECT level, COUNT(*) as count FROM otel_logs_and_spans
143+
SELECT level, COUNT(*) as count FROM otel_logs_and_spans
144144
WHERE project_id = 'prod_monitoring' AND level IS NOT NULL
145145
GROUP BY level
146-
ORDER BY count DESC
146+
ORDER BY count DESC, level
147147
----
148148
DEBUG 2
149149
INFO 2

tests/slt/json_functions.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ SELECT json_build_array(id, name, duration) FROM otel_logs_and_spans WHERE proje
199199
query T
200200
SELECT to_json(summary) FROM otel_logs_and_spans WHERE project_id='00000000-0000-0000-0000-000000000000' ORDER BY timestamp LIMIT 1
201201
----
202-
"[{\"status\": \"ok\", \"count\": 5}]"
202+
[{"count":5,"status":"ok"}]
203203

204204
# Test to_json with different types
205205
query T

0 commit comments

Comments
 (0)