Skip to content

Commit 11cca67

Browse files
committed
proxy: compress outgoing frames if came compressed
This commit finally makes proxy compress frames, if they have the compression flag set AND the proxy already intercepted a STARTUP frame. This effectively makes the proxy re-compress frames that it decompressed when it intercepted them.
1 parent 623b5f4 commit 11cca67

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

scylla-proxy/src/frame.rs

+6
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ pub(crate) async fn write_frame(
241241
writer: &mut (impl AsyncWrite + Unpin),
242242
compression: &CompressionReader,
243243
) -> Result<(), tokio::io::Error> {
244+
let compressed_body = compression
245+
.maybe_compress_body(params.flags, body)
246+
.map_err(|e| tokio::io::Error::new(std::io::ErrorKind::Other, e))?;
247+
248+
let body = compressed_body.as_deref().unwrap_or(body);
249+
244250
let mut header = [0; HEADER_SIZE];
245251

246252
header[0] = params.version;

scylla-proxy/src/proxy.rs

+73-18
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ impl Doorkeeper {
816816
FrameOpcode::Request(RequestOpcode::Options),
817817
&Bytes::new(),
818818
connection,
819-
&no_compression()
819+
&no_compression(),
820820
)
821821
.await
822822
.map_err(DoorkeeperError::ObtainingShardNumber)?;
@@ -870,16 +870,29 @@ impl Doorkeeper {
870870
}
871871

872872
mod compression {
873+
use std::error::Error;
874+
use std::sync::{Arc, OnceLock};
875+
873876
use bytes::Bytes;
877+
use scylla_cql::frame::frame_errors::{
878+
CqlRequestSerializationError, FrameBodyExtensionsParseError,
879+
};
874880
use scylla_cql::frame::request::{
875881
DeserializableRequest as _, RequestDeserializationError, Startup,
876882
};
877-
use scylla_cql::frame::{
878-
decompress, frame_errors::FrameBodyExtensionsParseError, Compression, FLAG_COMPRESSION,
879-
};
883+
use scylla_cql::frame::{compress_append, decompress, Compression, FLAG_COMPRESSION};
880884
use tracing::{error, warn};
881885

882-
use std::sync::{Arc, OnceLock};
886+
#[derive(Debug, thiserror::Error)]
887+
pub(crate) enum CompressionError {
888+
/// Body Snap compression failed.
889+
#[error("Snap compression error: {0}")]
890+
SnapCompressError(Arc<dyn Error + Sync + Send>),
891+
892+
/// Frame is to be compressed, but no compression was negotiated for the connection.
893+
#[error("Frame is to be compressed, but no compression negotiated for connection.")]
894+
NoCompressionNegotiated,
895+
}
883896

884897
type CompressionInfo = Arc<OnceLock<Option<Compression>>>;
885898

@@ -932,6 +945,26 @@ mod compression {
932945
pub(crate) fn get(&self) -> Option<Option<Compression>> {
933946
self.0.get().copied()
934947
}
948+
949+
pub(crate) fn maybe_compress_body(
950+
&self,
951+
flags: u8,
952+
body: &[u8],
953+
) -> Result<Option<Bytes>, CompressionError> {
954+
match (flags & FLAG_COMPRESSION != 0, self.get().flatten()) {
955+
(true, Some(compression)) => {
956+
let mut buf = Vec::new();
957+
compress_append(body, compression, &mut buf).map_err(|err| {
958+
let CqlRequestSerializationError::SnapCompressError(err) = err else {unreachable!("BUG: compress_append returned variant different than SnapCompressError")};
959+
CompressionError::SnapCompressError(err)
960+
})?;
961+
Ok(Some(Bytes::from(buf)))
962+
}
963+
(true, None) => Err(CompressionError::NoCompressionNegotiated),
964+
(false, _) => Ok(None),
965+
}
966+
}
967+
935968
pub(crate) fn maybe_decompress_body(
936969
&self,
937970
flags: u8,
@@ -1556,7 +1589,9 @@ mod tests {
15561589
let send_frame_to_shard = async {
15571590
let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
15581591

1559-
write_frame(params, opcode, &body, &mut conn, &no_compression()).await.unwrap();
1592+
write_frame(params, opcode, &body, &mut conn, &no_compression())
1593+
.await
1594+
.unwrap();
15601595
conn
15611596
};
15621597

@@ -2093,9 +2128,15 @@ mod tests {
20932128

20942129
let mock_node_action = async {
20952130
let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2096-
write_frame(params.for_response(), response_opcode, &body, &mut conn, &no_compression())
2097-
.await
2098-
.unwrap();
2131+
write_frame(
2132+
params.for_response(),
2133+
response_opcode,
2134+
&body,
2135+
&mut conn,
2136+
&no_compression(),
2137+
)
2138+
.await
2139+
.unwrap();
20992140
conn
21002141
};
21012142

@@ -2266,7 +2307,9 @@ mod tests {
22662307

22672308
let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
22682309

2269-
write_frame(params, opcode, &body, &mut conn, &no_compression()).await.unwrap();
2310+
write_frame(params, opcode, &body, &mut conn, &no_compression())
2311+
.await
2312+
.unwrap();
22702313
// We assert that after sufficiently long time, no error happens inside proxy.
22712314
tokio::time::sleep(Duration::from_millis(3)).await;
22722315
running_proxy.finish().await.unwrap();
@@ -2457,9 +2500,15 @@ mod tests {
24572500
let socket = bind_socket_for_shard(shards_count, driver_shard).await;
24582501
let mut conn = socket.connect(node_proxy_addr).await.unwrap();
24592502

2460-
write_frame(params, request_opcode, body_ref, &mut conn, &no_compression())
2461-
.await
2462-
.unwrap();
2503+
write_frame(
2504+
params,
2505+
request_opcode,
2506+
body_ref,
2507+
&mut conn,
2508+
&no_compression(),
2509+
)
2510+
.await
2511+
.unwrap();
24632512
conn
24642513
};
24652514

@@ -2477,9 +2526,15 @@ mod tests {
24772526
&no_compression(),
24782527
)
24792528
.await;
2480-
write_frame(params.for_response(), response_opcode, body_ref, &mut conn, &no_compression())
2481-
.await
2482-
.unwrap();
2529+
write_frame(
2530+
params.for_response(),
2531+
response_opcode,
2532+
body_ref,
2533+
&mut conn,
2534+
&no_compression(),
2535+
)
2536+
.await
2537+
.unwrap();
24832538
conn
24842539
})
24852540
.collect::<Vec<_>>();
@@ -2587,7 +2642,7 @@ mod tests {
25872642
FrameOpcode::Request(req_opcode),
25882643
(body_base.to_string() + "|request|").as_bytes(),
25892644
client_socket_ref,
2590-
&no_compression()
2645+
&no_compression(),
25912646
)
25922647
.await
25932648
.unwrap();
@@ -2603,7 +2658,7 @@ mod tests {
26032658
FrameOpcode::Response(resp_opcode),
26042659
(body_base.to_string() + "|response|").as_bytes(),
26052660
server_socket_ref,
2606-
&no_compression()
2661+
&no_compression(),
26072662
)
26082663
.await
26092664
.unwrap();

0 commit comments

Comments
 (0)