Skip to content

Commit 8379eef

Browse files
committed
proxy: compress outgoing frames if came compressed
This commit finally makes proxy compress frames, if they have the compression flag set AND the proxy already intercepted a STARTUP frame. This effectively makes the proxy re-compress frames that it decompressed when it intercepted them.
1 parent ec03223 commit 8379eef

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

scylla-proxy/src/frame.rs

+6
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ pub(crate) async fn write_frame(
241241
writer: &mut (impl AsyncWrite + Unpin),
242242
compression: &CompressionReader,
243243
) -> Result<(), tokio::io::Error> {
244+
let compressed_body = compression
245+
.maybe_compress_body(params.flags, body)
246+
.map_err(|e| tokio::io::Error::new(std::io::ErrorKind::Other, e))?;
247+
248+
let body = compressed_body.as_deref().unwrap_or(body);
249+
244250
let mut header = [0; HEADER_SIZE];
245251

246252
header[0] = params.version;

scylla-proxy/src/proxy.rs

+73-18
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ impl Doorkeeper {
816816
FrameOpcode::Request(RequestOpcode::Options),
817817
&Bytes::new(),
818818
connection,
819-
&no_compression()
819+
&no_compression(),
820820
)
821821
.await
822822
.map_err(DoorkeeperError::ObtainingShardNumber)?;
@@ -870,16 +870,29 @@ impl Doorkeeper {
870870
}
871871

872872
mod compression {
873+
use std::error::Error;
874+
use std::sync::{Arc, OnceLock};
875+
873876
use bytes::Bytes;
877+
use scylla_cql::frame::frame_errors::{
878+
CqlRequestSerializationError, FrameBodyExtensionsParseError,
879+
};
874880
use scylla_cql::frame::request::{
875881
DeserializableRequest as _, RequestDeserializationError, Startup,
876882
};
877-
use scylla_cql::frame::{
878-
decompress, frame_errors::FrameBodyExtensionsParseError, Compression, FLAG_COMPRESSION,
879-
};
883+
use scylla_cql::frame::{compress_append, decompress, Compression, FLAG_COMPRESSION};
880884
use tracing::{error, warn};
881885

882-
use std::sync::{Arc, OnceLock};
886+
#[derive(Debug, thiserror::Error)]
887+
pub(crate) enum CompressionError {
888+
/// Body Snap compression failed.
889+
#[error("Snap compression error: {0}")]
890+
SnapCompressError(Arc<dyn Error + Sync + Send>),
891+
892+
/// Frame is to be compressed, but no compression was negotiated for the connection.
893+
#[error("Frame is to be compressed, but no compression negotiated for connection.")]
894+
NoCompressionNegotiated,
895+
}
883896

884897
type CompressionInfo = Arc<OnceLock<Option<Compression>>>;
885898

@@ -932,6 +945,26 @@ mod compression {
932945
pub(crate) fn get(&self) -> Option<Option<Compression>> {
933946
self.0.get().copied()
934947
}
948+
949+
pub(crate) fn maybe_compress_body(
950+
&self,
951+
flags: u8,
952+
body: &[u8],
953+
) -> Result<Option<Bytes>, CompressionError> {
954+
match (flags & FLAG_COMPRESSION != 0, self.get().flatten()) {
955+
(true, Some(compression)) => {
956+
let mut buf = Vec::new();
957+
compress_append(body, compression, &mut buf).map_err(|err| {
958+
let CqlRequestSerializationError::SnapCompressError(err) = err else {unreachable!("BUG: compress_append returned variant different than SnapCompressError")};
959+
CompressionError::SnapCompressError(err)
960+
})?;
961+
Ok(Some(Bytes::from(buf)))
962+
}
963+
(true, None) => Err(CompressionError::NoCompressionNegotiated),
964+
(false, _) => Ok(None),
965+
}
966+
}
967+
935968
pub(crate) fn maybe_decompress_body(
936969
&self,
937970
flags: u8,
@@ -1555,7 +1588,9 @@ mod tests {
15551588
let send_frame_to_shard = async {
15561589
let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
15571590

1558-
write_frame(params, opcode, &body, &mut conn, &no_compression()).await.unwrap();
1591+
write_frame(params, opcode, &body, &mut conn, &no_compression())
1592+
.await
1593+
.unwrap();
15591594
conn
15601595
};
15611596

@@ -2092,9 +2127,15 @@ mod tests {
20922127

20932128
let mock_node_action = async {
20942129
let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2095-
write_frame(params.for_response(), response_opcode, &body, &mut conn, &no_compression())
2096-
.await
2097-
.unwrap();
2130+
write_frame(
2131+
params.for_response(),
2132+
response_opcode,
2133+
&body,
2134+
&mut conn,
2135+
&no_compression(),
2136+
)
2137+
.await
2138+
.unwrap();
20982139
conn
20992140
};
21002141

@@ -2265,7 +2306,9 @@ mod tests {
22652306

22662307
let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
22672308

2268-
write_frame(params, opcode, &body, &mut conn, &no_compression()).await.unwrap();
2309+
write_frame(params, opcode, &body, &mut conn, &no_compression())
2310+
.await
2311+
.unwrap();
22692312
// We assert that after sufficiently long time, no error happens inside proxy.
22702313
tokio::time::sleep(Duration::from_millis(3)).await;
22712314
running_proxy.finish().await.unwrap();
@@ -2456,9 +2499,15 @@ mod tests {
24562499
let socket = bind_socket_for_shard(shards_count, driver_shard).await;
24572500
let mut conn = socket.connect(node_proxy_addr).await.unwrap();
24582501

2459-
write_frame(params, request_opcode, body_ref, &mut conn, &no_compression())
2460-
.await
2461-
.unwrap();
2502+
write_frame(
2503+
params,
2504+
request_opcode,
2505+
body_ref,
2506+
&mut conn,
2507+
&no_compression(),
2508+
)
2509+
.await
2510+
.unwrap();
24622511
conn
24632512
};
24642513

@@ -2476,9 +2525,15 @@ mod tests {
24762525
&no_compression(),
24772526
)
24782527
.await;
2479-
write_frame(params.for_response(), response_opcode, body_ref, &mut conn, &no_compression())
2480-
.await
2481-
.unwrap();
2528+
write_frame(
2529+
params.for_response(),
2530+
response_opcode,
2531+
body_ref,
2532+
&mut conn,
2533+
&no_compression(),
2534+
)
2535+
.await
2536+
.unwrap();
24822537
conn
24832538
})
24842539
.collect::<Vec<_>>();
@@ -2586,7 +2641,7 @@ mod tests {
25862641
FrameOpcode::Request(req_opcode),
25872642
(body_base.to_string() + "|request|").as_bytes(),
25882643
client_socket_ref,
2589-
&no_compression()
2644+
&no_compression(),
25902645
)
25912646
.await
25922647
.unwrap();
@@ -2602,7 +2657,7 @@ mod tests {
26022657
FrameOpcode::Response(resp_opcode),
26032658
(body_base.to_string() + "|response|").as_bytes(),
26042659
server_socket_ref,
2605-
&no_compression()
2660+
&no_compression(),
26062661
)
26072662
.await
26082663
.unwrap();

0 commit comments

Comments
 (0)