Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 163 additions & 47 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

use std::sync::Arc;

use crate::strings::{GenericStringArrayBuilder, StringViewArrayBuilder, append_view};
use crate::strings::{
GenericStringArrayBuilder, STRING_VIEW_INIT_BLOCK_SIZE, STRING_VIEW_MAX_BLOCK_SIZE,
StringViewArrayBuilder, append_view,
};
use arrow::array::{
Array, ArrayRef, GenericStringArray, NullBufferBuilder, OffsetSizeTrait,
StringViewArray, new_null_array,
Expand Down Expand Up @@ -323,32 +326,42 @@ fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<Arr
}

pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_lowercase(), name)
case_conversion(args, true, name)
}

pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_uppercase(), name)
case_conversion(args, false, name)
}

#[inline]
fn unicode_case(s: &str, lower: bool) -> String {
if lower {
s.to_lowercase()
} else {
s.to_uppercase()
}
}

fn case_conversion<'a, F>(
args: &'a [ColumnarValue],
op: F,
fn case_conversion(
args: &[ColumnarValue],
lower: bool,
name: &str,
) -> Result<ColumnarValue>
where
F: Fn(&'a str) -> String,
{
) -> Result<ColumnarValue> {
match &args[0] {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
array, op,
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32>(
array, lower,
)?)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
i64,
_,
>(array, op)?)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(
case_conversion_array::<i64>(array, lower)?,
)),
DataType::Utf8View => {
let string_array = as_string_view_array(array)?;
if string_array.is_ascii() {
return Ok(ColumnarValue::Array(Arc::new(
case_conversion_utf8view_ascii(string_array, lower),
)));
}
let item_len = string_array.len();
// Null-preserving: reuse the input null buffer as the output null buffer.
let nulls = string_array.nulls().cloned();
Expand All @@ -361,14 +374,14 @@ where
} else {
// SAFETY: `n.is_null(i)` was false in the branch above.
let s = unsafe { string_array.value_unchecked(i) };
builder.append_value(&op(s));
builder.append_value(&unicode_case(s, lower));
}
}
} else {
for i in 0..item_len {
// SAFETY: no null buffer means every index is valid.
let s = unsafe { string_array.value_unchecked(i) };
builder.append_value(&op(s));
builder.append_value(&unicode_case(s, lower));
}
}

Expand All @@ -378,32 +391,31 @@ where
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| op(x));
let result = a.as_ref().map(|x| unicode_case(x, lower));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| op(x));
let result = a.as_ref().map(|x| unicode_case(x, lower));
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
}
ScalarValue::Utf8View(a) => {
let result = a.as_ref().map(|x| op(x));
let result = a.as_ref().map(|x| unicode_case(x, lower));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
}
}

fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
fn case_conversion_array<O: OffsetSizeTrait>(
array: &ArrayRef,
lower: bool,
) -> Result<ArrayRef> {
const PRE_ALLOC_BYTES: usize = 8;

let string_array = as_generic_string_array::<O>(array)?;
if string_array.is_ascii() {
return case_conversion_ascii_array::<O, _>(string_array, op);
return case_conversion_ascii_array::<O>(string_array, lower);
}

// Values contain non-ASCII.
Expand All @@ -423,43 +435,147 @@ where
} else {
// SAFETY: `n.is_null(i)` was false in the branch above.
let s = unsafe { string_array.value_unchecked(i) };
builder.append_value(&op(s));
builder.append_value(&unicode_case(s, lower));
}
}
} else {
for i in 0..item_len {
// SAFETY: no null buffer means every index is valid.
let s = unsafe { string_array.value_unchecked(i) };
builder.append_value(&op(s));
builder.append_value(&unicode_case(s, lower));
}
}
Ok(Arc::new(builder.finish(nulls)?))
}

/// Fast path for case conversion on an all-ASCII `StringViewArray`.
fn case_conversion_utf8view_ascii(
array: &StringViewArray,
lower: bool,
) -> StringViewArray {
// Specialize per conversion so the byte call inlines in the hot loops below.
if lower {
case_conversion_utf8view_ascii_inner(array, u8::to_ascii_lowercase)
} else {
case_conversion_utf8view_ascii_inner(array, u8::to_ascii_uppercase)
}
}

