Skip to content

Commit f65310c

Browse files
committed
Add faster specialization for deserializing vector<float>
1 parent de99af7 commit f65310c

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

scylla-cql/src/frame/response/result.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::frame::{frame_errors::ParseError, types};
77
use crate::types::deserialize::result::{RowIterator, TypedRowIterator};
88
use crate::types::deserialize::value::{
99
mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator,
10+
VectorIterator,
1011
};
1112
use crate::types::deserialize::{DeserializationError, FrameSlice};
1213
use bytes::{Buf, Bytes};
@@ -829,9 +830,17 @@ pub fn deser_cql_value(
829830
.collect::<StdResult<_, _>>()?;
830831
CqlValue::Tuple(t)
831832
}
832-
Vector(_type_name, _) => {
833-
let l = Vec::<CqlValue>::deserialize(typ, v)?;
834-
CqlValue::Vector(l)
833+
// Specialization for faster deserialization of vectors of floats, which are currently
834+
// the only type of vector
835+
Vector(elem_type, _) if matches!(elem_type.as_ref(), Float) => {
836+
let v = VectorIterator::<CqlValue>::deserialize_vector_of_float_to_vec_of_cql_value(
837+
typ, v,
838+
)?;
839+
CqlValue::Vector(v)
840+
}
841+
Vector(_, _) => {
842+
let v = Vec::<CqlValue>::deserialize(typ, v)?;
843+
CqlValue::Vector(v)
835844
}
836845
})
837846
}

scylla-cql/src/types/deserialize/value.rs

+53-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
//! Provides types for dealing with CQL value deserialization.
22
3+
use bytes::Bytes;
34
use std::{
45
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
56
hash::{BuildHasher, Hash},
67
net::IpAddr,
78
};
8-
9-
use bytes::Bytes;
109
use uuid::Uuid;
1110

1211
use std::fmt::{Display, Pointer};
@@ -759,7 +758,7 @@ pub struct VectorIterator<'frame, T> {
759758
}
760759

761760
impl<'frame, T> VectorIterator<'frame, T> {
762-
fn new(
761+
pub fn new(
763762
coll_typ: &'frame ColumnType,
764763
elem_typ: &'frame ColumnType,
765764
count: usize,
@@ -774,6 +773,54 @@ impl<'frame, T> VectorIterator<'frame, T> {
774773
phantom_data: std::marker::PhantomData,
775774
}
776775
}
776+
777+
/// Faster specialization for deserializing a `vector<float>` into `Vec<CqlValue>`.
778+
/// The generic code `Vec<CqlValue>::deserialize(...)` is much slower because it has to
779+
/// match on the element type for every item in the vector.
780+
/// Here we just hardcode `f32` and we can shortcut a lot of code.
781+
///
782+
/// This could be nicer if Rust had generic type specialization in stable,
783+
/// but for now we need a separate method.
784+
pub fn deserialize_vector_of_float_to_vec_of_cql_value(
785+
typ: &'frame ColumnType,
786+
v: Option<FrameSlice<'frame>>,
787+
) -> Result<Vec<CqlValue>, DeserializationError> {
788+
789+
// Typecheck would make sure those never happen:
790+
let ColumnType::Vector(elem_type, elem_count) = typ else {
791+
panic!("Wrong column type: {:?}. Expected vector<>", typ);
792+
};
793+
if !matches!(elem_type.as_ref(), ColumnType::Float) {
794+
panic!("Wrong element type: {:?}. Expected float", typ);
795+
}
796+
797+
let elem_count = *elem_count as usize;
798+
let mut frame = v.map(|s| s.as_slice()).unwrap_or_default();
799+
800+
// Check length only once
801+
if frame.len() < 4 * elem_count {
802+
return Err(mk_deser_err::<Vec<CqlValue>>(
803+
typ,
804+
BuiltinDeserializationErrorKind::RawCqlBytesReadError(
805+
LowLevelDeserializationError::TooFewBytesReceived {
806+
expected: 4 * elem_count,
807+
received: frame.len(),
808+
},
809+
),
810+
));
811+
}
812+
813+
// We know we have enough elements in the buffer, so now we can skip the checks
814+
let mut result = Vec::with_capacity(elem_count);
815+
for _ in 0..elem_count {
816+
// we did check for frame length earlier so we can safely not check again
817+
let (elem, remaining) = unsafe { frame.split_at_unchecked(4) };
818+
let elem = f32::from_be_bytes(elem.try_into().unwrap());
819+
result.push(CqlValue::Float(elem));
820+
frame = remaining;
821+
}
822+
Ok(result)
823+
}
777824
}
778825

779826
impl<'frame, T> DeserializeValue<'frame> for VectorIterator<'frame, T>
@@ -828,6 +875,7 @@ where
828875
{
829876
type Item = Result<T, DeserializationError>;
830877

878+
#[inline]
831879
fn next(&mut self) -> Option<Self::Item> {
832880
let raw = self.raw_iter.next()?.map_err(|err| {
833881
mk_deser_err::<Self>(
@@ -883,6 +931,7 @@ where
883931
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
884932
.map_err(deser_error_replace_rust_name::<Self>)
885933
}
934+
886935
ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::deserialize(typ, v)
887936
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
888937
.map_err(deser_error_replace_rust_name::<Self>),
@@ -1435,6 +1484,7 @@ impl<'frame> VectorBytesSequenceIterator<'frame> {
14351484
impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> {
14361485
type Item = Result<Option<FrameSlice<'frame>>, LowLevelDeserializationError>;
14371486

1487+
#[inline]
14381488
fn next(&mut self) -> Option<Self::Item> {
14391489
self.remaining = self.remaining.checked_sub(1)?;
14401490
Some(self.slice.read_subslice(self.elem_len))

0 commit comments

Comments
 (0)