Skip to content

Proxy: compresssion support #1246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions scylla-cql/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio::io::{AsyncRead, AsyncReadExt};
use uuid::Uuid;

use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use std::{collections::HashMap, convert::TryFrom};

Expand All @@ -22,11 +23,13 @@ use response::ResponseOpcode;

const HEADER_SIZE: usize = 9;

// Frame flags
const FLAG_COMPRESSION: u8 = 0x01;
const FLAG_TRACING: u8 = 0x02;
const FLAG_CUSTOM_PAYLOAD: u8 = 0x04;
const FLAG_WARNING: u8 = 0x08;
pub mod flag {
//! Frame flags
pub const COMPRESSION: u8 = 0x01;
pub const TRACING: u8 = 0x02;
pub const CUSTOM_PAYLOAD: u8 = 0x04;
pub const WARNING: u8 = 0x08;
}

// All of the Authenticators supported by Scylla
#[derive(Debug, PartialEq, Eq, Clone)]
Expand Down Expand Up @@ -56,6 +59,27 @@ impl Compression {
}
}

/// Unknown compression.
#[derive(Error, Debug, Clone)]
#[error("Unknown compression: {name}")]
pub struct CompressionFromStrError {
name: String,
}

impl FromStr for Compression {
type Err = CompressionFromStrError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"lz4" => Ok(Self::Lz4),
"snappy" => Ok(Self::Snappy),
other => Err(Self::Err {
name: other.to_owned(),
}),
}
}
}

