diff --git a/crates/dekaf/src/api_client.rs b/crates/dekaf/src/api_client.rs index 32463abfe8..86bcb81b4b 100644 --- a/crates/dekaf/src/api_client.rs +++ b/crates/dekaf/src/api_client.rs @@ -454,13 +454,13 @@ impl KafkaApiClient { #[instrument(skip_all)] pub async fn ensure_topics( &mut self, - topic_names: Vec, + topics: Vec<(messages::TopicName, usize)>, ) -> anyhow::Result<()> { let req = messages::MetadataRequest::default() .with_topics(Some( - topic_names + topics .iter() - .map(|name| { + .map(|(name, _)| { messages::metadata_request::MetadataRequestTopic::default() .with_name(Some(name.clone())) }) @@ -472,43 +472,145 @@ impl KafkaApiClient { let resp = coord.send_request(req, None).await?; tracing::debug!(metadata=?resp, "Got metadata response"); - if resp.topics.iter().all(|topic| { - topic - .name - .as_ref() - .map(|topic_name| topic_names.contains(topic_name) && topic.error_code == 0) - .unwrap_or(false) - }) { - return Ok(()); - } else { - let mut topics_map = vec![]; - for topic_name in topic_names.into_iter() { - topics_map.push( - messages::create_topics_request::CreatableTopic::default() - .with_name(topic_name) - .with_replication_factor(2) - .with_num_partitions(-1), - ); - } - let create_req = messages::CreateTopicsRequest::default().with_topics(topics_map); - let create_resp = coord.send_request(create_req, None).await?; - tracing::debug!(create_response=?create_resp, "Got create response"); - - for topic in create_resp.topics { - if topic.error_code > 0 { - let err = kafka_protocol::ResponseError::try_from_code(topic.error_code); - tracing::warn!( - topic = topic.name.to_string(), - error = ?err, - message = topic.error_message.map(|m|m.to_string()), - "Failed to create topic" + let mut topics_to_update = Vec::new(); + let mut topics_to_create = Vec::new(); + + for (topic_name, desired_partitions) in topics.iter() { + if let Some(topic) = resp + .topics + .iter() + .find(|t| t.name.as_ref() == Some(topic_name)) + { + let current_partitions = topic.partitions.len(); + if *desired_partitions > current_partitions { + tracing::info!( + topic = ?topic_name, + current_partitions = current_partitions, + desired_partitions = *desired_partitions, + "Increasing partition count for topic", + ); + topics_to_update.push((topic_name.clone(), *desired_partitions)); + } else if *desired_partitions < current_partitions { + anyhow::bail!("Topic {} has more partitions ({}) than requested ({}), cannot decrease partition count", + topic_name.as_str(), + current_partitions, + desired_partitions ); - bail!("Failed to create topic"); } + } else { + // Topic doesn't exist, add to creation list + tracing::info!( + topic = ?topic_name, + desired_partitions = *desired_partitions, + "Creating new topic as it does not exist", + ); + topics_to_create.push((topic_name.clone(), *desired_partitions)); + } + } + + if !topics_to_update.is_empty() { + self.increase_partition_counts(topics_to_update).await?; + } + + if !topics_to_create.is_empty() { + self.create_new_topics(topics_to_create).await?; + } + + Ok(()) + } + + #[instrument(skip_all)] + async fn increase_partition_counts( + &mut self, + topics: Vec<(messages::TopicName, usize)>, + ) -> anyhow::Result<()> { + let coord = self.connect_to_controller().await?; + + let mut topic_partitions = Vec::new(); + for (topic_name, partition_count) in topics { + topic_partitions.push( + messages::create_partitions_request::CreatePartitionsTopic::default() + .with_name(topic_name) + .with_count(partition_count as i32) + // Let Kafka auto-assign new partitions to brokers + .with_assignments(None), + ); + } + + let create_partitions_req = messages::CreatePartitionsRequest::default() + .with_topics(topic_partitions) + .with_timeout_ms(30000) // This request will cause a rebalance, so it can take some time + .with_validate_only(false); // Actually perform the changes + + let resp = coord.send_request(create_partitions_req, None).await?; + tracing::debug!(response = ?resp, "Got create partitions response"); + + for result in resp.results { + if result.error_code > 0 { + let err = kafka_protocol::ResponseError::try_from_code(result.error_code); + tracing::warn!( + topic = result.name.to_string(), + error = ?err, + message = result.error_message.map(|m| m.to_string()), + "Failed to increase partition count" + ); + return Err(anyhow::anyhow!( + "Failed to increase partition count for topic {}: {:?}", + result.name.as_str(), + err + )); + } else { + tracing::info!( + topic = result.name.to_string(), + "Successfully increased partition count", + ); } + } - Ok(()) + Ok(()) + } + + #[instrument(skip_all)] + async fn create_new_topics( + &mut self, + topics: Vec<(messages::TopicName, usize)>, + ) -> anyhow::Result<()> { + let coord = self.connect_to_controller().await?; + + let mut topics_map = vec![]; + for (topic_name, desired_partitions) in topics { + topics_map.push( + messages::create_topics_request::CreatableTopic::default() + .with_name(topic_name) + .with_replication_factor(2) + .with_num_partitions(desired_partitions as i32), + ); + } + + let create_req = messages::CreateTopicsRequest::default().with_topics(topics_map); + let create_resp = coord.send_request(create_req, None).await?; + tracing::debug!(create_response = ?create_resp, "Got create topics response"); + + for topic in create_resp.topics { + if topic.error_code > 0 { + let err = kafka_protocol::ResponseError::try_from_code(topic.error_code); + tracing::warn!( + topic = topic.name.to_string(), + error = ?err, + message = topic.error_message.map(|m| m.to_string()), + "Failed to create topic" + ); + return Err(anyhow::anyhow!("Failed to create topic")); + } else { + tracing::info!( + topic = topic.name.to_string(), + "Successfully created topic with {} partitions", + topic.num_partitions + ); + } } + + Ok(()) } } diff --git a/crates/dekaf/src/session.rs b/crates/dekaf/src/session.rs index c350be90ff..3eb5df41a8 100644 --- a/crates/dekaf/src/session.rs +++ b/crates/dekaf/src/session.rs @@ -6,6 +6,7 @@ use crate::{ }; use anyhow::{bail, Context}; use bytes::{BufMut, Bytes, BytesMut}; +use futures::TryFutureExt; use kafka_protocol::{ error::{ParseResponseErrorCode, ResponseError}, messages::{ @@ -18,6 +19,7 @@ use kafka_protocol::{ }, protocol::{buf::ByteBuf, Decodable, Encodable, Message, StrBytes}, }; +use rustls::crypto::hash::Hash; use std::{cmp::max, sync::Arc}; use std::{ collections::{hash_map::Entry, HashMap}, @@ -1103,13 +1105,23 @@ impl Session { #[instrument(skip_all, fields(group=?req.group_id))] pub async fn offset_commit( &mut self, - req: messages::OffsetCommitRequest, + mut req: messages::OffsetCommitRequest, header: RequestHeader, ) -> anyhow::Result { - let mut mutated_req = req.clone(); - for topic in &mut mutated_req.topics { + let collections = self + .fetch_collections(req.topics.iter().map(|topic| &topic.name)) + .await?; + + let desired_topic_partitions = collections + .iter() + .map(|(topic_name, collection)| { + self.encrypt_topic_name(topic_name.clone()) + .map(|encrypted_name| (encrypted_name, collection.partitions.len())) + }) + .collect::, _>>()?; + + for topic in &mut req.topics { let encrypted = self.encrypt_topic_name(topic.name.clone())?; - tracing::info!(topic_name = ?topic.name, partitions = ?topic.partitions, "Committing offset"); topic.name = encrypted; } @@ -1119,69 +1131,70 @@ impl Session { .connect_to_group_coordinator(req.group_id.as_str()) .await?; - client - .ensure_topics( - mutated_req - .topics - .iter() - .map(|t| t.name.to_owned()) - .collect(), - ) - .await?; - - let mut resp = client.send_request(mutated_req, Some(header)).await?; - - let auth = self - .auth - .as_mut() - .ok_or(anyhow::anyhow!("Session not authenticated"))?; - - let flow_client = auth.flow_client(&self.app).await?.clone(); + client.ensure_topics(desired_topic_partitions).await?; - // Redeclare to drop mutability - let auth = self.auth.as_ref().unwrap(); + let mut resp = client.send_request(req.clone(), Some(header)).await?; for topic in resp.topics.iter_mut() { - topic.name = self.decrypt_topic_name(topic.name.to_owned())?; + let encrypted_name = topic.name.clone(); + let decrypted_name = self.decrypt_topic_name(topic.name.to_owned())?; - let collection_partitions = Collection::new( - &self.app, - auth, - &flow_client.pg_client(), - topic.name.as_str(), - ) - .await? - .context(format!("unable to look up partitions for {:?}", topic.name))? - .partitions; + let collection_partitions = &collections + .iter() + .find(|(topic_name, _)| topic_name == &decrypted_name) + .context(format!( + "unable to look up partitions for {:?}", + decrypted_name + ))? + .1 + .partitions; for partition in &topic.partitions { if let Some(error) = partition.error_code.err() { - tracing::warn!(topic=?topic.name,partition=partition.partition_index,?error,"Got error from upstream Kafka when trying to commit offsets"); + tracing::warn!( + topic = decrypted_name.as_str(), + partition = partition.partition_index, + ?error, + "Got error from upstream Kafka when trying to commit offsets" + ); } else { + let response_partition_index = partition.partition_index; + let journal_name = collection_partitions - .get(partition.partition_index as usize) + .get(response_partition_index as usize) .context(format!( - "unable to find partition {} in collection {:?}", - partition.partition_index, topic.name + "unable to find collection partition idx {} in collection {:?}", + response_partition_index, + decrypted_name.as_str() ))? .spec .name .to_owned(); - let committed_offset = req + let request_partitions = &req .topics .iter() - .find(|req_topic| req_topic.name == topic.name) - .context(format!("unable to find topic in request {:?}", topic.name))? - .partitions - .get(partition.partition_index as usize) + .find(|req_topic| req_topic.name == encrypted_name) .context(format!( - "unable to find partition {}", - partition.partition_index + "unable to find topic in request {:?}", + decrypted_name.as_str() + ))? + .partitions; + + let committed_offset = request_partitions + .iter() + .find(|req_part| req_part.partition_index == response_partition_index) + .context(format!( + "Unable to find partition index {} in request partitions for topic {:?}, though response contained it. Request partitions: {:?}. Flow has: {:?}", + response_partition_index, + decrypted_name.as_str(), + request_partitions, + collection_partitions ))? .committed_offset; metrics::gauge!("dekaf_committed_offset", "group_id"=>req.group_id.to_string(),"journal_name"=>journal_name).set(committed_offset as f64); + tracing::info!(topic_name = ?topic.name, partitions = ?topic.partitions, committed_offset, "Committed offset"); } } } @@ -1192,11 +1205,23 @@ impl Session { #[instrument(skip_all, fields(group=?req.group_id))] pub async fn offset_fetch( &mut self, - req: messages::OffsetFetchRequest, + mut req: messages::OffsetFetchRequest, header: RequestHeader, ) -> anyhow::Result { - let mut mutated_req = req.clone(); - if let Some(ref mut topics) = mutated_req.topics { + let collection_partitions = if let Some(topics) = &req.topics { + self.fetch_collections(topics.iter().map(|topic| &topic.name)) + .await? + .into_iter() + .map(|(topic_name, collection)| { + self.encrypt_topic_name(topic_name) + .map(|encrypted_name| (encrypted_name, collection.partitions.len())) + }) + .collect::, _>>()? + } else { + vec![] + }; + + if let Some(ref mut topics) = req.topics { for topic in topics { topic.name = self.encrypt_topic_name(topic.name.clone())?; } @@ -1208,12 +1233,11 @@ impl Session { .connect_to_group_coordinator(req.group_id.as_str()) .await?; - if let Some(ref topics) = mutated_req.topics { - client - .ensure_topics(topics.iter().map(|t| t.name.to_owned()).collect()) - .await?; + if !collection_partitions.is_empty() { + client.ensure_topics(collection_partitions).await?; } - let mut resp = client.send_request(mutated_req, Some(header)).await?; + + let mut resp = client.send_request(req, Some(header)).await?; for topic in resp.topics.iter_mut() { topic.name = self.decrypt_topic_name(topic.name.to_owned())?; @@ -1318,6 +1342,30 @@ impl Session { } } + async fn fetch_collections( + &mut self, + topics: impl IntoIterator, + ) -> anyhow::Result> { + let auth = self + .auth + .as_mut() + .ok_or(anyhow::anyhow!("Session not authenticated"))?; + + let app = &self.app; + let flow_client = &auth.flow_client(app).await?.pg_client(); + + // Re-declare here to drop mutable reference + let auth = self.auth.as_ref().unwrap(); + + futures::future::try_join_all(topics.into_iter().map(|topic| async move { + let collection = Collection::new(app, auth, flow_client, topic.as_ref()) + .await? + .context(format!("unable to look up partitions for {:?}", topic))?; + Ok::<(TopicName, Collection), anyhow::Error>((topic.clone(), collection)) + })) + .await + } + /// If the fetched offset is within a fixed number of offsets from the end of the journal, /// return Some with a PartitionOffset containing the beginning and end of the latest fragment. #[tracing::instrument(skip(self))]