55
66#![ cfg( feature = "test-util" ) ]
77
8+ use aws_config:: retry:: RetryConfig ;
9+ use aws_sdk_s3:: config:: retry:: RetryPartition ;
810use aws_sdk_s3:: { config:: Region , Client , Config } ;
911use aws_smithy_async:: test_util:: ManualTimeSource ;
1012use aws_smithy_async:: time:: SharedTimeSource ;
1113use aws_smithy_http_client:: test_util:: { ReplayEvent , StaticReplayClient } ;
1214use aws_smithy_runtime:: client:: retries:: TokenBucket ;
1315use aws_smithy_runtime_api:: box_error:: BoxError ;
14- use aws_smithy_runtime_api:: client:: interceptors:: context:: BeforeTransmitInterceptorContextMut ;
1516use aws_smithy_runtime_api:: client:: interceptors:: Intercept ;
1617use aws_smithy_runtime_api:: client:: runtime_components:: RuntimeComponents ;
1718use aws_smithy_types:: body:: SdkBody ;
1819use aws_smithy_types:: config_bag:: ConfigBag ;
19- use std:: sync:: LazyLock ;
20+ use std:: sync:: Mutex ;
21+ use std:: sync:: { Arc , LazyLock } ;
2022use std:: time:: { Duration , SystemTime , UNIX_EPOCH } ;
2123
2224static THE_TIME : LazyLock < SystemTime > =
2325 LazyLock :: new ( || UNIX_EPOCH + Duration :: from_secs ( 12344321 ) ) ;
2426
2527#[ derive( Debug ) ]
26- struct TimeSourceValidationInterceptor ;
28+ struct TimeSourceValidationInterceptor {
29+ current_attempt : Arc < Mutex < u32 > > ,
30+ }
2731
2832impl Intercept for TimeSourceValidationInterceptor {
2933 fn name ( & self ) -> & ' static str {
3034 "TimeSourceValidationInterceptor"
3135 }
3236
33- fn modify_before_transmit (
37+ fn read_before_attempt (
3438 & self ,
35- _context : & mut BeforeTransmitInterceptorContextMut < ' _ > ,
39+ _context : & aws_sdk_s3 :: config :: interceptors :: BeforeTransmitInterceptorContextRef < ' _ > ,
3640 _runtime_components : & RuntimeComponents ,
3741 cfg : & mut ConfigBag ,
3842 ) -> Result < ( ) , BoxError > {
3943 if let Some ( token_bucket) = cfg. load :: < TokenBucket > ( ) {
40- let token_bucket_time_source = token_bucket. time_source ( ) ;
41- let token_time = token_bucket_time_source. now ( ) ;
44+ * self . current_attempt . lock ( ) . unwrap ( ) += 1 ;
4245
43- assert_eq ! (
44- * THE_TIME , token_time,
45- "Token source should match the configured time source"
46- ) ;
46+ if * self . current_attempt . lock ( ) . unwrap ( ) == 1 {
47+ let last_refill = token_bucket
48+ . last_refill_time_secs ( )
49+ . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
50+ assert_eq ! ( last_refill, 0 ) ;
51+ } else if * self . current_attempt . lock ( ) . unwrap ( ) == 2 {
52+ let last_refill = token_bucket
53+ . last_refill_time_secs ( )
54+ . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
55+ assert_eq ! ( last_refill, 12344321 ) ;
56+ } else {
57+ panic ! ( "No attempts past the second should happen" ) ;
58+ }
4759 }
4860 Ok ( ( ) )
4961 }
@@ -54,22 +66,42 @@ async fn test_token_bucket_gets_time_source_from_config() {
5466 let time_source = ManualTimeSource :: new ( * THE_TIME ) ;
5567 let shared_time_source = SharedTimeSource :: new ( time_source) ;
5668
57- let http_client = StaticReplayClient :: new ( vec ! [ ReplayEvent :: new(
58- http_1x:: Request :: builder( )
59- . uri( "https://www.doesntmatter.com" )
60- . body( SdkBody :: empty( ) )
61- . unwrap( ) ,
62- http_1x:: Response :: builder( )
63- . status( 200 )
64- . body( SdkBody :: from( "<ListBucketResult></ListBucketResult>" ) )
65- . unwrap( ) ,
66- ) ] ) ;
69+ let http_client = StaticReplayClient :: new ( vec ! [
70+ ReplayEvent :: new(
71+ http_1x:: Request :: builder( )
72+ . uri( "https://www.doesntmatter.com" )
73+ . body( SdkBody :: empty( ) )
74+ . unwrap( ) ,
75+ http_1x:: Response :: builder( )
76+ . status( 500 )
77+ . body( SdkBody :: from( "This was an error" ) )
78+ . unwrap( ) ,
79+ ) ,
80+ ReplayEvent :: new(
81+ http_1x:: Request :: builder( )
82+ . uri( "https://www.doesntmatter.com" )
83+ . body( SdkBody :: empty( ) )
84+ . unwrap( ) ,
85+ http_1x:: Response :: builder( )
86+ . status( 200 )
87+ . body( SdkBody :: from( "<ListBucketResult></ListBucketResult>" ) )
88+ . unwrap( ) ,
89+ ) ,
90+ ] ) ;
6791
6892 let config = Config :: builder ( )
6993 . region ( Region :: new ( "us-east-1" ) )
7094 . http_client ( http_client)
7195 . time_source ( shared_time_source)
72- . interceptor ( TimeSourceValidationInterceptor )
96+ . interceptor ( TimeSourceValidationInterceptor {
97+ current_attempt : Arc :: new ( Mutex :: new ( 0 ) ) ,
98+ } )
99+ . retry_config ( RetryConfig :: standard ( ) )
100+ . retry_partition (
101+ RetryPartition :: custom ( "test" )
102+ . token_bucket ( TokenBucket :: builder ( ) . refill_rate ( 100.0 ) . build ( ) )
103+ . build ( ) ,
104+ )
73105 . build ( ) ;
74106
75107 let client = Client :: from_conf ( config) ;
0 commit comments