@@ -816,7 +816,7 @@ impl Doorkeeper {
816
816
FrameOpcode :: Request ( RequestOpcode :: Options ) ,
817
817
& Bytes :: new ( ) ,
818
818
connection,
819
- & no_compression ( )
819
+ & no_compression ( ) ,
820
820
)
821
821
. await
822
822
. map_err ( DoorkeeperError :: ObtainingShardNumber ) ?;
@@ -870,16 +870,29 @@ impl Doorkeeper {
870
870
}
871
871
872
872
mod compression {
873
+ use std:: error:: Error ;
874
+ use std:: sync:: { Arc , OnceLock } ;
875
+
873
876
use bytes:: Bytes ;
877
+ use scylla_cql:: frame:: frame_errors:: {
878
+ CqlRequestSerializationError , FrameBodyExtensionsParseError ,
879
+ } ;
874
880
use scylla_cql:: frame:: request:: {
875
881
DeserializableRequest as _, RequestDeserializationError , Startup ,
876
882
} ;
877
- use scylla_cql:: frame:: {
878
- decompress, frame_errors:: FrameBodyExtensionsParseError , Compression , FLAG_COMPRESSION ,
879
- } ;
883
+ use scylla_cql:: frame:: { compress_append, decompress, Compression , FLAG_COMPRESSION } ;
880
884
use tracing:: { error, warn} ;
881
885
882
- use std:: sync:: { Arc , OnceLock } ;
886
+ #[ derive( Debug , thiserror:: Error ) ]
887
+ pub ( crate ) enum CompressionError {
888
+ /// Body Snap compression failed.
889
+ #[ error( "Snap compression error: {0}" ) ]
890
+ SnapCompressError ( Arc < dyn Error + Sync + Send > ) ,
891
+
892
+ /// Frame is to be compressed, but no compression was negotiated for the connection.
893
+ #[ error( "Frame is to be compressed, but no compression negotiated for connection." ) ]
894
+ NoCompressionNegotiated ,
895
+ }
883
896
884
897
type CompressionInfo = Arc < OnceLock < Option < Compression > > > ;
885
898
@@ -932,6 +945,26 @@ mod compression {
932
945
pub ( crate ) fn get ( & self ) -> Option < Option < Compression > > {
933
946
self . 0 . get ( ) . copied ( )
934
947
}
948
+
949
+ pub ( crate ) fn maybe_compress_body (
950
+ & self ,
951
+ flags : u8 ,
952
+ body : & [ u8 ] ,
953
+ ) -> Result < Option < Bytes > , CompressionError > {
954
+ match ( flags & FLAG_COMPRESSION != 0 , self . get ( ) . flatten ( ) ) {
955
+ ( true , Some ( compression) ) => {
956
+ let mut buf = Vec :: new ( ) ;
957
+ compress_append ( body, compression, & mut buf) . map_err ( |err| {
958
+ let CqlRequestSerializationError :: SnapCompressError ( err) = err else { unreachable ! ( "BUG: compress_append returned variant different than SnapCompressError" ) } ;
959
+ CompressionError :: SnapCompressError ( err)
960
+ } ) ?;
961
+ Ok ( Some ( Bytes :: from ( buf) ) )
962
+ }
963
+ ( true , None ) => Err ( CompressionError :: NoCompressionNegotiated ) ,
964
+ ( false , _) => Ok ( None ) ,
965
+ }
966
+ }
967
+
935
968
pub ( crate ) fn maybe_decompress_body (
936
969
& self ,
937
970
flags : u8 ,
@@ -1556,7 +1589,9 @@ mod tests {
1556
1589
let send_frame_to_shard = async {
1557
1590
let mut conn = TcpStream :: connect ( node1_proxy_addr) . await . unwrap ( ) ;
1558
1591
1559
- write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) ) . await . unwrap ( ) ;
1592
+ write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) )
1593
+ . await
1594
+ . unwrap ( ) ;
1560
1595
conn
1561
1596
} ;
1562
1597
@@ -2093,9 +2128,15 @@ mod tests {
2093
2128
2094
2129
let mock_node_action = async {
2095
2130
let ( mut conn, _) = mock_node_listener. accept ( ) . await . unwrap ( ) ;
2096
- write_frame ( params. for_response ( ) , response_opcode, & body, & mut conn, & no_compression ( ) )
2097
- . await
2098
- . unwrap ( ) ;
2131
+ write_frame (
2132
+ params. for_response ( ) ,
2133
+ response_opcode,
2134
+ & body,
2135
+ & mut conn,
2136
+ & no_compression ( ) ,
2137
+ )
2138
+ . await
2139
+ . unwrap ( ) ;
2099
2140
conn
2100
2141
} ;
2101
2142
@@ -2266,7 +2307,9 @@ mod tests {
2266
2307
2267
2308
let mut conn = TcpStream :: connect ( node1_proxy_addr) . await . unwrap ( ) ;
2268
2309
2269
- write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) ) . await . unwrap ( ) ;
2310
+ write_frame ( params, opcode, & body, & mut conn, & no_compression ( ) )
2311
+ . await
2312
+ . unwrap ( ) ;
2270
2313
// We assert that after sufficiently long time, no error happens inside proxy.
2271
2314
tokio:: time:: sleep ( Duration :: from_millis ( 3 ) ) . await ;
2272
2315
running_proxy. finish ( ) . await . unwrap ( ) ;
@@ -2457,9 +2500,15 @@ mod tests {
2457
2500
let socket = bind_socket_for_shard ( shards_count, driver_shard) . await ;
2458
2501
let mut conn = socket. connect ( node_proxy_addr) . await . unwrap ( ) ;
2459
2502
2460
- write_frame ( params, request_opcode, body_ref, & mut conn, & no_compression ( ) )
2461
- . await
2462
- . unwrap ( ) ;
2503
+ write_frame (
2504
+ params,
2505
+ request_opcode,
2506
+ body_ref,
2507
+ & mut conn,
2508
+ & no_compression ( ) ,
2509
+ )
2510
+ . await
2511
+ . unwrap ( ) ;
2463
2512
conn
2464
2513
} ;
2465
2514
@@ -2477,9 +2526,15 @@ mod tests {
2477
2526
& no_compression ( ) ,
2478
2527
)
2479
2528
. await ;
2480
- write_frame ( params. for_response ( ) , response_opcode, body_ref, & mut conn, & no_compression ( ) )
2481
- . await
2482
- . unwrap ( ) ;
2529
+ write_frame (
2530
+ params. for_response ( ) ,
2531
+ response_opcode,
2532
+ body_ref,
2533
+ & mut conn,
2534
+ & no_compression ( ) ,
2535
+ )
2536
+ . await
2537
+ . unwrap ( ) ;
2483
2538
conn
2484
2539
} )
2485
2540
. collect :: < Vec < _ > > ( ) ;
@@ -2587,7 +2642,7 @@ mod tests {
2587
2642
FrameOpcode :: Request ( req_opcode) ,
2588
2643
( body_base. to_string ( ) + "|request|" ) . as_bytes ( ) ,
2589
2644
client_socket_ref,
2590
- & no_compression ( )
2645
+ & no_compression ( ) ,
2591
2646
)
2592
2647
. await
2593
2648
. unwrap ( ) ;
@@ -2603,7 +2658,7 @@ mod tests {
2603
2658
FrameOpcode :: Response ( resp_opcode) ,
2604
2659
( body_base. to_string ( ) + "|response|" ) . as_bytes ( ) ,
2605
2660
server_socket_ref,
2606
- & no_compression ( )
2661
+ & no_compression ( ) ,
2607
2662
)
2608
2663
. await
2609
2664
. unwrap ( ) ;
0 commit comments