impl Display for Compression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
Expand All @@ -76,15 +100,15 @@ impl SerializedRequest {
let mut data = vec![0; HEADER_SIZE];

if let Some(compression) = compression {
flags |= FLAG_COMPRESSION;
flags |= flag::COMPRESSION;
let body = req.to_bytes()?;
compress_append(&body, compression, &mut data)?;
} else {
req.serialize(&mut data)?;
}

if tracing {
flags |= FLAG_TRACING;
flags |= flag::TRACING;
}

data[0] = 4; // We only support version 4 for now
Expand Down Expand Up @@ -188,15 +212,15 @@ pub fn parse_response_body_extensions(
compression: Option<Compression>,
mut body: Bytes,
) -> Result<ResponseBodyWithExtensions, FrameBodyExtensionsParseError> {
if flags & FLAG_COMPRESSION != 0 {
if flags & flag::COMPRESSION != 0 {
if let Some(compression) = compression {
body = decompress(&body, compression)?.into();
} else {
return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated);
}
}

let trace_id = if flags & FLAG_TRACING != 0 {
let trace_id = if flags & flag::TRACING != 0 {
let buf = &mut &*body;
let trace_id =
types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?;
Expand All @@ -206,7 +230,7 @@ pub fn parse_response_body_extensions(
None
};

let warnings = if flags & FLAG_WARNING != 0 {
let warnings = if flags & flag::WARNING != 0 {
let body_len = body.len();
let buf = &mut &*body;
let warnings = types::read_string_list(buf)
Expand All @@ -218,7 +242,7 @@ pub fn parse_response_body_extensions(
Vec::new()
};

let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 {
let custom_payload = if flags & flag::CUSTOM_PAYLOAD != 0 {
let body_len = body.len();
let buf = &mut &*body;
let payload_map = types::read_bytes_map(buf)
Expand All @@ -238,7 +262,7 @@ pub fn parse_response_body_extensions(
})
}

fn compress_append(
pub fn compress_append(
uncomp_body: &[u8],
compression: Compression,
out: &mut Vec<u8>,
Expand All @@ -264,7 +288,7 @@ fn compress_append(
}
}

fn decompress(
pub fn decompress(
mut comp_body: &[u8],
compression: Compression,
) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
Expand Down
14 changes: 14 additions & 0 deletions scylla-cql/src/frame/request/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::{
frame::types,
};

use super::DeserializableRequest;

pub struct Startup<'a> {
pub options: HashMap<Cow<'a, str>, Cow<'a, str>>,
}
Expand All @@ -31,3 +33,15 @@ pub enum StartupSerializationError {
#[error("Malformed startup options: {0}")]
OptionsSerialization(TryFromIntError),
}

impl DeserializableRequest for Startup<'_> {
fn deserialize(buf: &mut &[u8]) -> Result<Self, super::RequestDeserializationError> {
// Note: this is inefficient, but it's only used for tests and it's not common
// to deserialize STARTUP frames anyway.
let options = types::read_string_map(buf)?
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
Ok(Self { options })
}
}
14 changes: 12 additions & 2 deletions scylla-proxy/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use std::net::SocketAddr;

use scylla_cql::frame::frame_errors::{FrameHeaderParseError, LowLevelDeserializationError};
use scylla_cql::frame::frame_errors::{
FrameBodyExtensionsParseError, FrameHeaderParseError, LowLevelDeserializationError,
};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ReadFrameError {
#[error("Failed to read frame header: {0}")]
Header(#[from] FrameHeaderParseError),
#[error("Failed to decompress frame: {0}")]
Compression(#[from] FrameBodyExtensionsParseError),
}

#[derive(Debug, Error)]
pub enum DoorkeeperError {
#[error("Listen on {0} failed with {1}")]
Expand All @@ -20,7 +30,7 @@ pub enum DoorkeeperError {
#[error("Could not send Options frame for obtaining shards number: {0}")]
ObtainingShardNumber(std::io::Error),
#[error("Could not send read Supported frame for obtaining shards number: {0}")]
ObtainingShardNumberFrame(FrameHeaderParseError),
ObtainingShardNumberFrame(ReadFrameError),
#[error("Could not read Supported options: {0}")]
ObtainingShardNumberParseOptions(LowLevelDeserializationError),
#[error("ShardInfo parameters missing")]
Expand Down
52 changes: 35 additions & 17 deletions scylla-proxy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use tracing::warn;

use crate::errors::ReadFrameError;
use crate::proxy::CompressionReader;

const HEADER_SIZE: usize = 9;

// Parts of the frame header which are not determined by the request/response type.
Expand All @@ -22,13 +25,13 @@ pub struct FrameParams {
}

impl FrameParams {
pub fn for_request(&self) -> FrameParams {
pub const fn for_request(&self) -> FrameParams {
Self {
version: self.version & 0x7F,
..*self
}
}
pub fn for_response(&self) -> FrameParams {
pub const fn for_response(&self) -> FrameParams {
Self {
version: 0x80 | (self.version & 0x7F),
..*self
Expand All @@ -48,23 +51,25 @@ pub(crate) enum FrameOpcode {
Response(ResponseOpcode),
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RequestFrame {
pub params: FrameParams,
pub opcode: RequestOpcode,
pub body: Bytes,
}

impl RequestFrame {
pub async fn write(
pub(crate) async fn write(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit: proxy: unpub RequestFrame::write

Isn't this a breaking change for scylla-proxy. Current version is 0.0.4, so the next time we release it, it should be bumped to 0.1.0, or am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that scylla-proxy used to comply with semver. Anyway, releasing a new major makes perfect sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why would we need 0.1.0? 0.0.5 is also a new major.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I wasn't that familiar with the semantics pre-1.0. So what's the difference between bumping the minor vs patch version in pre-1.0 crate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

semver says:

Major version zero (0.y.z) is for initial development. Anything MAY change at any time. The public API SHOULD NOT be considered stable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that it was just us that had the convention that in 0.y.z, y was the major and z was the minor version number.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that it was just us that had the convention that in 0.y.z, y was the major and z was the minor version number.

This is also incorrect. Official semver spec differs from the semver used in Cargo. See https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#default-requirements

Versions are considered compatible if their left-most non-zero major/minor/patch component is the same. This is different from SemVer which considers all pre-1.0.0 packages to be incompatible.

&self,
writer: &mut (impl AsyncWrite + Unpin),
compression: &CompressionReader,
) -> Result<(), tokio::io::Error> {
write_frame(
self.params,
FrameOpcode::Request(self.opcode),
&self.body,
writer,
compression,
)
.await
}
Expand All @@ -73,7 +78,7 @@ impl RequestFrame {
Request::deserialize(&mut &self.body[..], self.opcode)
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ResponseFrame {
pub params: FrameParams,
pub opcode: ResponseOpcode,
Expand Down Expand Up @@ -133,12 +138,14 @@ impl ResponseFrame {
pub(crate) async fn write(
&self,
writer: &mut (impl AsyncWrite + Unpin),
compression: &CompressionReader,
) -> Result<(), tokio::io::Error> {
write_frame(
self.params,
FrameOpcode::Response(self.opcode),
&self.body,
writer,
compression,
)
.await
}
Expand Down Expand Up @@ -230,9 +237,16 @@ fn serialize_error_specific_fields(
pub(crate) async fn write_frame(
params: FrameParams,
opcode: FrameOpcode,
body: &Bytes,
body: &[u8],
writer: &mut (impl AsyncWrite + Unpin),
compression: &CompressionReader,
) -> Result<(), tokio::io::Error> {
let compressed_body = compression
.maybe_compress_body(params.flags, body)
.map_err(|e| tokio::io::Error::new(std::io::ErrorKind::Other, e))?;

let body = compressed_body.as_deref().unwrap_or(body);

let mut header = [0; HEADER_SIZE];

header[0] = params.version;
Expand All @@ -253,7 +267,8 @@ pub(crate) async fn write_frame(
pub(crate) async fn read_frame(
reader: &mut (impl AsyncRead + Unpin),
frame_type: FrameType,
) -> Result<(FrameParams, FrameOpcode, Bytes), FrameHeaderParseError> {
compression: &CompressionReader,
) -> Result<(FrameParams, FrameOpcode, Bytes), ReadFrameError> {
let mut raw_header = [0u8; HEADER_SIZE];
reader
.read_exact(&mut raw_header[..])
Expand All @@ -269,7 +284,7 @@ pub(crate) async fn read_frame(
FrameType::Response => (FrameHeaderParseError::FrameFromClient, 0x80, "response"),
};
if version & 0x80 != valid_direction {
return Err(err);
return Err(err.into());
}
let protocol_version = version & 0x7F;
if protocol_version != 0x04 {
Expand Down Expand Up @@ -311,20 +326,22 @@ pub(crate) async fn read_frame(
.map_err(|err| FrameHeaderParseError::BodyChunkIoError(body.remaining_mut(), err))?;
if n == 0 {
// EOF, too early
return Err(FrameHeaderParseError::ConnectionClosed(
body.remaining_mut(),
length,
));
return Err(
FrameHeaderParseError::ConnectionClosed(body.remaining_mut(), length).into(),
);
}
}

Ok((frame_params, opcode, body.into_inner().into()))
let body = compression.maybe_decompress_body(flags, body.into_inner().into())?;

Ok((frame_params, opcode, body))
}

pub(crate) async fn read_request_frame(
reader: &mut (impl AsyncRead + Unpin),
) -> Result<RequestFrame, FrameHeaderParseError> {
read_frame(reader, FrameType::Request)
compression: &CompressionReader,
) -> Result<RequestFrame, ReadFrameError> {
read_frame(reader, FrameType::Request, compression)
.await
.map(|(params, opcode, body)| RequestFrame {
params,
Expand All @@ -338,8 +355,9 @@ pub(crate) async fn read_request_frame(

pub(crate) async fn read_response_frame(
reader: &mut (impl AsyncRead + Unpin),
) -> Result<ResponseFrame, FrameHeaderParseError> {
read_frame(reader, FrameType::Response)
compression: &CompressionReader,
) -> Result<ResponseFrame, ReadFrameError> {
read_frame(reader, FrameType::Response, compression)
.await
.map(|(params, opcode, body)| ResponseFrame {
params,
Expand Down
Loading