diff --git a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py index 47cbb5a2182..d6aa1adf8ca 100644 --- a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py +++ b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py @@ -139,3 +139,10 @@ def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: return self._worker.get_finished(finished_req_ids) + + @override + def get_block_ids_with_load_errors(self) -> set[int]: + """Return block IDs that failed to load asynchronously.""" + if self._worker is None: + return set() + return self._worker.get_block_ids_with_load_errors() diff --git a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py index 0deadf48b70..9ef7e4d85e8 100644 --- a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py +++ b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py @@ -203,3 +203,7 @@ def get_finished( # finished_ids = [id for id in finished_req_ids] # return set(sending_ids), set(receiving_ids) return self._connector.get_finished(finished_req_ids) + + def get_block_ids_with_load_errors(self) -> set[int]: + """Get block IDs that failed to load and clear the set.""" + return self._connector.get_block_ids_with_load_errors() diff --git a/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs index 6b75dd64dce..4e09e80d315 100644 --- a/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs @@ -1002,6 +1002,7 @@ impl VllmConnectorSlot { uuid: operation_id, transfer_type: TransferType::Store, request_type: RequestType::Scheduled, + block_ids: block_ids.iter().map(|b| *b as usize).collect(), }; if let Err(e) = self.xfer_tx.send(xfer_req) { @@ -1035,6 +1036,9 @@ impl VllmConnectorSlot { let src_storage_pool = src_blocks.storage_pool(); let operation_id = uuid::Uuid::new_v4(); + // Capture block_ids before moving dst_block_ids + let block_ids: Vec = dst_block_ids.iter().map(|b| *b as usize).collect(); + let xfer_req = LocalTransferRequest::Onboard(LocalOnboardRequest::new( self.request_id.clone(), src_blocks, @@ -1047,6 +1051,7 @@ impl VllmConnectorSlot { uuid: operation_id, transfer_type: TransferType::Load, request_type: RequestType::Immediate, + block_ids, }; if let Err(e) = self.xfer_tx.send(xfer_req) { diff --git a/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs b/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs index 8dac9954bfc..b57fb07ae3a 100644 --- a/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs @@ -6,7 +6,7 @@ use dynamo_llm::block_manager::connector::scheduler::{ Scheduler, TransferSchedulerClient, WorkerSchedulerClient, }; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, OnceLock}; use super::*; @@ -48,6 +48,9 @@ pub trait Worker: Send + Sync { &mut self, finished_requests: HashSet, ) -> (HashSet, HashSet); + + /// Get block IDs that failed to load and clear the set. + fn get_block_ids_with_load_errors(&mut self) -> HashSet; } pub struct KvConnectorWorker { @@ -73,6 +76,12 @@ pub struct KvConnectorWorker { /// cuda events created by the python side layer_events: Vec, + + /// Map request_id to (uuid → block_ids) for error tracking + request_to_blocks: HashMap>>, + + /// Block IDs that failed to load + failed_block_ids: HashSet, } impl KvConnectorWorker { @@ -111,8 +120,49 @@ impl KvConnectorWorker { layers_complete: 0, kv_cache_layers: Vec::new(), layer_events: Vec::new(), + request_to_blocks: HashMap::new(), + failed_block_ids: HashSet::new(), }) } + + // TODO: Move this out of the bindings + /// Drains pending failures from the scheduler and accumulates failed block IDs. + fn process_pending_failures(&mut self) { + let failures = self.connector.drain_failures(); + for (request_id, failed_uuids) in failures { + if let Some(uuid_to_blocks) = self.request_to_blocks.get(&request_id) { + for uuid in failed_uuids { + if let Some(block_ids) = uuid_to_blocks.get(&uuid) { + for &block_id in block_ids { + self.failed_block_ids.insert(block_id as u32); + } + tracing::warn!( + request_id = %request_id, + operation_id = %uuid, + block_ids = ?block_ids, + "load operation failed; marking blocks as failed" + ); + } + } + } + } + } + + /// Cleans up tracking state for a finished onboarding request. + /// + /// This method ensures failures are processed before removing the block ID mapping, + /// preserving the failure->block correlation. + fn cleanup_onboarding_request(&mut self, request_id: &str) { + // Process any pending failures first (ensures we can still map to block IDs) + self.process_pending_failures(); + + // Now safe to remove tracking state + self.maybe_finished_onboarding.remove(request_id); + self.request_to_blocks.remove(request_id); + if self.connector.has_slot(request_id) { + self.connector.remove_slot(request_id); + } + } } impl Worker for KvConnectorWorker { @@ -282,6 +332,17 @@ impl Worker for KvConnectorWorker { // immediately enqueue the onboarding operations for operation in onboarding_operations { let request_id = operation.request_id.clone(); + let uuid = operation.uuid; + let block_ids = operation.block_ids.clone(); + + // Store block_ids mapping for error tracking (only for load operations) + if !block_ids.is_empty() { + self.request_to_blocks + .entry(request_id.clone()) + .or_default() + .insert(uuid, block_ids); + } + self.connector.enqueue_request(operation); self.maybe_finished_onboarding.insert(request_id); } @@ -448,16 +509,21 @@ impl Worker for KvConnectorWorker { } } - // remove the finished requests from the maybe finished set + // Clean up finished onboarding requests for request_id in &is_finished_onboarding { - self.maybe_finished_onboarding.remove(request_id); - if self.connector.has_slot(request_id) { - self.connector.remove_slot(request_id); - } + self.cleanup_onboarding_request(request_id); } (is_finished_offloading, is_finished_onboarding) } + + fn get_block_ids_with_load_errors(&mut self) -> HashSet { + // Process any remaining failures that haven't been handled yet + self.process_pending_failures(); + + // Return and clear the accumulated failed block IDs + std::mem::take(&mut self.failed_block_ids) + } } #[pyclass] @@ -541,6 +607,10 @@ impl PyKvConnectorWorker { ) -> (HashSet, HashSet) { self.connector_worker.get_finished(finished_requests) } + + pub fn get_block_ids_with_load_errors(&mut self) -> HashSet { + self.connector_worker.get_block_ids_with_load_errors() + } } use cudarc::driver::sys::{ diff --git a/lib/llm/src/block_manager/connector/protocol.rs b/lib/llm/src/block_manager/connector/protocol.rs index 8aa85a552ec..5709fd96e5f 100644 --- a/lib/llm/src/block_manager/connector/protocol.rs +++ b/lib/llm/src/block_manager/connector/protocol.rs @@ -149,6 +149,9 @@ pub struct WorkerTransferRequest { pub uuid: uuid::Uuid, pub transfer_type: TransferType, pub request_type: RequestType, + /// Block IDs for this transfer (for error tracking) + #[serde(default)] + pub block_ids: Vec, } /// Sent by Worker to Scheduler. diff --git a/lib/llm/src/block_manager/connector/scheduler.rs b/lib/llm/src/block_manager/connector/scheduler.rs index a19d6e4fb48..f2256b5597c 100644 --- a/lib/llm/src/block_manager/connector/scheduler.rs +++ b/lib/llm/src/block_manager/connector/scheduler.rs @@ -83,6 +83,8 @@ impl TransferSchedulerClient { pub struct WorkerSchedulerClient { slots: HashMap, scheduler_tx: mpsc::UnboundedSender, + /// Channel to receive failure notifications from the scheduler + failure_rx: mpsc::UnboundedReceiver<(String, uuid::Uuid)>, iteration: u64, iteration_complete: bool, layers_complete: u32, @@ -91,17 +93,29 @@ pub struct WorkerSchedulerClient { impl WorkerSchedulerClient { pub fn new( scheduler_tx: mpsc::UnboundedSender, + failure_rx: mpsc::UnboundedReceiver<(String, uuid::Uuid)>, _cancel_token: CancellationToken, ) -> Self { Self { slots: HashMap::new(), scheduler_tx, + failure_rx, iteration: 0, iteration_complete: true, layers_complete: 0, } } + /// Drain all pending failure notifications from the scheduler. + /// Returns a map of request_id -> set of failed operation uuids. + pub fn drain_failures(&mut self) -> HashMap> { + let mut failures: HashMap> = HashMap::new(); + while let Ok((request_id, uuid)) = self.failure_rx.try_recv() { + failures.entry(request_id).or_default().insert(uuid); + } + failures + } + pub fn iteration(&self) -> u64 { self.iteration } @@ -312,6 +326,10 @@ pub struct Scheduler { // Messages from the transfer client arrive on this channel transfer_rx: mpsc::Receiver, + + // Channel to send failure notifications to the worker + failure_tx: mpsc::UnboundedSender<(String, uuid::Uuid)>, + iteration: u64, layers_complete: u32, iteration_complete: bool, @@ -323,7 +341,8 @@ impl Scheduler { ) -> (Self, WorkerSchedulerClient, TransferSchedulerClient) { let (scheduler_tx, scheduler_rx) = mpsc::unbounded_channel(); let (transfer_tx, transfer_rx) = mpsc::channel(128); - let worker_client = WorkerSchedulerClient::new(scheduler_tx, cancel_token); + let (failure_tx, failure_rx) = mpsc::unbounded_channel(); + let worker_client = WorkerSchedulerClient::new(scheduler_tx, failure_rx, cancel_token); let transfer_client = TransferSchedulerClient::new(transfer_tx); ( Scheduler { @@ -333,6 +352,7 @@ impl Scheduler { enqueued_requests: HashMap::new(), worker_rx: scheduler_rx, transfer_rx, + failure_tx, iteration: 0, layers_complete: 0, iteration_complete: true, @@ -508,6 +528,20 @@ impl Scheduler { #[tracing::instrument(level = "debug", skip_all, fields(request_id = %result.request_id, operation_id = %result.uuid))] fn handle_immediate_result(&mut self, result: ImmediateTransferResult) { + // Check if this result indicates a failure + if result.status.is_err() { + tracing::warn!( + request_id = %result.request_id, + operation_id = %result.uuid, + error = ?result.status, + "immediate transfer failed; notifying worker" + ); + // Send failure notification to worker (ignore send errors during shutdown) + let _ = self + .failure_tx + .send((result.request_id.clone(), result.uuid)); + } + match self.slots.get_mut(&result.request_id) { Some(slot) => { slot.completed.fetch_add(1, Ordering::Relaxed); @@ -808,6 +842,7 @@ mod tests { uuid: operation_id, transfer_type: TransferType::Load, request_type: RequestType::Immediate, + block_ids: vec![], }; worker_client.enqueue_request(worker_request); assert_eq!(worker_client.slots.get("test").unwrap().operations.len(), 1); @@ -864,6 +899,7 @@ mod tests { uuid: operation_id, transfer_type: TransferType::Load, request_type: RequestType::Immediate, + block_ids: vec![], }; // immediate requests are not passed to the scheduler, but the completion will be automatically @@ -949,6 +985,7 @@ mod tests { uuid: operation_id, transfer_type: TransferType::Store, request_type: RequestType::Scheduled, + block_ids: vec![], }; // worker arrives last @@ -1005,6 +1042,7 @@ mod tests { uuid: operation_id, transfer_type: TransferType::Store, request_type: RequestType::Scheduled, + block_ids: vec![], }; // worker arrives first