/// Walks the views once: inline rows (length ≤ 12) convert their inline bytes
/// in place; long rows copy their referenced bytes into a single packed output
/// buffer while converting, then rewrite the view (`buffer_index = 0`, new
/// offset, new 4-byte prefix) to point at it.
fn case_conversion_utf8view_ascii_inner<F: Fn(&u8) -> u8>(
array: &StringViewArray,
convert: F,
) -> StringViewArray {
let item_len = array.len();
let views = array.views();
let data_buffers = array.data_buffers();
let nulls = array.nulls();

let mut new_views: Vec<u128> = Vec::with_capacity(item_len);
let mut in_progress: Vec<u8> = Vec::new();
let mut completed: Vec<Buffer> = Vec::new();
let mut block_size: u32 = STRING_VIEW_INIT_BLOCK_SIZE;

for i in 0..item_len {
if nulls.is_some_and(|n| n.is_null(i)) {
new_views.push(0);
continue;
}
let view = views[i];
let len = view as u32 as usize;
if len == 0 {
new_views.push(0);
continue;
}
let mut bytes = view.to_le_bytes();
if len <= 12 {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add some comments, I assume 12 len is the fast path for german strings?

// Inline row: convert the inline data bytes; layout unchanged.
for b in &mut bytes[4..4 + len] {
*b = convert(b);
}
new_views.push(u128::from_le_bytes(bytes));
} else {
// Make sure the current data block has room for this value;
// otherwise flush and start a new, larger block.
let required_cap = in_progress.len() + len;
if in_progress.capacity() < required_cap {
if !in_progress.is_empty() {
completed.push(Buffer::from_vec(std::mem::take(&mut in_progress)));
}
if block_size < STRING_VIEW_MAX_BLOCK_SIZE {
block_size = block_size.saturating_mul(2);
}
let to_reserve = len.max(block_size as usize);
in_progress.reserve(to_reserve);
}

let buffer_index: u32 = i32::try_from(completed.len())
.expect("buffer count exceeds i32::MAX")
as u32;
let new_offset: u32 =
i32::try_from(in_progress.len()).expect("offset exceeds i32::MAX") as u32;

let src_buffer_index =
u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
let src_offset =
u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize;
let src =
&data_buffers[src_buffer_index].as_slice()[src_offset..src_offset + len];

let prefix_start = in_progress.len();
in_progress.extend(src.iter().map(&convert));

// Prefix is the first 4 bytes of the converted data we just wrote.
let prefix: [u8; 4] = in_progress[prefix_start..prefix_start + 4]
.try_into()
.unwrap();
bytes[4..8].copy_from_slice(&prefix);
bytes[8..12].copy_from_slice(&buffer_index.to_le_bytes());
bytes[12..16].copy_from_slice(&new_offset.to_le_bytes());
new_views.push(u128::from_le_bytes(bytes));
}
}

if !in_progress.is_empty() {
completed.push(Buffer::from_vec(in_progress));
}

// SAFETY: each long view's buffer_index addresses a buffer we wrote, and
// its offset addresses bytes within that buffer; prefixes were copied from
// those same bytes; inline views were rewritten from valid inline bytes;
// null/empty rows are zero views with no buffer reference; row count is
// unchanged.
unsafe {
StringViewArray::new_unchecked(
ScalarBuffer::from(new_views),
completed,
array.nulls().cloned(),
)
}
}

/// Fast path for case conversion on an all-ASCII string array. ASCII case
/// conversion is byte-length-preserving, so we can convert the entire addressed
/// range in one call and reuse the offsets and nulls buffers — rebasing the
/// offsets when the input is a sliced array.
fn case_conversion_ascii_array<'a, O, F>(
string_array: &'a GenericStringArray<O>,
op: F,
) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
/// byte range in one pass over the value buffer and reuse the offsets and nulls
/// buffers — rebasing the offsets when the input is a sliced array.
fn case_conversion_ascii_array<O: OffsetSizeTrait>(
string_array: &GenericStringArray<O>,
lower: bool,
) -> Result<ArrayRef> {
let value_offsets = string_array.value_offsets();
let start = value_offsets.first().unwrap().as_usize();
let end = value_offsets.last().unwrap().as_usize();
let relevant = &string_array.value_data()[start..end];

// SAFETY: `relevant` is a subslice of the string array's value buffer,
// which is valid UTF-8.
let str_values = unsafe { std::str::from_utf8_unchecked(relevant) };

let converted_values = op(str_values);
debug_assert_eq!(converted_values.len(), str_values.len());
let values = Buffer::from_vec(converted_values.into_bytes());
let converted: Vec<u8> = if lower {
relevant.iter().map(u8::to_ascii_lowercase).collect()
} else {
relevant.iter().map(u8::to_ascii_uppercase).collect()
};
let values = Buffer::from_vec(converted);

// Shift offsets from `start`-based to 0-based so they index into `values`.
let offsets = if start == 0 {
Expand All @@ -468,7 +584,7 @@ where
let s = O::usize_as(start);
let rebased: Vec<O> = value_offsets.iter().map(|&o| o - s).collect();
// SAFETY: subtracting a constant from monotonic offsets preserves
// monotonicity, and `start` is the minimum offset so no underflow.
// monotonicity, and `start` is the minimum offset, so no underflow.
unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(rebased)) }
};

Expand Down
Loading
Loading