diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs index 43a15b859f..1fae8ebe5c 100644 --- a/scylla-cql/src/frame/mod.rs +++ b/scylla-cql/src/frame/mod.rs @@ -14,6 +14,7 @@ use tokio::io::{AsyncRead, AsyncReadExt}; use uuid::Uuid; use std::fmt::Display; +use std::str::FromStr; use std::sync::Arc; use std::{collections::HashMap, convert::TryFrom}; @@ -23,10 +24,10 @@ use response::ResponseOpcode; const HEADER_SIZE: usize = 9; // Frame flags -const FLAG_COMPRESSION: u8 = 0x01; -const FLAG_TRACING: u8 = 0x02; -const FLAG_CUSTOM_PAYLOAD: u8 = 0x04; -const FLAG_WARNING: u8 = 0x08; +pub const FLAG_COMPRESSION: u8 = 0x01; +pub const FLAG_TRACING: u8 = 0x02; +pub const FLAG_CUSTOM_PAYLOAD: u8 = 0x04; +pub const FLAG_WARNING: u8 = 0x08; // All of the Authenticators supported by Scylla #[derive(Debug, PartialEq, Eq, Clone)] @@ -56,6 +57,27 @@ impl Compression { } } +/// Unknown compression. +#[derive(Error, Debug, Clone)] +#[error("Unknown compression: {name}")] +pub struct CompressionFromStrError { + name: String, +} + +impl FromStr for Compression { + type Err = CompressionFromStrError; + + fn from_str(s: &str) -> Result { + match s { + "lz4" => Ok(Self::Lz4), + "snappy" => Ok(Self::Snappy), + other => Err(Self::Err { + name: other.to_owned(), + }), + } + } +} + impl Display for Compression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(self.as_str()) @@ -238,7 +260,7 @@ pub fn parse_response_body_extensions( }) } -fn compress_append( +pub fn compress_append( uncomp_body: &[u8], compression: Compression, out: &mut Vec, @@ -264,7 +286,7 @@ fn compress_append( } } -fn decompress( +pub fn decompress( mut comp_body: &[u8], compression: Compression, ) -> Result, FrameBodyExtensionsParseError> { diff --git a/scylla-cql/src/frame/request/startup.rs b/scylla-cql/src/frame/request/startup.rs index cab84dc398..b702a8a25e 100644 --- a/scylla-cql/src/frame/request/startup.rs +++ b/scylla-cql/src/frame/request/startup.rs @@ -9,6 +9,8 @@ use crate::{ frame::types, }; +use super::DeserializableRequest; + pub struct Startup<'a> { pub options: HashMap, Cow<'a, str>>, } @@ -31,3 +33,15 @@ pub enum StartupSerializationError { #[error("Malformed startup options: {0}")] OptionsSerialization(TryFromIntError), } + +impl DeserializableRequest for Startup<'_> { + fn deserialize(buf: &mut &[u8]) -> Result { + // Note: this is inefficient, but it's only used for tests and it's not common + // to deserialize STARTUP frames anyway. + let options = types::read_string_map(buf)? + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + Ok(Self { options }) + } +} diff --git a/scylla-proxy/src/errors.rs b/scylla-proxy/src/errors.rs index fa1cc47d83..46bab27ee6 100644 --- a/scylla-proxy/src/errors.rs +++ b/scylla-proxy/src/errors.rs @@ -1,8 +1,18 @@ use std::net::SocketAddr; -use scylla_cql::frame::frame_errors::{FrameHeaderParseError, LowLevelDeserializationError}; +use scylla_cql::frame::frame_errors::{ + FrameBodyExtensionsParseError, FrameHeaderParseError, LowLevelDeserializationError, +}; use thiserror::Error; +#[derive(Debug, Error)] +pub enum ReadFrameError { + #[error("Failed to read frame header: {0}")] + Header(#[from] FrameHeaderParseError), + #[error("Failed to decompress frame: {0}")] + Compression(#[from] FrameBodyExtensionsParseError), +} + #[derive(Debug, Error)] pub enum DoorkeeperError { #[error("Listen on {0} failed with {1}")] @@ -20,7 +30,7 @@ pub enum DoorkeeperError { #[error("Could not send Options frame for obtaining shards number: {0}")] ObtainingShardNumber(std::io::Error), #[error("Could not send read Supported frame for obtaining shards number: {0}")] - ObtainingShardNumberFrame(FrameHeaderParseError), + ObtainingShardNumberFrame(ReadFrameError), #[error("Could not read Supported options: {0}")] ObtainingShardNumberParseOptions(LowLevelDeserializationError), #[error("ShardInfo parameters missing")] diff --git a/scylla-proxy/src/frame.rs b/scylla-proxy/src/frame.rs index 435a164863..88f782ee25 100644 --- a/scylla-proxy/src/frame.rs +++ b/scylla-proxy/src/frame.rs @@ -11,6 +11,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::warn; +use crate::errors::ReadFrameError; +use crate::proxy::CompressionReader; + const HEADER_SIZE: usize = 9; // Parts of the frame header which are not determined by the request/response type. @@ -22,13 +25,13 @@ pub struct FrameParams { } impl FrameParams { - pub fn for_request(&self) -> FrameParams { + pub const fn for_request(&self) -> FrameParams { Self { version: self.version & 0x7F, ..*self } } - pub fn for_response(&self) -> FrameParams { + pub const fn for_response(&self) -> FrameParams { Self { version: 0x80 | (self.version & 0x7F), ..*self @@ -48,7 +51,7 @@ pub(crate) enum FrameOpcode { Response(ResponseOpcode), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct RequestFrame { pub params: FrameParams, pub opcode: RequestOpcode, @@ -56,15 +59,17 @@ pub struct RequestFrame { } impl RequestFrame { - pub async fn write( + pub(crate) async fn write( &self, writer: &mut (impl AsyncWrite + Unpin), + compression: &CompressionReader, ) -> Result<(), tokio::io::Error> { write_frame( self.params, FrameOpcode::Request(self.opcode), &self.body, writer, + compression, ) .await } @@ -73,7 +78,7 @@ impl RequestFrame { Request::deserialize(&mut &self.body[..], self.opcode) } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct ResponseFrame { pub params: FrameParams, pub opcode: ResponseOpcode, @@ -133,12 +138,14 @@ impl ResponseFrame { pub(crate) async fn write( &self, writer: &mut (impl AsyncWrite + Unpin), + compression: &CompressionReader, ) -> Result<(), tokio::io::Error> { write_frame( self.params, FrameOpcode::Response(self.opcode), &self.body, writer, + compression, ) .await } @@ -230,9 +237,16 @@ fn serialize_error_specific_fields( pub(crate) async fn write_frame( params: FrameParams, opcode: FrameOpcode, - body: &Bytes, + body: &[u8], writer: &mut (impl AsyncWrite + Unpin), + compression: &CompressionReader, ) -> Result<(), tokio::io::Error> { + let compressed_body = compression + .maybe_compress_body(params.flags, body) + .map_err(|e| tokio::io::Error::new(std::io::ErrorKind::Other, e))?; + + let body = compressed_body.as_deref().unwrap_or(body); + let mut header = [0; HEADER_SIZE]; header[0] = params.version; @@ -253,7 +267,8 @@ pub(crate) async fn write_frame( pub(crate) async fn read_frame( reader: &mut (impl AsyncRead + Unpin), frame_type: FrameType, -) -> Result<(FrameParams, FrameOpcode, Bytes), FrameHeaderParseError> { + compression: &CompressionReader, +) -> Result<(FrameParams, FrameOpcode, Bytes), ReadFrameError> { let mut raw_header = [0u8; HEADER_SIZE]; reader .read_exact(&mut raw_header[..]) @@ -269,7 +284,7 @@ pub(crate) async fn read_frame( FrameType::Response => (FrameHeaderParseError::FrameFromClient, 0x80, "response"), }; if version & 0x80 != valid_direction { - return Err(err); + return Err(err.into()); } let protocol_version = version & 0x7F; if protocol_version != 0x04 { @@ -311,20 +326,22 @@ pub(crate) async fn read_frame( .map_err(|err| FrameHeaderParseError::BodyChunkIoError(body.remaining_mut(), err))?; if n == 0 { // EOF, too early - return Err(FrameHeaderParseError::ConnectionClosed( - body.remaining_mut(), - length, - )); + return Err( + FrameHeaderParseError::ConnectionClosed(body.remaining_mut(), length).into(), + ); } } - Ok((frame_params, opcode, body.into_inner().into())) + let body = compression.maybe_decompress_body(flags, body.into_inner().into())?; + + Ok((frame_params, opcode, body)) } pub(crate) async fn read_request_frame( reader: &mut (impl AsyncRead + Unpin), -) -> Result { - read_frame(reader, FrameType::Request) + compression: &CompressionReader, +) -> Result { + read_frame(reader, FrameType::Request, compression) .await .map(|(params, opcode, body)| RequestFrame { params, @@ -338,8 +355,9 @@ pub(crate) async fn read_request_frame( pub(crate) async fn read_response_frame( reader: &mut (impl AsyncRead + Unpin), -) -> Result { - read_frame(reader, FrameType::Response) + compression: &CompressionReader, +) -> Result { + read_frame(reader, FrameType::Response, compression) .await .map(|(params, opcode, body)| ResponseFrame { params, diff --git a/scylla-proxy/src/proxy.rs b/scylla-proxy/src/proxy.rs index 3f9788ea5c..d3f3969769 100644 --- a/scylla-proxy/src/proxy.rs +++ b/scylla-proxy/src/proxy.rs @@ -5,6 +5,7 @@ use crate::frame::{ }; use crate::{RequestOpcode, TargetShard}; use bytes::Bytes; +use compression::no_compression; use scylla_cql::frame::types::read_string_multimap; use std::collections::HashMap; use std::fmt::Display; @@ -633,12 +634,25 @@ impl Doorkeeper { let (tx_driver, rx_driver) = mpsc::unbounded_channel::(); let event_register_flag = Arc::new(AtomicBool::new(false)); - tokio::task::spawn(new_worker().receiver_from_driver(driver_read, tx_request)); + let ( + compression_writer_request_processor, + compression_reader_receiver_from_driver, + compression_reader_receiver_from_cluster, + compression_reader_sender_to_driver, + compression_reader_sender_to_cluster, + ) = compression::make_compression_infra(); + + tokio::task::spawn(new_worker().receiver_from_driver( + driver_read, + tx_request, + compression_reader_receiver_from_driver, + )); tokio::task::spawn(new_worker().sender_to_driver( driver_write, rx_driver, connection_close_tx.subscribe(), self.terminate_signaler.subscribe(), + compression_reader_sender_to_driver, )); tokio::task::spawn(new_worker().request_processor( rx_request, @@ -648,6 +662,7 @@ impl Doorkeeper { self.node.request_rules().clone(), connection_close_tx.clone(), event_register_flag.clone(), + compression_writer_request_processor, )); if let InternalNode::Real { ref response_rules, .. @@ -659,8 +674,13 @@ impl Doorkeeper { rx_cluster, connection_close_tx.subscribe(), self.terminate_signaler.subscribe(), + compression_reader_sender_to_cluster, + )); + tokio::task::spawn(new_worker().receiver_from_cluster( + cluster_read, + tx_response, + compression_reader_receiver_from_cluster, )); - tokio::task::spawn(new_worker().receiver_from_cluster(cluster_read, tx_response)); tokio::task::spawn(new_worker().response_processor( rx_response, tx_driver, @@ -796,11 +816,12 @@ impl Doorkeeper { FrameOpcode::Request(RequestOpcode::Options), &Bytes::new(), connection, + &no_compression(), ) .await .map_err(DoorkeeperError::ObtainingShardNumber)?; - let supported_frame = read_response_frame(connection) + let supported_frame = read_response_frame(connection, &compression::no_compression()) .await .map_err(DoorkeeperError::ObtainingShardNumberFrame)?; @@ -848,6 +869,153 @@ impl Doorkeeper { } } +mod compression { + use std::error::Error; + use std::sync::{Arc, OnceLock}; + + use bytes::Bytes; + use scylla_cql::frame::frame_errors::{ + CqlRequestSerializationError, FrameBodyExtensionsParseError, + }; + use scylla_cql::frame::request::{ + DeserializableRequest as _, RequestDeserializationError, Startup, + }; + use scylla_cql::frame::{compress_append, decompress, Compression, FLAG_COMPRESSION}; + use tracing::{error, warn}; + + #[derive(Debug, thiserror::Error)] + pub(crate) enum CompressionError { + /// Body Snap compression failed. + #[error("Snap compression error: {0}")] + SnapCompressError(Arc), + + /// Frame is to be compressed, but no compression was negotiated for the connection. + #[error("Frame is to be compressed, but no compression negotiated for connection.")] + NoCompressionNegotiated, + } + + type CompressionInfo = Arc>>; + + /// The write end of compression config for a connection. + /// + /// Used by the request processor upon STARTUP frame captured + /// and compression setting retrieved from it. + #[derive(Debug, Clone)] + pub(crate) struct CompressionWriter(CompressionInfo); + impl CompressionWriter { + pub(crate) fn set( + &self, + compression: Option, + ) -> Result<(), Option> { + self.0.set(compression) + } + + pub(crate) fn set_from_startup( + &self, + mut body: &[u8], + ) -> Result, RequestDeserializationError> { + let startup = Startup::deserialize(&mut body)?; + let maybe_compression = startup + .options + .get(scylla_cql::frame::request::options::COMPRESSION); + let maybe_compression = maybe_compression.and_then(|compression| { + compression + .parse::() + .inspect_err(|err| error!("STARTUP compression error: {}", err)) + .ok() + }); + let _ = self.set(maybe_compression).inspect_err(|_| { + warn!("Captured second or further STARTUP frame on the same connection") + }); + + Ok(maybe_compression) + } + } + + /// The read end of compression config for a connection. + /// + /// Used by frame (de)serializers. + #[derive(Debug, Clone)] + pub(crate) struct CompressionReader(CompressionInfo); + impl CompressionReader { + /// Return the compression negotiated for the connection. + /// + /// Outer Option signifies whether the negotiation took places, + /// inner Option is the compression (or lack of it) negotiated. + pub(crate) fn get(&self) -> Option> { + self.0.get().copied() + } + + pub(crate) fn maybe_compress_body( + &self, + flags: u8, + body: &[u8], + ) -> Result, CompressionError> { + match (flags & FLAG_COMPRESSION != 0, self.get().flatten()) { + (true, Some(compression)) => { + let mut buf = Vec::new(); + compress_append(body, compression, &mut buf).map_err(|err| { + let CqlRequestSerializationError::SnapCompressError(err) = err else {unreachable!("BUG: compress_append returned variant different than SnapCompressError")}; + CompressionError::SnapCompressError(err) + })?; + Ok(Some(Bytes::from(buf))) + } + (true, None) => Err(CompressionError::NoCompressionNegotiated), + (false, _) => Ok(None), + } + } + + pub(crate) fn maybe_decompress_body( + &self, + flags: u8, + body: Bytes, + ) -> Result { + match (flags & FLAG_COMPRESSION != 0, self.get().flatten()) { + (true, Some(compression)) => decompress(&body, compression).map(Into::into), + (true, None) => Err(FrameBodyExtensionsParseError::NoCompressionNegotiated), + (false, _) => Ok(body), + } + } + } + + pub(crate) fn make_compression_infra() -> ( + CompressionWriter, + CompressionReader, + CompressionReader, + CompressionReader, + CompressionReader, + ) { + let info = Arc::new(OnceLock::new()); + ( + CompressionWriter(info.clone()), + CompressionReader(info.clone()), + CompressionReader(info.clone()), + CompressionReader(info.clone()), + CompressionReader(info), + ) + } + + fn mock_compression_reader(compression: Option) -> CompressionReader { + CompressionReader(Arc::new({ + let once = OnceLock::new(); + once.set(compression).unwrap(); + once + })) + } + + // Compression explicitly turned off. + pub(crate) fn no_compression() -> CompressionReader { + mock_compression_reader(None) + } + + // Compression explicitly turned on. + #[cfg(test)] // Currently only used for tests. + pub(crate) fn with_compression(compression: Compression) -> CompressionReader { + mock_compression_reader(Some(compression)) + } +} +pub(crate) use compression::{CompressionReader, CompressionWriter}; + struct ProxyWorker { terminate_notifier: TerminateNotifier, finish_guard: FinishGuard, @@ -896,13 +1064,14 @@ impl ProxyWorker { self, mut read_half: (impl AsyncRead + Unpin), request_processor_tx: mpsc::UnboundedSender, + compression: CompressionReader, ) { let shard = self.shard; self.run_until_interrupted( "receiver_from_driver", |driver_addr, proxy_addr, _real_addr| async move { loop { - let frame = frame::read_request_frame(&mut read_half) + let frame = frame::read_request_frame(&mut read_half, &compression) .await .map_err(|err| { warn!("Request reception from {} error: {}", driver_addr, err); @@ -930,6 +1099,7 @@ impl ProxyWorker { self, mut read_half: (impl AsyncRead + Unpin), response_processor_tx: mpsc::UnboundedSender, + compression: CompressionReader, ) { let shard = self.shard; self.run_until_interrupted( @@ -937,19 +1107,18 @@ impl ProxyWorker { |driver_addr, _proxy_addr, real_addr| async move { let real_addr = real_addr.expect("BUG: no real_addr in cluster worker"); loop { - let frame = - frame::read_response_frame(&mut read_half) - .await - .map_err(|err| { - warn!("Response reception from {} error: {}", real_addr, err); - WorkerError::NodeDisconnected(real_addr) - })?; + let frame = frame::read_response_frame(&mut read_half, &compression) + .await + .map_err(|err| { + warn!("Response reception from {} error: {}", real_addr, err); + WorkerError::NodeDisconnected(real_addr) + })?; debug!( - "Intercepted Cluster ({}) -> Driver ({}) ({}) frame. opcode: {:?}.", + "Intercepted Cluster ({}) ({}) -> Driver ({}) frame. opcode: {:?}.", real_addr, - driver_addr, DisplayableShard(shard), + driver_addr, &frame.opcode ); @@ -969,6 +1138,7 @@ impl ProxyWorker { mut responses_rx: mpsc::UnboundedReceiver, mut connection_close_notifier: ConnectionCloseNotifier, mut terminate_notifier: TerminateNotifier, + compression: CompressionReader, ) { let shard = self.shard; self.run_until_interrupted( @@ -988,13 +1158,13 @@ impl ProxyWorker { }; debug!( - "Sending Proxy ({}) -> Driver ({}) ({}) frame. opcode: {:?}.", + "Sending Proxy ({}) ({}) -> Driver ({}) frame. opcode: {:?}.", proxy_addr, - driver_addr, DisplayableShard(shard), + driver_addr, &response.opcode ); - if response.write(&mut write_half).await.is_err() { + if response.write(&mut write_half, &compression).await.is_err() { if terminate_notifier.try_recv().is_err() && connection_close_notifier.try_recv().is_err() { @@ -1015,6 +1185,7 @@ impl ProxyWorker { mut requests_rx: mpsc::UnboundedReceiver, mut connection_close_notifier: ConnectionCloseNotifier, mut terminate_notifier: TerminateNotifier, + compression: CompressionReader, ) { let shard = self.shard; self.run_until_interrupted( @@ -1042,7 +1213,7 @@ impl ProxyWorker { &request.opcode ); - if request.write(&mut write_half).await.is_err() { + if request.write(&mut write_half, &compression).await.is_err() { if terminate_notifier.try_recv().is_err() && connection_close_notifier.try_recv().is_err() { @@ -1067,6 +1238,7 @@ impl ProxyWorker { request_rules: Arc>>, connection_close_signaler: ConnectionCloseSignaler, event_registered_flag: Arc, + compression: CompressionWriter, ) { let shard = self.shard; self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move { @@ -1075,7 +1247,19 @@ impl ProxyWorker { Some(request) => { if request.opcode == RequestOpcode::Register { event_registered_flag.store(true, Ordering::Relaxed); + } else if request.opcode == RequestOpcode::Startup { + match compression.set_from_startup(&request.body) { + Err(err) => error!("Failed to deserialize STARTUP frame: {}", err), + Ok(read_compression) => info!( + "Intercepted STARTUP frame ({} -> {} ({})), so set compression accordingly to {:?}.", + driver_addr, + DisplayableRealAddrOption(real_addr), + DisplayableShard(shard), + read_compression + ) + }; } + let ctx = EvaluationContext { connection_seq_no: connection_no, opcode: FrameOpcode::Request(request.opcode), @@ -1293,8 +1477,11 @@ pub fn get_exclusive_local_address() -> IpAddr { #[cfg(test)] mod tests { + use super::compression::no_compression; use super::*; + use crate::errors::ReadFrameError; use crate::frame::{read_frame, read_request_frame, FrameType}; + use crate::proxy::compression::with_compression; use crate::{ setup_tracing, Condition, Reaction as _, RequestReaction, ResponseOpcode, ResponseReaction, }; @@ -1302,8 +1489,10 @@ mod tests { use bytes::{BufMut, BytesMut}; use futures::future::{join, join3}; use rand::RngCore; - use scylla_cql::frame::frame_errors::FrameHeaderParseError; + use scylla_cql::frame::request::options::COMPRESSION; + use scylla_cql::frame::request::{SerializableRequest as _, Startup}; use scylla_cql::frame::types::write_string_multimap; + use scylla_cql::frame::{Compression, FLAG_COMPRESSION}; use std::collections::HashMap; use std::mem; use std::str::FromStr; @@ -1321,12 +1510,13 @@ mod tests { async fn respond_with_supported( conn: &mut TcpStream, supported_options: &HashMap>, + compression: &CompressionReader, ) { let RequestFrame { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_request_frame(conn).await.unwrap(); + } = read_request_frame(conn, compression).await.unwrap(); assert_eq!(recvd_params, HARDCODED_OPTIONS_PARAMS); assert_eq!(recvd_opcode, RequestOpcode::Options); assert_eq!(recvd_body, Bytes::new()); // body should be empty @@ -1341,6 +1531,7 @@ mod tests { FrameOpcode::Response(ResponseOpcode::Supported), &body, conn, + &no_compression(), ) .await .unwrap(); @@ -1361,12 +1552,20 @@ mod tests { sharded_info } - async fn respond_with_shards_count(conn: &mut TcpStream, shards_count: u16) { - respond_with_supported(conn, &supported_shards_count(shards_count)).await; + async fn respond_with_shards_count( + conn: &mut TcpStream, + shards_count: u16, + compression: &CompressionReader, + ) { + respond_with_supported(conn, &supported_shards_count(shards_count), compression).await; } - async fn respond_with_shard_num(conn: &mut TcpStream, shard_num: TargetShard) { - respond_with_supported(conn, &supported_shard_number(shard_num)).await; + async fn respond_with_shard_num( + conn: &mut TcpStream, + shard_num: TargetShard, + compression: &CompressionReader, + ) { + respond_with_supported(conn, &supported_shard_number(shard_num), compression).await; } fn next_local_address_with_port(port: u16) -> SocketAddr { @@ -1399,24 +1598,32 @@ mod tests { let send_frame_to_shard = async { let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap(); - write_frame(params, opcode, &body, &mut conn).await.unwrap(); + write_frame(params, opcode, &body, &mut conn, &no_compression()) + .await + .unwrap(); conn }; let mock_node_action = async { if let ShardAwareness::QueryNode = shard_awareness { - respond_with_shards_count(&mut mock_node_listener.accept().await.unwrap().0, 1) - .await; + respond_with_shards_count( + &mut mock_node_listener.accept().await.unwrap().0, + 1, + &no_compression(), + ) + .await; } let (mut conn, _) = mock_node_listener.accept().await.unwrap(); if shard_awareness.is_aware() { - respond_with_shard_num(&mut conn, 1).await; + respond_with_shard_num(&mut conn, 1, &no_compression()).await; } let RequestFrame { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_request_frame(&mut conn).await.unwrap(); + } = read_request_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params); assert_eq!(FrameOpcode::Request(recvd_opcode), opcode); assert_eq!(recvd_body, body); @@ -1527,6 +1734,7 @@ mod tests { respond_with_shards_count( &mut mock_node_listener.accept().await.unwrap().0, shards_num, + &no_compression(), ) .await; let (conn, remote_addr) = mock_node_listener.accept().await.unwrap(); @@ -1595,7 +1803,7 @@ mod tests { let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap(); let params1 = FrameParams { - flags: 3, + flags: 2, version: 0x42, stream: 42, }; @@ -1627,13 +1835,13 @@ mod tests { let send_frame_to_shard = async { let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap(); - write_frame(params1, opcode1, &body1, &mut conn) + write_frame(params1, opcode1, &body1, &mut conn, &no_compression()) .await .unwrap(); - write_frame(params2, opcode2, &body2, &mut conn) + write_frame(params2, opcode2, &body2, &mut conn, &no_compression()) .await .unwrap(); - write_frame(params3, opcode3, &body3, &mut conn) + write_frame(params3, opcode3, &body3, &mut conn, &no_compression()) .await .unwrap(); @@ -1641,7 +1849,9 @@ mod tests { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_response_frame(&mut conn).await.unwrap(); + } = read_response_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params1.for_response()); assert_eq!(recvd_opcode, ResponseOpcode::Ready); assert_eq!(recvd_body, Bytes::new()); @@ -1650,7 +1860,9 @@ mod tests { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_response_frame(&mut conn).await.unwrap(); + } = read_response_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params2.for_response()); assert_eq!(recvd_opcode, ResponseOpcode::Event); assert_eq!(recvd_body, Bytes::from_static(test_msg)); @@ -1664,7 +1876,9 @@ mod tests { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_request_frame(&mut conn).await.unwrap(); + } = read_request_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params3); assert_eq!(FrameOpcode::Request(recvd_opcode), opcode3); assert_eq!(recvd_body, body3); @@ -1721,10 +1935,10 @@ mod tests { params: FrameParams, opcode: FrameOpcode, body: &Bytes, - ) -> Result { + ) -> Result { let (send_res, recv_res) = join( - write_frame(params, opcode, &body.clone(), driver), - read_request_frame(node), + write_frame(params, opcode, &body.clone(), driver, &no_compression()), + read_request_frame(node, &no_compression()), ) .await; send_res.unwrap(); @@ -1836,10 +2050,10 @@ mod tests { params: FrameParams, opcode: FrameOpcode, body: &Bytes, - ) -> Result { + ) -> Result { let (send_res, recv_res) = join( - write_frame(params, opcode, &body.clone(), driver), - read_request_frame(node), + write_frame(params, opcode, &body.clone(), driver, &no_compression()), + read_request_frame(node, &no_compression()), ) .await; send_res.unwrap(); @@ -1915,7 +2129,7 @@ mod tests { let send_frame_to_shard = async { let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap(); - write_frame(params, request_opcode, &body, &mut conn) + write_frame(params, request_opcode, &body, &mut conn, &no_compression()) .await .unwrap(); conn @@ -1923,9 +2137,15 @@ mod tests { let mock_node_action = async { let (mut conn, _) = mock_node_listener.accept().await.unwrap(); - write_frame(params.for_response(), response_opcode, &body, &mut conn) - .await - .unwrap(); + write_frame( + params.for_response(), + response_opcode, + &body, + &mut conn, + &no_compression(), + ) + .await + .unwrap(); conn }; @@ -2008,7 +2228,7 @@ mod tests { let node1_real_addr = next_local_address_with_port(9876); let node1_proxy_addr = next_local_address_with_port(9876); - let delay = Duration::from_millis(30); + let delay = Duration::from_millis(60); let proxy = Proxy::new([Node::new( node1_real_addr, @@ -2045,10 +2265,10 @@ mod tests { let send_frame_to_shard = async { let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap(); - write_frame(params1, opcode1, &body1, &mut conn) + write_frame(params1, opcode1, &body1, &mut conn, &no_compression()) .await .unwrap(); - write_frame(params2, opcode2, &body2, &mut conn) + write_frame(params2, opcode2, &body2, &mut conn, &no_compression()) .await .unwrap(); conn @@ -2060,7 +2280,9 @@ mod tests { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_request_frame(&mut conn).await.unwrap(); + } = read_request_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params2); assert_eq!(FrameOpcode::Request(recvd_opcode), opcode2); assert_eq!(recvd_body, body2); @@ -2094,7 +2316,9 @@ mod tests { let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap(); - write_frame(params, opcode, &body, &mut conn).await.unwrap(); + write_frame(params, opcode, &body, &mut conn, &no_compression()) + .await + .unwrap(); // We assert that after sufficiently long time, no error happens inside proxy. tokio::time::sleep(Duration::from_millis(3)).await; running_proxy.finish().await.unwrap(); @@ -2141,7 +2365,7 @@ mod tests { let running_proxy = proxy.run().await.unwrap(); let params1 = FrameParams { - flags: 3, + flags: 2, version: 0x42, stream: 42, }; @@ -2172,13 +2396,13 @@ mod tests { let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap(); - write_frame(params1, opcode1, &body1, &mut conn) + write_frame(params1, opcode1, &body1, &mut conn, &no_compression()) .await .unwrap(); - write_frame(params2, opcode2, &body2, &mut conn) + write_frame(params2, opcode2, &body2, &mut conn, &no_compression()) .await .unwrap(); - write_frame(params3, opcode3, &body3, &mut conn) + write_frame(params3, opcode3, &body3, &mut conn, &no_compression()) .await .unwrap(); @@ -2186,7 +2410,9 @@ mod tests { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_response_frame(&mut conn).await.unwrap(); + } = read_response_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params1.for_response()); assert_eq!(recvd_opcode, ResponseOpcode::Ready); assert_eq!(recvd_body, Bytes::new()); @@ -2195,7 +2421,9 @@ mod tests { params: recvd_params, opcode: recvd_opcode, body: recvd_body, - } = read_response_frame(&mut conn).await.unwrap(); + } = read_response_frame(&mut conn, &no_compression()) + .await + .unwrap(); assert_eq!(recvd_params, params2.for_response()); assert_eq!(recvd_opcode, ResponseOpcode::Event); assert_eq!(recvd_body, Bytes::from_static(test_msg)); @@ -2281,9 +2509,15 @@ mod tests { let socket = bind_socket_for_shard(shards_count, driver_shard).await; let mut conn = socket.connect(node_proxy_addr).await.unwrap(); - write_frame(params, request_opcode, body_ref, &mut conn) - .await - .unwrap(); + write_frame( + params, + request_opcode, + body_ref, + &mut conn, + &no_compression(), + ) + .await + .unwrap(); conn }; @@ -2295,10 +2529,21 @@ mod tests { let mut conns_futs = (0..2) .map(|_| async { let (mut conn, driver_addr) = mock_node_listener.accept().await.unwrap(); - respond_with_shard_num(&mut conn, driver_addr.port() % shards_count).await; - write_frame(params.for_response(), response_opcode, body_ref, &mut conn) - .await - .unwrap(); + respond_with_shard_num( + &mut conn, + driver_addr.port() % shards_count, + &no_compression(), + ) + .await; + write_frame( + params.for_response(), + response_opcode, + body_ref, + &mut conn, + &no_compression(), + ) + .await + .unwrap(); conn }) .collect::>(); @@ -2404,29 +2649,33 @@ mod tests { write_frame( params, FrameOpcode::Request(req_opcode), - &(body_base.to_string() + "|request|").into(), + (body_base.to_string() + "|request|").as_bytes(), client_socket_ref, + &no_compression(), ) .await .unwrap(); - let received_request = read_frame(server_socket_ref, FrameType::Request) - .await - .unwrap(); + let received_request = + read_frame(server_socket_ref, FrameType::Request, &no_compression()) + .await + .unwrap(); assert_eq!(received_request.1, FrameOpcode::Request(req_opcode)); write_frame( params.for_response(), FrameOpcode::Response(resp_opcode), - &(body_base.to_string() + "|response|").into(), + (body_base.to_string() + "|response|").as_bytes(), server_socket_ref, + &no_compression(), ) .await .unwrap(); - let received_response = read_frame(client_socket_ref, FrameType::Response) - .await - .unwrap(); + let received_response = + read_frame(client_socket_ref, FrameType::Response, &no_compression()) + .await + .unwrap(); assert_eq!(received_response.1, FrameOpcode::Response(resp_opcode)); } @@ -2476,4 +2725,193 @@ mod tests { let _ = request_feedback_rx.try_recv().unwrap_err(); let _ = response_feedback_rx.try_recv().unwrap_err(); } + + #[tokio::test] + #[ntest::timeout(1000)] + async fn proxy_compresses_and_decompresses_frames_iff_compression_negociated() { + setup_tracing(); + let node1_real_addr = next_local_address_with_port(9876); + let node1_proxy_addr = next_local_address_with_port(9876); + + let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel(); + let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel(); + let proxy = Proxy::builder() + .with_node( + Node::builder() + .real_address(node1_real_addr) + .proxy_address(node1_proxy_addr) + .shard_awareness(ShardAwareness::Unaware) + .request_rules(vec![RequestRule( + Condition::True, + RequestReaction::noop().with_feedback_when_performed(request_feedback_tx), + )]) + .response_rules(vec![ResponseRule( + Condition::True, + ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx), + )]) + .build(), + ) + .build(); + let running_proxy = proxy.run().await.unwrap(); + + let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap(); + + const PARAMS_REQUEST_NO_COMPRESSION: FrameParams = FrameParams { + flags: 0, + version: 0x04, + stream: 0, + }; + const PARAMS_REQUEST_COMPRESSION: FrameParams = FrameParams { + flags: FLAG_COMPRESSION, + ..PARAMS_REQUEST_NO_COMPRESSION + }; + const PARAMS_RESPONSE_NO_COMPRESSION: FrameParams = + PARAMS_REQUEST_NO_COMPRESSION.for_response(); + const PARAMS_RESPONSE_COMPRESSION: FrameParams = + PARAMS_REQUEST_NO_COMPRESSION.for_response(); + + let make_driver_conn = async { TcpStream::connect(node1_proxy_addr).await.unwrap() }; + let make_node_conn = async { mock_node_listener.accept().await.unwrap() }; + + let (mut driver_conn, (mut node_conn, _)) = join(make_driver_conn, make_node_conn).await; + + /* Outline of the test: + * 1. "driver" sends an, uncompressed, e.g., QUERY frame, feedback returns its uncompressed body, + * and "node" receives the uncompressed frame. + * 2. "node" responds with an uncompressed RESULT frame, feedback returns its uncompressed body, + * and "driver" receives the uncompressed frame. + * 3. "driver" sends an uncompressed STARTUP frame, feedback returns its uncompressed body, + * and "node" receives the uncompressed frame. + * 4. "driver" sends a compressed, e.g., QUERY frame, feedback returns its uncompressed body, + * and "node" receives the compressed frame. + * 5. "node" responds with a compressed RESULT frame, feedback returns its uncompressed body, + * and "driver" receives the compressed frame. + */ + + // 1. "driver" sends an, uncompressed, e.g., QUERY frame, feedback returns its uncompressed body, + // and "node" receives the uncompressed frame. + { + let sent_frame = RequestFrame { + params: PARAMS_REQUEST_NO_COMPRESSION, + opcode: RequestOpcode::Query, + body: random_body(), + }; + + sent_frame + .write(&mut driver_conn, &no_compression()) + .await + .unwrap(); + + let (captured_frame, _) = request_feedback_rx.recv().await.unwrap(); + assert_eq!(captured_frame, sent_frame); + + let received_frame = read_request_frame(&mut node_conn, &no_compression()) + .await + .unwrap(); + assert_eq!(received_frame, sent_frame); + } + + // 2. "node" responds with an uncompressed RESULT frame, feedback returns its uncompressed body, + // and "driver" receives the uncompressed frame. + { + let sent_frame = ResponseFrame { + params: PARAMS_RESPONSE_NO_COMPRESSION, + opcode: ResponseOpcode::Result, + body: random_body(), + }; + + sent_frame + .write(&mut node_conn, &no_compression()) + .await + .unwrap(); + + let (captured_frame, _) = response_feedback_rx.recv().await.unwrap(); + assert_eq!(captured_frame, sent_frame); + + let received_frame = read_response_frame(&mut driver_conn, &no_compression()) + .await + .unwrap(); + assert_eq!(received_frame, sent_frame); + } + + // 3. "driver" sends an uncompressed STARTUP frame, feedback returns its uncompressed body, + // and "node" receives the uncompressed frame. + { + let startup_body = Startup { + options: std::iter::once((COMPRESSION.into(), Compression::Lz4.as_str().into())) + .collect(), + } + .to_bytes() + .unwrap(); + + let sent_frame = RequestFrame { + params: PARAMS_REQUEST_NO_COMPRESSION, + opcode: RequestOpcode::Startup, + body: startup_body, + }; + + sent_frame + .write(&mut driver_conn, &no_compression()) + .await + .unwrap(); + + let (captured_frame, _) = request_feedback_rx.recv().await.unwrap(); + assert_eq!(captured_frame, sent_frame); + + let received_frame = read_request_frame(&mut node_conn, &no_compression()) + .await + .unwrap(); + assert_eq!(received_frame, sent_frame); + } + + // 4. "driver" sends a compressed, e.g., QUERY frame, feedback returns its uncompressed body, + // and "node" receives the compressed frame. + { + let sent_frame = RequestFrame { + params: PARAMS_REQUEST_COMPRESSION, + opcode: RequestOpcode::Query, + body: random_body(), + }; + + sent_frame + .write(&mut driver_conn, &with_compression(Compression::Lz4)) + .await + .unwrap(); + + let (captured_frame, _) = request_feedback_rx.recv().await.unwrap(); + assert_eq!(captured_frame, sent_frame); + + let received_frame = + read_request_frame(&mut node_conn, &with_compression(Compression::Lz4)) + .await + .unwrap(); + assert_eq!(received_frame, sent_frame); + } + + // 5. "node" responds with a compressed RESULT frame, feedback returns its uncompressed body, + // and "driver" receives the compressed frame. + { + let sent_frame = ResponseFrame { + params: PARAMS_RESPONSE_COMPRESSION, + opcode: ResponseOpcode::Result, + body: random_body(), + }; + + sent_frame + .write(&mut node_conn, &with_compression(Compression::Lz4)) + .await + .unwrap(); + + let (captured_frame, _) = response_feedback_rx.recv().await.unwrap(); + assert_eq!(captured_frame, sent_frame); + + let received_frame = + read_response_frame(&mut driver_conn, &with_compression(Compression::Lz4)) + .await + .unwrap(); + assert_eq!(received_frame, sent_frame); + } + + running_proxy.finish().await.unwrap(); + } }