Skip to content

Commit f3eb389

Browse files
committed
Add faster specialization for deserializing vector<float>
1 parent de99af7 commit f3eb389

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

scylla-cql/src/frame/response/result.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::frame::{frame_errors::ParseError, types};
77
use crate::types::deserialize::result::{RowIterator, TypedRowIterator};
88
use crate::types::deserialize::value::{
99
mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator,
10+
VectorIterator,
1011
};
1112
use crate::types::deserialize::{DeserializationError, FrameSlice};
1213
use bytes::{Buf, Bytes};
@@ -829,9 +830,17 @@ pub fn deser_cql_value(
829830
.collect::<StdResult<_, _>>()?;
830831
CqlValue::Tuple(t)
831832
}
832-
Vector(_type_name, _) => {
833-
let l = Vec::<CqlValue>::deserialize(typ, v)?;
834-
CqlValue::Vector(l)
833+
// Specialization for faster deserialization of vectors of floats, which are currently
834+
// the only type of vector
835+
Vector(elem_type, _) if matches!(elem_type.as_ref(), Float) => {
836+
let v = VectorIterator::<CqlValue>::deserialize_vector_of_float_to_vec_of_cql_value(
837+
typ, v,
838+
)?;
839+
CqlValue::Vector(v)
840+
}
841+
Vector(_, _) => {
842+
let v = Vec::<CqlValue>::deserialize(typ, v)?;
843+
CqlValue::Vector(v)
835844
}
836845
})
837846
}

scylla-cql/src/types/deserialize/value.rs

+56-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
//! Provides types for dealing with CQL value deserialization.
22
3-
use std::{
4-
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
5-
hash::{BuildHasher, Hash},
6-
net::IpAddr,
7-
};
8-
93
use bytes::Bytes;
4+
use std::{collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, mem, net::IpAddr};
105
use uuid::Uuid;
116

127
use std::fmt::{Display, Pointer};
@@ -759,7 +754,7 @@ pub struct VectorIterator<'frame, T> {
759754
}
760755

761756
impl<'frame, T> VectorIterator<'frame, T> {
762-
fn new(
757+
pub fn new(
763758
coll_typ: &'frame ColumnType,
764759
elem_typ: &'frame ColumnType,
765760
count: usize,
@@ -774,6 +769,57 @@ impl<'frame, T> VectorIterator<'frame, T> {
774769
phantom_data: std::marker::PhantomData,
775770
}
776771
}
772+
773+
/// Faster specialization for deserializing a `vector<float>` into `Vec<CqlValue>`.
774+
/// The generic code `Vec<CqlValue>::deserialize(...)` is much slower because it has to
775+
/// match on the element type for every item in the vector.
776+
/// Here we just hardcode `f32` and we can shortcut a lot of code.
777+
///
778+
/// This could be nicer if Rust had generic type specialization in stable,
779+
/// but for now we need a separate method.
780+
pub fn deserialize_vector_of_float_to_vec_of_cql_value(
781+
typ: &'frame ColumnType,
782+
v: Option<FrameSlice<'frame>>,
783+
) -> Result<Vec<CqlValue>, DeserializationError> {
784+
785+
// Typecheck would make sure those never happen:
786+
let ColumnType::Vector(elem_type, elem_count) = typ else {
787+
panic!("Wrong column type: {:?}. Expected vector<>", typ);
788+
};
789+
if !matches!(elem_type.as_ref(), ColumnType::Float) {
790+
panic!("Wrong element type: {:?}. Expected float", typ);
791+
}
792+
793+
let elem_count = *elem_count as usize;
794+
let mut frame = v.map(|s| s.as_slice()).unwrap_or_default();
795+
let mut result = Vec::with_capacity(elem_count);
796+
797+
unsafe {
798+
type Float = f32;
799+
800+
// Check length only once
801+
if frame.len() < size_of::<Float>() * elem_count {
802+
return Err(mk_deser_err::<Vec<CqlValue>>(
803+
typ,
804+
BuiltinDeserializationErrorKind::RawCqlBytesReadError(
805+
LowLevelDeserializationError::TooFewBytesReceived {
806+
expected: size_of::<Float>() * elem_count,
807+
received: frame.len(),
808+
},
809+
),
810+
));
811+
}
812+
// We know we have enough elements in the buffer, so now we can skip the checks
813+
for _ in 0..elem_count {
814+
// we did check for frame length earlier, so we can safely not check again
815+
let (elem, remaining) = frame.split_first_chunk().unwrap_unchecked();
816+
let elem = Float::from_be_bytes(*elem);
817+
result.push(CqlValue::Float(elem));
818+
frame = remaining;
819+
}
820+
}
821+
Ok(result)
822+
}
777823
}
778824

779825
impl<'frame, T> DeserializeValue<'frame> for VectorIterator<'frame, T>
@@ -828,6 +874,7 @@ where
828874
{
829875
type Item = Result<T, DeserializationError>;
830876

877+
#[inline]
831878
fn next(&mut self) -> Option<Self::Item> {
832879
let raw = self.raw_iter.next()?.map_err(|err| {
833880
mk_deser_err::<Self>(
@@ -883,6 +930,7 @@ where
883930
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
884931
.map_err(deser_error_replace_rust_name::<Self>)
885932
}
933+
886934
ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::deserialize(typ, v)
887935
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
888936
.map_err(deser_error_replace_rust_name::<Self>),
@@ -1435,6 +1483,7 @@ impl<'frame> VectorBytesSequenceIterator<'frame> {
14351483
impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> {
14361484
type Item = Result<Option<FrameSlice<'frame>>, LowLevelDeserializationError>;
14371485

1486+
#[inline]
14381487
fn next(&mut self) -> Option<Self::Item> {
14391488
self.remaining = self.remaining.checked_sub(1)?;
14401489
Some(self.slice.read_subslice(self.elem_len))

0 commit comments

Comments
 (0)