Skip to content

Commit d3d4d35

Browse files
committed
Add MSK IAM auth support
1 parent a8c2739 commit d3d4d35

File tree

9 files changed

+211
-19
lines changed

9 files changed

+211
-19
lines changed

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ tokio-util = "0.7"
3232
reqwest = { version = "0.11", features = ["json"] }
3333
aws-config = { version = "1.6", features = ["behavior-version-latest"] }
3434
aws-sdk-s3 = { version = "1.85", features = ["behavior-version-latest"] }
35+
aws-sigv4 = "1.3"
36+
aws-credential-types = "1.2"
37+
url = "2.5"
3538
prometheus-client = "0.22"
3639
sentry = "0.36"
3740
jemallocator = "0.5"

misc/rds_iam_bootstrap

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ DATABASE_PASSWORD="$(aws rds generate-db-auth-token \
2525

2626
export DATABASE_URL="postgres://$DATABASE_USER:$DATABASE_PASSWORD@$DATABASE_HOST"
2727

28+
unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY AWS_SESSION_TOKEN
29+
2830
constellation-processors $@
2931

3032
get_auth_token

src/aggregator/mod.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ use crate::epoch::EpochConfig;
1010
use crate::models::{DBConnectionType, DBPool, DBStorageConnections, PgStoreError};
1111
use crate::profiler::{Profiler, ProfilerStat};
1212
use crate::record_stream::{
13-
get_data_channel_topic_from_env, KafkaRecordStream, KafkaRecordStreamConfig, RecordStream,
14-
RecordStreamArc, RecordStreamError,
13+
get_data_channel_topic_from_env, KafkaRecordStreamConfig, KafkaRecordStreamFactory, RecordStream, RecordStreamArc, RecordStreamError
1514
};
1615
use crate::star::AppSTARError;
1716
use crate::util::parse_env_var;
@@ -48,14 +47,15 @@ pub enum AggregatorError {
4847
}
4948

