@@ -816,7 +816,7 @@ impl Doorkeeper {
816
816
FrameOpcode :: Request ( RequestOpcode :: Options ) ,
817
817
& Bytes :: new ( ) ,
818
818
connection,
819
- & no_compression ( )
819
+ & no_compression ( ) ,
820
820
)
821
821
. await
822
822
. map_err ( DoorkeeperError :: ObtainingShardNumber ) ?;
@@ -870,16 +870,29 @@ impl Doorkeeper {
870
870
}
871
871
872
872
mod compression {
873
+ use std:: error:: Error ;
874
+ use std:: sync:: { Arc , OnceLock } ;
875
+
873
876
use bytes:: Bytes ;
877
+ use scylla_cql:: frame:: frame_errors:: {
878
+ CqlRequestSerializationError , FrameBodyExtensionsParseError ,
879
+ } ;
874
880
use scylla_cql:: frame:: request:: {
875
881
DeserializableRequest as _, RequestDeserializationError , Startup ,
876
882
} ;
877
- use scylla_cql:: frame:: {
878
- decompress, frame_errors:: FrameBodyExtensionsParseError , Compression , FLAG_COMPRESSION ,
879
- } ;
883
+ use scylla_cql:: frame:: { compress_append, decompress, Compression , FLAG_COMPRESSION } ;
880
884
use tracing:: { error, warn} ;
881
885
882
- use std:: sync:: { Arc , OnceLock } ;
886
+ #[ derive( Debug , thiserror:: Error ) ]
887
+ pub ( crate ) enum CompressionError {
888
+ /// Body Snap compression failed.
889
+ #[ error( "Snap compression error: {0}" ) ]
890
+ SnapCompressError ( Arc < dyn Error + Sync + Send > ) ,
891
+
892
+ /// Frame is to be compressed, but no compression was negotiated for the connection.
893
+ #[ error( "Frame is to be compressed, but no compression negotiated for connection." ) ]
894
+ NoCompressionNegotiated ,
895
+ }
883
896
884
897
type CompressionInfo = Arc < OnceLock < Option < Compression > > > ;
885
898
@@ -932,6 +945,26 @@ mod compression {
932
945
pub ( crate ) fn get ( & self ) -> Option < Option < Compression > > {
933
946
self . 0 . get ( ) . copied ( )
934
947
}
948
+
949
+ pub ( crate ) fn maybe_compress_body (
950
+ & self ,
951
+ flags : u8 ,
952
+ body : & [ u8 ] ,
953
+ ) -> Result < Option < Bytes > , CompressionError > {
954
+ match ( flags & FLAG_COMPRESSION != 0 , self . get ( ) . flatten ( ) ) {
955
+ ( true , Some ( compression) ) => {
956
+ let mut buf = Vec :: new ( ) ;
957
+ compress_append ( body, compression, & mut buf) . map_err ( |err| {
958
+ let CqlRequestSerializationError :: SnapCompressError ( err) = err else { unreachable ! ( "BUG: compress_append returned variant different than SnapCompressError" ) } ;
959
+ CompressionError :: SnapCompressError ( err)
960
+ } ) ?;
961
+ Ok ( Some ( Bytes :: from ( buf) ) )
962
+ }
963
+ ( true , None ) => Err ( CompressionError :: NoCompressionNegotiated ) ,
964
+ ( false , _) => Ok ( None ) ,
965
+ }
966
+ }
967
+
935
968
pub ( crate ) fn maybe_decompress_body (
936
969
& self ,
937
970
flags : u8 ,
@@ -1555,7 +1588,9 @@ mod tests {
1555
1588
let send_frame_to_shard = async {
1556
1589
let mut conn = TcpStream :: connect ( node1_proxy_addr) . await . unwrap ( ) ;
1557
1590
1558
- write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) ) . await . unwrap ( ) ;
1591
+ write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) )
1592
+ . await
1593
+ . unwrap ( ) ;
1559
1594
conn
1560
1595
} ;
1561
1596
@@ -2092,9 +2127,15 @@ mod tests {
2092
2127
2093
2128
let mock_node_action = async {
2094
2129
let ( mut conn, _) = mock_node_listener. accept ( ) . await . unwrap ( ) ;
2095
- write_frame ( params. for_response ( ) , response_opcode, & body, & mut conn, & no_compression ( ) )
2096
- . await
2097
- . unwrap ( ) ;
2130
+ write_frame (
2131
+ params. for_response ( ) ,
2132
+ response_opcode,
2133
+ & body,
2134
+ & mut conn,
2135
+ & no_compression ( ) ,
2136
+ )
2137
+ . await
2138
+ . unwrap ( ) ;
2098
2139
conn
2099
2140
} ;
2100
2141
@@ -2265,7 +2306,9 @@ mod tests {
2265
2306
2266
2307
let mut conn = TcpStream :: connect ( node1_proxy_addr) . await . unwrap ( ) ;
2267
2308
2268
- write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) ) . await . unwrap ( ) ;
2309
+ write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) )
2310
+ . await
2311
+ . unwrap ( ) ;
2269
2312
// We assert that after sufficiently long time, no error happens inside proxy.
2270
2313
tokio:: time:: sleep ( Duration :: from_millis ( 3 ) ) . await ;
2271
2314
running_proxy. finish ( ) . await . unwrap ( ) ;
@@ -2456,9 +2499,15 @@ mod tests {
2456
2499
let socket = bind_socket_for_shard ( shards_count, driver_shard) . await ;
2457
2500
let mut conn = socket. connect ( node_proxy_addr) . await . unwrap ( ) ;
2458
2501
2459
- write_frame ( params, request_opcode, body_ref, & mut conn, & no_compression ( ) )
2460
- . await
2461
- . unwrap ( ) ;
2502
+ write_frame (
2503
+ params,
2504
+ request_opcode,
2505
+ body_ref,
2506
+ & mut conn,
2507
+ & no_compression ( ) ,
2508
+ )
2509
+ . await
2510
+ . unwrap ( ) ;
2462
2511
conn
2463
2512
} ;
2464
2513
@@ -2476,9 +2525,15 @@ mod tests {
2476
2525
& no_compression ( ) ,
2477
2526
)
2478
2527
. await ;
2479
- write_frame ( params. for_response ( ) , response_opcode, body_ref, & mut conn, & no_compression ( ) )
2480
- . await
2481
- . unwrap ( ) ;
2528
+ write_frame (
2529
+ params. for_response ( ) ,
2530
+ response_opcode,
2531
+ body_ref,
2532
+ & mut conn,
2533
+ & no_compression ( ) ,
2534
+ )
2535
+ . await
2536
+ . unwrap ( ) ;
2482
2537
conn
2483
2538
} )
2484
2539
. collect :: < Vec < _ > > ( ) ;
@@ -2586,7 +2641,7 @@ mod tests {
2586
2641
FrameOpcode :: Request ( req_opcode) ,
2587
2642
( body_base. to_string ( ) + "|request|" ) . as_bytes ( ) ,
2588
2643
client_socket_ref,
2589
- & no_compression ( )
2644
+ & no_compression ( ) ,
2590
2645
)
2591
2646
. await
2592
2647
. unwrap ( ) ;
@@ -2602,7 +2657,7 @@ mod tests {
2602
2657
FrameOpcode :: Response ( resp_opcode) ,
2603
2658
( body_base. to_string ( ) + "|response|" ) . as_bytes ( ) ,
2604
2659
server_socket_ref,
2605
- & no_compression ( )
2660
+ & no_compression ( ) ,
2606
2661
)
2607
2662
. await
2608
2663
. unwrap ( ) ;
0 commit comments