Skip to content

Commit ff312b5

Browse files
committed
Updating integ test for TokenBucket time source
1 parent b43a32e commit ff312b5

2 files changed

Lines changed: 65 additions & 24 deletions

File tree

aws/sdk/integration-tests/s3/tests/token_bucket_time_source.rs

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,57 @@
55

66
#![cfg(feature = "test-util")]
77

8+
use aws_config::retry::RetryConfig;
9+
use aws_sdk_s3::config::retry::RetryPartition;
810
use aws_sdk_s3::{config::Region, Client, Config};
911
use aws_smithy_async::test_util::ManualTimeSource;
1012
use aws_smithy_async::time::SharedTimeSource;
1113
use aws_smithy_http_client::test_util::{ReplayEvent, StaticReplayClient};
1214
use aws_smithy_runtime::client::retries::TokenBucket;
1315
use aws_smithy_runtime_api::box_error::BoxError;
14-
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
1516
use aws_smithy_runtime_api::client::interceptors::Intercept;
1617
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
1718
use aws_smithy_types::body::SdkBody;
1819
use aws_smithy_types::config_bag::ConfigBag;
19-
use std::sync::LazyLock;
20+
use std::sync::Mutex;
21+
use std::sync::{Arc, LazyLock};
2022
use std::time::{Duration, SystemTime, UNIX_EPOCH};
2123

2224
static THE_TIME: LazyLock<SystemTime> =
2325
LazyLock::new(|| UNIX_EPOCH + Duration::from_secs(12344321));
2426

2527
#[derive(Debug)]
26-
struct TimeSourceValidationInterceptor;
28+
struct TimeSourceValidationInterceptor {
29+
current_attempt: Arc<Mutex<u32>>,
30+
}
2731

2832
impl Intercept for TimeSourceValidationInterceptor {
2933
fn name(&self) -> &'static str {
3034
"TimeSourceValidationInterceptor"
3135
}
3236

33-
fn modify_before_transmit(
37+
fn read_before_attempt(
3438
&self,
35-
_context: &mut BeforeTransmitInterceptorContextMut<'_>,
39+
_context: &aws_sdk_s3::config::interceptors::BeforeTransmitInterceptorContextRef<'_>,
3640
_runtime_components: &RuntimeComponents,
3741
cfg: &mut ConfigBag,
3842
) -> Result<(), BoxError> {
3943
if let Some(token_bucket) = cfg.load::<TokenBucket>() {
40-
let token_bucket_time_source = token_bucket.time_source();
41-
let token_time = token_bucket_time_source.now();
44+
*self.current_attempt.lock().unwrap() += 1;
4245

43-
assert_eq!(
44-
*THE_TIME, token_time,
45-
"Token source should match the configured time source"
46-
);
46+
if *self.current_attempt.lock().unwrap() == 1 {
47+
let last_refill = token_bucket
48+
.last_refill_time_secs()
49+
.load(std::sync::atomic::Ordering::Relaxed);
50+
assert_eq!(last_refill, 0);
51+
} else if *self.current_attempt.lock().unwrap() == 2 {
52+
let last_refill = token_bucket
53+
.last_refill_time_secs()
54+
.load(std::sync::atomic::Ordering::Relaxed);
55+
assert_eq!(last_refill, 12344321);
56+
} else {
57+
panic!("No attempts past the second should happen");
58+
}
4759
}
4860
Ok(())
4961
}
@@ -54,22 +66,42 @@ async fn test_token_bucket_gets_time_source_from_config() {
5466
let time_source = ManualTimeSource::new(*THE_TIME);
5567
let shared_time_source = SharedTimeSource::new(time_source);
5668

57-
let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
58-
http_1x::Request::builder()
59-
.uri("https://www.doesntmatter.com")
60-
.body(SdkBody::empty())
61-
.unwrap(),
62-
http_1x::Response::builder()
63-
.status(200)
64-
.body(SdkBody::from("<ListBucketResult></ListBucketResult>"))
65-
.unwrap(),
66-
)]);
69+
let http_client = StaticReplayClient::new(vec![
70+
ReplayEvent::new(
71+
http_1x::Request::builder()
72+
.uri("https://www.doesntmatter.com")
73+
.body(SdkBody::empty())
74+
.unwrap(),
75+
http_1x::Response::builder()
76+
.status(500)
77+
.body(SdkBody::from("This was an error"))
78+
.unwrap(),
79+
),
80+
ReplayEvent::new(
81+
http_1x::Request::builder()
82+
.uri("https://www.doesntmatter.com")
83+
.body(SdkBody::empty())
84+
.unwrap(),
85+
http_1x::Response::builder()
86+
.status(200)
87+
.body(SdkBody::from("<ListBucketResult></ListBucketResult>"))
88+
.unwrap(),
89+
),
90+
]);
6791

6892
let config = Config::builder()
6993
.region(Region::new("us-east-1"))
7094
.http_client(http_client)
7195
.time_source(shared_time_source)
72-
.interceptor(TimeSourceValidationInterceptor)
96+
.interceptor(TimeSourceValidationInterceptor {
97+
current_attempt: Arc::new(Mutex::new(0)),
98+
})
99+
.retry_config(RetryConfig::standard())
100+
.retry_partition(
101+
RetryPartition::custom("test")
102+
.token_bucket(TokenBucket::builder().refill_rate(100.0).build())
103+
.build(),
104+
)
73105
.build();
74106

75107
let client = Client::from_conf(config);

rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ impl TokenBucket {
186186
/// Refills tokens based on elapsed time since last refill.
187187
/// This method implements lazy evaluation - tokens are only calculated when accessed.
188188
/// Uses a single compare-and-swap to ensure only one thread processes each time window.
189+
#[inline]
189190
fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
190191
if self.refill_rate > 0.0 {
191192
let current_time_secs = time_source
@@ -274,6 +275,14 @@ impl TokenBucket {
274275
pub(crate) fn available_permits(&self) -> usize {
275276
self.semaphore.available_permits()
276277
}
278+
279+
/// Only used in tests
280+
#[allow(dead_code)]
281+
#[doc(hidden)]
282+
#[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
283+
pub fn last_refill_time_secs(&self) -> Arc<AtomicU64> {
284+
self.last_refill_time_secs.clone()
285+
}
277286
}
278287

279288
/// Builder for constructing a `TokenBucket`.
@@ -350,7 +359,7 @@ impl TokenBucketBuilder {
350359
mod tests {
351360

352361
use super::*;
353-
use aws_smithy_async::{test_util::ManualTimeSource, time::SharedTimeSource};
362+
use aws_smithy_async::test_util::ManualTimeSource;
354363
use std::{sync::LazyLock, time::UNIX_EPOCH};
355364

356365
static TIME_SOURCE: LazyLock<ManualTimeSource> =
@@ -907,7 +916,7 @@ mod tests {
907916

908917
// Advance time by 10 seconds
909918
time_source.advance(Duration::from_secs(10));
910-
let shared_time_source = SharedTimeSource::new(time_source);
919+
let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
911920

912921
// Launch 100 threads that all try to refill simultaneously
913922
let barrier = Arc::new(Barrier::new(100));

0 commit comments

Comments
 (0)