Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proxy: compresssion support #1246

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions scylla-cql/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio::io::{AsyncRead, AsyncReadExt};
use uuid::Uuid;

use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use std::{collections::HashMap, convert::TryFrom};

Expand All @@ -23,10 +24,10 @@ use response::ResponseOpcode;
const HEADER_SIZE: usize = 9;

// Frame flags
const FLAG_COMPRESSION: u8 = 0x01;
const FLAG_TRACING: u8 = 0x02;
const FLAG_CUSTOM_PAYLOAD: u8 = 0x04;
const FLAG_WARNING: u8 = 0x08;
pub const FLAG_COMPRESSION: u8 = 0x01;
pub const FLAG_TRACING: u8 = 0x02;
pub const FLAG_CUSTOM_PAYLOAD: u8 = 0x04;
pub const FLAG_WARNING: u8 = 0x08;

// All of the Authenticators supported by Scylla
#[derive(Debug, PartialEq, Eq, Clone)]
Expand Down Expand Up @@ -56,6 +57,27 @@ impl Compression {
}
}

/// Unknown compression.
#[derive(Error, Debug, Clone)]
#[error("Unknown compression: {name}")]
pub struct CompressionFromStrError {
name: String,
}

impl FromStr for Compression {
type Err = CompressionFromStrError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"lz4" => Ok(Self::Lz4),
"snappy" => Ok(Self::Snappy),
other => Err(Self::Err {
name: other.to_owned(),
}),
}
}
}

impl Display for Compression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
Expand Down Expand Up @@ -238,7 +260,7 @@ pub fn parse_response_body_extensions(
})
}

