Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,10 @@ def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
return self._worker.get_finished(finished_req_ids)

@override
def get_block_ids_with_load_errors(self) -> set[int]:
"""Return block IDs that failed to load asynchronously."""
if self._worker is None:
return set()
return self._worker.get_block_ids_with_load_errors()
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,7 @@ def get_finished(
# finished_ids = [id for id in finished_req_ids]
# return set(sending_ids), set(receiving_ids)
return self._connector.get_finished(finished_req_ids)

def get_block_ids_with_load_errors(self) -> set[int]:
"""Get block IDs that failed to load and clear the set."""
return self._connector.get_block_ids_with_load_errors()
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,7 @@ impl VllmConnectorSlot {
uuid: operation_id,
transfer_type: TransferType::Store,
request_type: RequestType::Scheduled,
block_ids: block_ids.iter().map(|b| *b as usize).collect(),
};

if let Err(e) = self.xfer_tx.send(xfer_req) {
Expand Down Expand Up @@ -1035,6 +1036,9 @@ impl VllmConnectorSlot {
let src_storage_pool = src_blocks.storage_pool();
let operation_id = uuid::Uuid::new_v4();

// Capture block_ids before moving dst_block_ids
let block_ids: Vec<usize> = dst_block_ids.iter().map(|b| *b as usize).collect();

let xfer_req = LocalTransferRequest::Onboard(LocalOnboardRequest::new(
self.request_id.clone(),
src_blocks,
Expand All @@ -1047,6 +1051,7 @@ impl VllmConnectorSlot {
uuid: operation_id,
transfer_type: TransferType::Load,
request_type: RequestType::Immediate,
block_ids,
};

if let Err(e) = self.xfer_tx.send(xfer_req) {
Expand Down
55 changes: 54 additions & 1 deletion lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use dynamo_llm::block_manager::connector::scheduler::{
Scheduler, TransferSchedulerClient, WorkerSchedulerClient,
};

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, OnceLock};

use super::*;
Expand Down Expand Up @@ -48,6 +48,9 @@ pub trait Worker: Send + Sync {
&mut self,
finished_requests: HashSet<String>,
) -> (HashSet<String>, HashSet<String>);

/// Get block IDs that failed to load and clear the set.
fn get_block_ids_with_load_errors(&mut self) -> HashSet<u32>;
}

pub struct KvConnectorWorker {
Expand All @@ -73,6 +76,12 @@ pub struct KvConnectorWorker {

/// cuda events created by the python side
layer_events: Vec<u64>,

/// Map request_id to (uuid → block_ids) for error tracking
request_to_blocks: HashMap<String, HashMap<uuid::Uuid, Vec<usize>>>,

/// Block IDs that failed to load
failed_block_ids: HashSet<u32>,
}

impl KvConnectorWorker {
Expand Down Expand Up @@ -111,6 +120,8 @@ impl KvConnectorWorker {
layers_complete: 0,
kv_cache_layers: Vec::new(),
layer_events: Vec::new(),
request_to_blocks: HashMap::new(),
failed_block_ids: HashSet::new(),
})
}
}
Expand Down Expand Up @@ -282,6 +293,17 @@ impl Worker for KvConnectorWorker {
// immediately enqueue the onboarding operations
for operation in onboarding_operations {
let request_id = operation.request_id.clone();
let uuid = operation.uuid;
let block_ids = operation.block_ids.clone();

// Store block_ids mapping for error tracking (only for load operations)
if !block_ids.is_empty() {
self.request_to_blocks
.entry(request_id.clone())
.or_default()
.insert(uuid, block_ids);
}

self.connector.enqueue_request(operation);
self.maybe_finished_onboarding.insert(request_id);
}
Expand Down Expand Up @@ -451,13 +473,40 @@ impl Worker for KvConnectorWorker {
// remove the finished requests from the maybe finished set
for request_id in &is_finished_onboarding {
self.maybe_finished_onboarding.remove(request_id);
// Clean up block_ids tracking for this request
self.request_to_blocks.remove(request_id);
if self.connector.has_slot(request_id) {
self.connector.remove_slot(request_id);
}
}

(is_finished_offloading, is_finished_onboarding)
}

fn get_block_ids_with_load_errors(&mut self) -> HashSet<u32> {
// Drain failures from the scheduler and convert (request_id, uuid) -> block_ids
let failures = self.connector.drain_failures();
for (request_id, failed_uuids) in failures {
if let Some(uuid_to_blocks) = self.request_to_blocks.get(&request_id) {
for uuid in failed_uuids {
if let Some(block_ids) = uuid_to_blocks.get(&uuid) {
for &block_id in block_ids {
self.failed_block_ids.insert(block_id as u32);
}
tracing::warn!(
request_id = %request_id,
operation_id = %uuid,
block_ids = ?block_ids,
"load operation failed; marking blocks as failed"
);
}
}
}
}

// Return and clear the failed block IDs
std::mem::take(&mut self.failed_block_ids)
}
}

#[pyclass]
Expand Down Expand Up @@ -541,6 +590,10 @@ impl PyKvConnectorWorker {
) -> (HashSet<String>, HashSet<String>) {
self.connector_worker.get_finished(finished_requests)
}

pub fn get_block_ids_with_load_errors(&mut self) -> HashSet<u32> {
self.connector_worker.get_block_ids_with_load_errors()
}
}

use cudarc::driver::sys::{
Expand Down
3 changes: 3 additions & 0 deletions lib/llm/src/block_manager/connector/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ pub struct WorkerTransferRequest {
pub uuid: uuid::Uuid,
pub transfer_type: TransferType,
pub request_type: RequestType,
/// Block IDs for this transfer (for error tracking)
#[serde(default)]
pub block_ids: Vec<usize>,
}

/// Sent by Worker to Scheduler.
Expand Down
38 changes: 37 additions & 1 deletion lib/llm/src/block_manager/connector/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ impl TransferSchedulerClient {
pub struct WorkerSchedulerClient {
slots: HashMap<String, WorkerSchedulerClientSlot>,
scheduler_tx: mpsc::UnboundedSender<SchedulerMessage>,
/// Channel to receive failure notifications from the scheduler
failure_rx: mpsc::UnboundedReceiver<(String, uuid::Uuid)>,
iteration: u64,
iteration_complete: bool,
layers_complete: u32,
Expand All @@ -91,17 +93,29 @@ pub struct WorkerSchedulerClient {
impl WorkerSchedulerClient {
pub fn new(
scheduler_tx: mpsc::UnboundedSender<SchedulerMessage>,
failure_rx: mpsc::UnboundedReceiver<(String, uuid::Uuid)>,
_cancel_token: CancellationToken,
) -> Self {
Self {
slots: HashMap::new(),
scheduler_tx,
failure_rx,
iteration: 0,
iteration_complete: true,
layers_complete: 0,
}
}

/// Drain all pending failure notifications from the scheduler.
/// Returns a map of request_id -> set of failed operation uuids.
pub fn drain_failures(&mut self) -> HashMap<String, HashSet<uuid::Uuid>> {
let mut failures: HashMap<String, HashSet<uuid::Uuid>> = HashMap::new();
while let Ok((request_id, uuid)) = self.failure_rx.try_recv() {
failures.entry(request_id).or_default().insert(uuid);
}
failures
}

pub fn iteration(&self) -> u64 {
self.iteration
}
Expand Down Expand Up @@ -312,6 +326,10 @@ pub struct Scheduler {

// Messages from the transfer client arrive on this channel
transfer_rx: mpsc::Receiver<TransferToSchedulerMessage>,

// Channel to send failure notifications to the worker
failure_tx: mpsc::UnboundedSender<(String, uuid::Uuid)>,

iteration: u64,
layers_complete: u32,
iteration_complete: bool,
Expand All @@ -323,7 +341,8 @@ impl Scheduler {
) -> (Self, WorkerSchedulerClient, TransferSchedulerClient) {
let (scheduler_tx, scheduler_rx) = mpsc::unbounded_channel();
let (transfer_tx, transfer_rx) = mpsc::channel(128);
let worker_client = WorkerSchedulerClient::new(scheduler_tx, cancel_token);
let (failure_tx, failure_rx) = mpsc::unbounded_channel();
let worker_client = WorkerSchedulerClient::new(scheduler_tx, failure_rx, cancel_token);
let transfer_client = TransferSchedulerClient::new(transfer_tx);
(
Scheduler {
Expand All @@ -333,6 +352,7 @@ impl Scheduler {
enqueued_requests: HashMap::new(),
worker_rx: scheduler_rx,
transfer_rx,
failure_tx,
iteration: 0,
layers_complete: 0,
iteration_complete: true,
Expand Down Expand Up @@ -508,6 +528,18 @@ impl Scheduler {

#[tracing::instrument(level = "debug", skip_all, fields(request_id = %result.request_id, operation_id = %result.uuid))]
fn handle_immediate_result(&mut self, result: ImmediateTransferResult) {
// Check if this result indicates a failure
if result.status.is_err() {
tracing::warn!(
request_id = %result.request_id,
operation_id = %result.uuid,
error = ?result.status,
"immediate transfer failed; notifying worker"
);
// Send failure notification to worker (ignore send errors during shutdown)
let _ = self.failure_tx.send((result.request_id.clone(), result.uuid));
}

match self.slots.get_mut(&result.request_id) {
Some(slot) => {
slot.completed.fetch_add(1, Ordering::Relaxed);
Expand Down Expand Up @@ -808,6 +840,7 @@ mod tests {
uuid: operation_id,
transfer_type: TransferType::Load,
request_type: RequestType::Immediate,
block_ids: vec![],
};
worker_client.enqueue_request(worker_request);
assert_eq!(worker_client.slots.get("test").unwrap().operations.len(), 1);
Expand Down Expand Up @@ -864,6 +897,7 @@ mod tests {
uuid: operation_id,
transfer_type: TransferType::Load,
request_type: RequestType::Immediate,
block_ids: vec![],
};

// immediate requests are not passed to the scheduler, but the completion will be automatically
Expand Down Expand Up @@ -949,6 +983,7 @@ mod tests {
uuid: operation_id,
transfer_type: TransferType::Store,
request_type: RequestType::Scheduled,
block_ids: vec![],
};

// worker arrives last
Expand Down Expand Up @@ -1005,6 +1040,7 @@ mod tests {
uuid: operation_id,
transfer_type: TransferType::Store,
request_type: RequestType::Scheduled,
block_ids: vec![],
};

// worker arrives first
Expand Down
Loading