5049
fn create_output_stream(
50+
rec_stream_factory: &KafkaRecordStreamFactory,
5151
output_measurements_to_stdout: bool,
5252
channel_name: &str,
5353
) -> Result<Option<RecordStreamArc>, AggregatorError> {
5454
let topic = get_data_channel_topic_from_env(true, channel_name);
5555
Ok(if output_measurements_to_stdout {
5656
None
5757
} else {
58-
let out_stream = Arc::new(KafkaRecordStream::new(KafkaRecordStreamConfig {
58+
let out_stream = Arc::new(rec_stream_factory.create_record_stream(KafkaRecordStreamConfig {
5959
enable_producer: true,
6060
enable_consumer: false,
6161
topic,
@@ -96,12 +96,13 @@ pub async fn start_aggregation(
9696

9797
info!("Starting aggregation...");
9898

99-
let out_stream = create_output_stream(output_measurements_to_stdout, channel_name)?;
99+
let rec_stream_factory = KafkaRecordStreamFactory::new();
100+
let out_stream = create_output_stream(&rec_stream_factory, output_measurements_to_stdout, channel_name)?;
100101

101102
let mut in_streams: Vec<RecordStreamArc> = Vec::new();
102103
let in_stream_topic = get_data_channel_topic_from_env(false, channel_name);
103104
for _ in 0..CONSUMER_COUNT {
104-
in_streams.push(Arc::new(KafkaRecordStream::new(KafkaRecordStreamConfig {
105+
in_streams.push(Arc::new(rec_stream_factory.create_record_stream(KafkaRecordStreamConfig {
105106
enable_producer: false,
106107
enable_consumer: true,
107108
topic: in_stream_topic.clone(),
@@ -208,7 +209,7 @@ pub async fn start_aggregation(
208209
// Delete pending/recovered messages from DB.
209210
info!("Checking/processing expired epochs");
210211
let profiler = Arc::new(Profiler::default());
211-
let out_stream = create_output_stream(output_measurements_to_stdout, channel_name)?;
212+
let out_stream = create_output_stream(&rec_stream_factory, output_measurements_to_stdout, channel_name)?;
212213
let db_conn = Arc::new(Mutex::new(db_pool.get().await?));
213214
process_expired_epochs(db_conn.clone(), &epoch_config, out_stream, profiler.clone()).await?;
214215
info!("Profiler summary:\n{}", profiler.summary().await);

src/lakesink.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::lake::{DataLake, DataLakeError};
22
use crate::prometheus::DataLakeMetrics;
33
use crate::record_stream::{
4-
DynRecordStream, KafkaRecordStream, KafkaRecordStreamConfig, RecordStream, RecordStreamError,
4+
DynRecordStream, KafkaRecordStreamConfig, KafkaRecordStreamFactory, RecordStream, RecordStreamError
55
};
66
use crate::util::parse_env_var;
77
use derive_more::{Display, Error, From};
@@ -53,7 +53,8 @@ pub async fn start_lakesink(
5353
) -> Result<(), LakeSinkError> {
5454
let batch_size = parse_env_var::<usize>(BATCH_SIZE_ENV_KEY, BATCH_SIZE_DEFAULT);
5555

56-
let rec_stream = KafkaRecordStream::new(KafkaRecordStreamConfig {
56+
let rec_stream_factory = KafkaRecordStreamFactory::new();
57+
let rec_stream = rec_stream_factory.create_record_stream(KafkaRecordStreamConfig {
5758
enable_producer: false,
5859
enable_consumer: true,
5960
topic: stream_topic,

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod epoch;
44
mod lake;
55
mod lakesink;
66
mod models;
7+
mod msk_iam;
78
mod profiler;
89
mod prometheus;
910
mod record_stream;

src/msk_iam.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
use aws_config::SdkConfig;
2+
use aws_credential_types::provider::ProvideCredentials;
3+
use aws_sigv4::{http_request::{self, SignableBody, SignableRequest, SigningSettings}, sign::v4};
4+
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
5+
use futures::executor;
6+
use time::{macros::format_description, Duration, OffsetDateTime};
7+
use std::{error::Error, time::SystemTime};
8+
use url::Url;
9+
10+
const ACTION_TYPE: &str = "Action";
11+
const ACTION_NAME: &str = "kafka-cluster:Connect";
12+
const SIGNING_NAME: &str = "kafka-cluster";
13+
const USER_AGENT_KEY: &str = "User-Agent";
14+
const DATE_QUERY_KEY: &str = "X-Amz-Date";
15+
const EXPIRES_QUERY_KEY: &str = "X-Amz-Expires";
16+
const DEFAULT_EXPIRY_SECONDS: u64 = 900;
17+
18+
const APP_NAME: &str = "constellation-processors";
19+
const USER_AGENT_VERSION: &str = "1.0";
20+
21+
#[derive(Clone)]
22+
pub struct TokenInfo {
23+
pub token: String,
24+
pub expiration_time: OffsetDateTime,
25+
}
26+
pub struct MSKIAMAuthManager {
27+
token_info: Option<TokenInfo>,
28+
}
29+
30+
impl MSKIAMAuthManager {
31+
pub fn new() -> Self {
32+
Self {
33+
token_info: None,
34+
}
35+
}
36+
37+
pub fn get_auth_token(&mut self) -> Result<TokenInfo, Box<dyn Error>> {
38+
if let Some(token_info) = &self.token_info {
39+
if token_info.expiration_time > OffsetDateTime::now_utc() {
40+
return Ok(token_info.clone());
41+
}
42+
}
43+
44+
let token_info = executor::block_on(generate_auth_token_async())?;
45+
self.token_info = Some(token_info.clone());
46+
Ok(token_info)
47+
}
48+
}
49+
50+
async fn generate_auth_token_async() -> Result<TokenInfo, Box<dyn Error>> {
51+
let config = aws_config::from_env()
52+
.load()
53+
.await;
54+
55+
let mut url = build_request_url(&config)?;
56+
57+
sign_request_url(&mut url, &config).await?;
58+
59+
let expiration_time = get_expiration_time(&url)?;
60+
61+
add_user_agent(&mut url);
62+
63+
let encoded = URL_SAFE_NO_PAD.encode(url.as_str().as_bytes());
64+
65+
Ok(TokenInfo {
66+
token: encoded,
67+
expiration_time,
68+
})
69+
}
70+
71+
fn build_request_url(config: &SdkConfig) -> Result<Url, Box<dyn Error>> {
72+
let endpoint_url = format!("https://kafka.{}.amazonaws.com/", config.region().ok_or_else(|| "AWS region is not set")?.to_string());
73+
let mut url = Url::parse(&endpoint_url)?;
74+
75+
{
76+
let mut query_pairs = url.query_pairs_mut();
77+
query_pairs.append_pair(ACTION_TYPE, ACTION_NAME);
78+
query_pairs.append_pair(EXPIRES_QUERY_KEY, &DEFAULT_EXPIRY_SECONDS.to_string());
79+
}
80+
81+
Ok(url)
82+
}
83+
84+
async fn sign_request_url(
85+
url: &mut Url,
86+
config: &SdkConfig,
87+
) -> Result<(), Box<dyn Error>> {
88+
let credentials_provider = config.credentials_provider().ok_or_else(|| "AWS credentials provider is not set")?;
89+
let credentials = credentials_provider.provide_credentials().await?;
90+
91+
let signable_request = SignableRequest::new("GET", url.as_str(), std::iter::empty(), SignableBody::Bytes(&[]))?;
92+
93+
let identity = credentials.into();
94+
let region = config.region().ok_or_else(|| "AWS region is not set")?.to_string();
95+
let signing_params = v4::SigningParams::builder()
96+
.identity(&identity)
97+
.region(&region)
98+
.name(SIGNING_NAME)
99+
.time(SystemTime::now())
100+
.settings(SigningSettings::default())
101+
.build()?
102+
.into();
103+
104+
let signing_output = http_request::sign(
105+
signable_request,
106+
&signing_params,
107+
)?;
108+
109+
for (key, value) in signing_output.output().params() {
110+
url.query_pairs_mut().append_pair(key, value);
111+
}
112+
113+
Ok(())
114+
}
115+
116+
fn get_expiration_time(url: &Url) -> Result<OffsetDateTime, Box<dyn Error>> {
117+
let date_str = url.query_pairs()
118+
.find_map(|(k, v)| if k == DATE_QUERY_KEY { Some(v.to_string()) } else { None })
119+
.ok_or_else(|| "failed to find AWS signed date parameter")?;
120+
121+
let date_format_description = format_description!("[year][month][day]T[hour][minute][second]Z");
122+
let date = OffsetDateTime::parse(&date_str, date_format_description)?;
123+
124+
let expiry_duration_seconds = url.query_pairs()
125+
.find_map(|(k, v)| if k == EXPIRES_QUERY_KEY { Some(v.to_string()) } else { None })
126+
.ok_or_else(|| "failed to find X-Amz-Expires parameter")?
127+
.parse::<i64>()?;
128+
129+
let expiry_duration = Duration::seconds(expiry_duration_seconds);
130+
let expiry_time = date + expiry_duration;
131+
132+
Ok(expiry_time)
133+
}
134+
135+
fn add_user_agent(url: &mut Url) {
136+
let user_agent = format!("{}/{}/{}", APP_NAME, USER_AGENT_VERSION, USER_AGENT_VERSION);
137+
url.query_pairs_mut().append_pair(USER_AGENT_KEY, &user_agent);
138+
}

src/record_stream.rs

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use async_trait::async_trait;
22
use derive_more::{Display, Error, From};
33
use futures::future::try_join_all;
44
use rand::{seq::SliceRandom, thread_rng};
5-
use rdkafka::client::ClientContext;
5+
use rdkafka::client::{ClientContext, OAuthToken};
66
use rdkafka::config::ClientConfig;
77
use rdkafka::consumer::{
88
stream_consumer::StreamConsumer, CommitMode, Consumer, ConsumerContext, Rebalance,
@@ -14,20 +14,23 @@ use rdkafka::types::RDKafkaErrorCode;
1414
use rdkafka::TopicPartitionList;
1515
use std::collections::HashMap;
1616
use std::env;
17-
use std::sync::Arc;
17+
use std::error::Error as StdError;
18+
use std::sync::{Arc, Mutex as StdMutex};
1819
use std::time::Duration;
1920
use tokio::sync::mpsc::{error::SendError, unbounded_channel, UnboundedSender};
2021
use tokio::sync::{Mutex, RwLock};
2122
use tokio::task::{JoinError, JoinHandle};
2223
use tokio::time::sleep;
2324

2425
use crate::channel::{get_data_channel_map_from_env, get_data_channel_value_from_env};
26+
use crate::msk_iam::MSKIAMAuthManager;
2527
use crate::util::parse_env_var;
2628

2729
const KAFKA_ENC_TOPICS_ENV_KEY: &str = "KAFKA_ENCRYPTED_TOPICS";
2830
const KAFKA_OUT_TOPICS_ENV_KEY: &str = "KAFKA_OUTPUT_TOPICS";
2931
const DEFAULT_ENC_KAFKA_TOPICS: &str = "typical=p3a-star-enc";
3032
const DEFAULT_OUT_KAFKA_TOPICS: &str = "typical=p3a-star-out";
33+
const KAFKA_IAM_BROKERS_ENV_KEY: &str = "KAFKA_IAM_BROKERS";
3134
const KAFKA_BROKERS_ENV_KEY: &str = "KAFKA_BROKERS";
3235
const KAFKA_ENABLE_PLAINTEXT_ENV_KEY: &str = "KAFKA_ENABLE_PLAINTEXT";
3336
const KAFKA_PRODUCER_QUEUE_TASK_COUNT_ENV_KEY: &str = "KAFKA_PRODUCE_QUEUE_TASK_COUNT";
@@ -54,9 +57,24 @@ pub enum RecordStreamError {
5457
Join(JoinError),
5558
}
5659

57-
struct KafkaContext;
60+
#[derive(Clone)]
61+
struct KafkaContext {
62+
msk_iam_auth_manager: Arc<StdMutex<MSKIAMAuthManager>>,
63+
}
64+
5865

59-
impl ClientContext for KafkaContext {}
66+
impl ClientContext for KafkaContext {
67+
const ENABLE_REFRESH_OAUTH_TOKEN: bool = true;
68+
69+
fn generate_oauth_token(&self, _oauthbearer_config: Option<&str>) -> Result<OAuthToken, Box<dyn StdError>> {
70+
let token_info = self.msk_iam_auth_manager.lock().unwrap().get_auth_token()?;
71+
Ok(OAuthToken {
72+
token: token_info.token,
73+
lifetime_ms: (token_info.expiration_time.unix_timestamp_nanos() / 1_000_000) as i64,
74+
principal_name: String::new(),
75+
})
76+
}
77+
}
6078

6179
impl ConsumerContext for KafkaContext {
6280
fn pre_rebalance(&self, rebalance: &Rebalance) {
@@ -115,6 +133,25 @@ pub struct KafkaRecordStreamConfig {
115133
pub use_output_group_id: bool,
116134
}
117135

136+
pub struct KafkaRecordStreamFactory {
137+
msk_iam_auth_manager: Arc<StdMutex<MSKIAMAuthManager>>,
138+
}
139+
140+
impl KafkaRecordStreamFactory {
141+
pub fn new() -> Self {
142+
Self {
143+
msk_iam_auth_manager: Arc::new(StdMutex::new(MSKIAMAuthManager::new())),
144+
}
145+
}
146+
147+
pub fn create_record_stream(&self, stream_config: KafkaRecordStreamConfig) -> KafkaRecordStream {
148+
let context = KafkaContext {
149+
msk_iam_auth_manager: self.msk_iam_auth_manager.clone(),
150+
};
151+
KafkaRecordStream::new(stream_config, context)
152+
}
153+
}
154+
118155
pub struct KafkaRecordStream {
119156
producer: Option<Arc<FutureProducer<KafkaContext>>>,
120157
consumer: Option<StreamConsumer<KafkaContext>>,
@@ -150,7 +187,7 @@ pub fn get_data_channel_topic_from_env(use_output_topic: bool, channel_name: &st
150187
}
151188

152189
impl KafkaRecordStream {
153-
pub fn new(stream_config: KafkaRecordStreamConfig) -> Self {
190+
fn new(stream_config: KafkaRecordStreamConfig, context: KafkaContext) -> Self {
154191
let group_id = match stream_config.use_output_group_id {
155192
true => "star-agg-dec",
156193
false => "star-agg-enc",
@@ -163,7 +200,6 @@ impl KafkaRecordStream {
163200
producer_queues: RwLock::new(Vec::new()),
164201
};
165202
if stream_config.enable_producer {
166-
let context = KafkaContext;
167203
let mut config = Self::new_client_config();
168204
let mut config_ref = &mut config;
169205
if stream_config.use_output_group_id {
@@ -175,13 +211,12 @@ impl KafkaRecordStream {
175211
.set("transaction.timeout.ms", "3600000")
176212
.set("request.timeout.ms", "900000")
177213
.set("socket.timeout.ms", "300000")
178-
.create_with_context(context)
214+
.create_with_context(context.clone())
179215
.unwrap(),
180216
));
181217
info!("Producing to topic: {}", stream_config.topic);
182218
}
183219
if stream_config.enable_consumer {
184-
let context = KafkaContext;
185220
let mut config = Self::new_client_config();
186221
result.consumer = Some(
187222
config
@@ -210,6 +245,13 @@ impl KafkaRecordStream {
210245
}
211246

212247
fn new_client_config() -> ClientConfig {
248+
if let Some(brokers) = env::var(KAFKA_IAM_BROKERS_ENV_KEY).ok() {
249+
let mut result = ClientConfig::new();
250+
result.set("bootstrap.servers", brokers);
251+
result.set("security.protocol", "SASL_SSL");
252+
result.set("sasl.mechanism", "OAUTHBEARER");
253+
return result;
254+
}
213255
let brokers = env::var(KAFKA_BROKERS_ENV_KEY)
214256
.unwrap_or_else(|_| panic!("{} env var must be defined", KAFKA_BROKERS_ENV_KEY));
215257
let mut result = ClientConfig::new();

0 commit comments

Comments
 (0)