fn compress_append(
pub fn compress_append(
uncomp_body: &[u8],
compression: Compression,
out: &mut Vec<u8>,
Expand All @@ -264,7 +286,7 @@ fn compress_append(
}
}

fn decompress(
pub fn decompress(
mut comp_body: &[u8],
compression: Compression,
) -> Result<Vec<u8>, FrameBodyExtensionsParseError> {
Expand Down
14 changes: 14 additions & 0 deletions scylla-cql/src/frame/request/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::{
frame::types,
};

use super::DeserializableRequest;

pub struct Startup<'a> {
pub options: HashMap<Cow<'a, str>, Cow<'a, str>>,
}
Expand All @@ -31,3 +33,15 @@ pub enum StartupSerializationError {
#[error("Malformed startup options: {0}")]
OptionsSerialization(TryFromIntError),
}

impl DeserializableRequest for Startup<'_> {
fn deserialize(buf: &mut &[u8]) -> Result<Self, super::RequestDeserializationError> {
// Note: this is inefficient, but it's only used for tests and it's not common
// to deserialize STARTUP frames anyway.
let options = types::read_string_map(buf)?
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
Ok(Self { options })
}
}
14 changes: 12 additions & 2 deletions scylla-proxy/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use std::net::SocketAddr;

use scylla_cql::frame::frame_errors::{FrameHeaderParseError, LowLevelDeserializationError};
use scylla_cql::frame::frame_errors::{
FrameBodyExtensionsParseError, FrameHeaderParseError, LowLevelDeserializationError,
};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ReadFrameError {
#[error("Failed to read frame header: {0}")]
Header(#[from] FrameHeaderParseError),
#[error("Failed to decompress frame: {0}")]
Compression(#[from] FrameBodyExtensionsParseError),
}

#[derive(Debug, Error)]
pub enum DoorkeeperError {
#[error("Listen on {0} failed with {1}")]
Expand All @@ -20,7 +30,7 @@ pub enum DoorkeeperError {
#[error("Could not send Options frame for obtaining shards number: {0}")]
ObtainingShardNumber(std::io::Error),
#[error("Could not send read Supported frame for obtaining shards number: {0}")]
ObtainingShardNumberFrame(FrameHeaderParseError),
ObtainingShardNumberFrame(ReadFrameError),
#[error("Could not read Supported options: {0}")]
ObtainingShardNumberParseOptions(LowLevelDeserializationError),
#[error("ShardInfo parameters missing")]
Expand Down
52 changes: 35 additions & 17 deletions scylla-proxy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use tracing::warn;

use crate::errors::ReadFrameError;
use crate::proxy::CompressionReader;

const HEADER_SIZE: usize = 9;

// Parts of the frame header which are not determined by the request/response type.
Expand All @@ -22,13 +25,13 @@ pub struct FrameParams {
}

impl FrameParams {
pub fn for_request(&self) -> FrameParams {
pub const fn for_request(&self) -> FrameParams {
Self {
version: self.version & 0x7F,
..*self
}
}
pub fn for_response(&self) -> FrameParams {
pub const fn for_response(&self) -> FrameParams {
Self {
version: 0x80 | (self.version & 0x7F),
..*self
Expand All @@ -48,23 +51,25 @@ pub(crate) enum FrameOpcode {
Response(ResponseOpcode),
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RequestFrame {
pub params: FrameParams,
pub opcode: RequestOpcode,
pub body: Bytes,
}

impl RequestFrame {
pub async fn write(
pub(crate) async fn write(
&self,
writer: &mut (impl AsyncWrite + Unpin),
compression: &CompressionReader,
) -> Result<(), tokio::io::Error> {
write_frame(
self.params,
FrameOpcode::Request(self.opcode),
&self.body,
writer,
compression,
)
.await
}
Expand All @@ -73,7 +78,7 @@ impl RequestFrame {
Request::deserialize(&mut &self.body[..], self.opcode)
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ResponseFrame {
pub params: FrameParams,
pub opcode: ResponseOpcode,
Expand Down Expand Up @@ -133,12 +138,14 @@ impl ResponseFrame {
pub(crate) async fn write(
&self,
writer: &mut (impl AsyncWrite + Unpin),
compression: &CompressionReader,
) -> Result<(), tokio::io::Error> {
write_frame(
self.params,
FrameOpcode::Response(self.opcode),
&self.body,
writer,
compression,
)
.await
}
Expand Down Expand Up @@ -230,9 +237,16 @@ fn serialize_error_specific_fields(
pub(crate) async fn write_frame(
params: FrameParams,
opcode: FrameOpcode,
body: &Bytes,
body: &[u8],
writer: &mut (impl AsyncWrite + Unpin),
compression: &CompressionReader,
) -> Result<(), tokio::io::Error> {
let compressed_body = compression
.maybe_compress_body(params.flags, body)
.map_err(|e| tokio::io::Error::new(std::io::ErrorKind::Other, e))?;

let body = compressed_body.as_deref().unwrap_or(body);

let mut header = [0; HEADER_SIZE];

header[0] = params.version;
Expand All @@ -253,7 +267,8 @@ pub(crate) async fn write_frame(
pub(crate) async fn read_frame(
reader: &mut (impl AsyncRead + Unpin),
frame_type: FrameType,
) -> Result<(FrameParams, FrameOpcode, Bytes), FrameHeaderParseError> {
compression: &CompressionReader,
) -> Result<(FrameParams, FrameOpcode, Bytes), ReadFrameError> {
let mut raw_header = [0u8; HEADER_SIZE];
reader
.read_exact(&mut raw_header[..])
Expand All @@ -269,7 +284,7 @@ pub(crate) async fn read_frame(
FrameType::Response => (FrameHeaderParseError::FrameFromClient, 0x80, "response"),
};
if version & 0x80 != valid_direction {
return Err(err);
return Err(err.into());
}
let protocol_version = version & 0x7F;
if protocol_version != 0x04 {
Expand Down Expand Up @@ -311,20 +326,22 @@ pub(crate) async fn read_frame(
.map_err(|err| FrameHeaderParseError::BodyChunkIoError(body.remaining_mut(), err))?;
if n == 0 {
// EOF, too early
return Err(FrameHeaderParseError::ConnectionClosed(
body.remaining_mut(),
length,
));
return Err(
FrameHeaderParseError::ConnectionClosed(body.remaining_mut(), length).into(),
);
}
}

Ok((frame_params, opcode, body.into_inner().into()))
let body = compression.maybe_decompress_body(flags, body.into_inner().into())?;

Ok((frame_params, opcode, body))
}

pub(crate) async fn read_request_frame(
reader: &mut (impl AsyncRead + Unpin),
) -> Result<RequestFrame, FrameHeaderParseError> {
read_frame(reader, FrameType::Request)
compression: &CompressionReader,
) -> Result<RequestFrame, ReadFrameError> {
read_frame(reader, FrameType::Request, compression)
.await
.map(|(params, opcode, body)| RequestFrame {
params,
Expand All @@ -338,8 +355,9 @@ pub(crate) async fn read_request_frame(

pub(crate) async fn read_response_frame(
reader: &mut (impl AsyncRead + Unpin),
) -> Result<ResponseFrame, FrameHeaderParseError> {
read_frame(reader, FrameType::Response)
compression: &CompressionReader,
) -> Result<ResponseFrame, ReadFrameError> {
read_frame(reader, FrameType::Response, compression)
.await
.map(|(params, opcode, body)| ResponseFrame {
params,
Expand Down
Loading