Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,24 @@ def recycle_cpu_blocks(self, cpu_block_ids):
else:
heapq.heappush(self.cpu_free_block_list, cpu_block_ids)

def _acquire_kvcache_lock(self):
"""Acquire the GPU KV cache lock for the transfer process.

Uses a file-based lock (fcntl.flock) to ensure mutual exclusion
between the worker and the CPU transfer process. This prevents
concurrent GPU KV cache access which may cause NaN errors under
certain DP+EP configurations.
"""
if not envs.FD_USE_KVCACHE_LOCK:
return
self.gpu_cache_lock.acquire()

def _release_kvcache_lock(self):
"""Release the GPU KV cache lock held by the transfer process."""
if not envs.FD_USE_KVCACHE_LOCK:
return
self.gpu_cache_lock.release()

def issue_swap_task(
self,
transfer_task_id,
Expand All @@ -541,7 +559,8 @@ def issue_swap_task(
event_type: CacheStatus.SWAP2GPU or CacheStatus.SWAP2CPU
is_sync: bool, whether to wait for the result of the swap task
"""

assert is_sync, "Only support is sync for swap_task now."
self._acquire_kvcache_lock()
self.task_swapping_event[transfer_task_id] = Event()
self.cache_task_queue.put_transfer_task(
(
Expand All @@ -554,6 +573,7 @@ def issue_swap_task(
)
if is_sync:
self.sync_swap_task(transfer_task_id)
self._release_kvcache_lock()

def sync_swap_task(self, transfer_task_id):
"""
Expand Down
13 changes: 13 additions & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCLock,
IPCSignal,
ZmqIpcServer,
ZmqTcpServer,
Expand Down Expand Up @@ -144,6 +145,10 @@ def __init__(self, cfg, start_queue=True):
)
self._init_worker_monitor_signals()

# Pass the GPU KV cache lock to cache_manager for mutual exclusion
# between the CPU transfer process and the worker process.
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock

if self.cfg.eplb_config.enable_eplb:
current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
Expand Down Expand Up @@ -275,6 +280,14 @@ def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进
create=True,
)

# gpu_cache_lock: file-based lock for mutual exclusion between worker
# and CPU transfer when accessing GPU KV cache.
self.gpu_cache_lock = IPCLock(
name="gpu_cache_lock",
suffix=current_suffix,
create=True,
)

def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@
"FD_CONFIG_ROOT": lambda: os.path.expanduser(
os.getenv("FD_CONFIG_ROOT", os.path.join(os.path.expanduser("~"), ".config", "fastdeploy"))
),
# Whether to enable KV cache lock, enforcing mutual exclusion between
# PrefixCacheManager and Worker when accessing GPU KV cache.
# Under certain DP+EP configurations, concurrent access (even read-only)
# has been observed to cause NaN computation errors.
# Set to 1 to enable the lock; defaults to 0 (disabled).
"FD_USE_KVCACHE_LOCK": lambda: bool(int(os.getenv("FD_USE_KVCACHE_LOCK", "0"))),
# Suspend rollouting routing replay
"FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))),
}
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/inter_communicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal, shared_memory_exists
from .ipc_signal import IPCLock, IPCSignal, shared_memory_exists
from .ipc_signal_const import (
ExistTaskStatus,
KVCacheStatus,
Expand All @@ -31,6 +31,7 @@
"ZmqIpcClient",
"ZmqIpcServer",
"ZmqTcpServer",
"IPCLock",
"IPCSignal",
"EngineWorkerQueue",
"EngineCacheQueue",
Expand Down
57 changes: 57 additions & 0 deletions fastdeploy/inter_communicator/ipc_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""

import fcntl
import os
from multiprocessing.shared_memory import SharedMemory

import numpy as np
Expand Down Expand Up @@ -114,3 +116,58 @@ def clear(self) -> None:
if shared_memory_exists(self.shm.name):
self.shm.close()
self.shm.unlink()


class IPCLock:
"""A file-based inter-process lock using fcntl.flock.

Provides mutual exclusion between processes that may be spawned via
subprocess (not just fork/multiprocessing). Lock files are stored in
/dev/shm/ for performance, falling back to /tmp/.

Args:
name: Unique identifier for the lock.
suffix: Optional suffix appended to the name to avoid conflicts
when multiple engines are launched.
create: If True, creates the lock file; otherwise opens an
existing one.
"""

def __init__(self, name: str, suffix: int = None, create: bool = True) -> None:
if suffix is not None:
name = f"{name}.{suffix}"

lock_dir = "/dev/shm" if os.path.isdir("/dev/shm") else "/tmp"
self._lock_path = os.path.join(lock_dir, f"fd_lock_{name}")

if create:
llm_logger.debug(f"creating ipc lock: {self._lock_path}")
# Use restrictive permissions to avoid other users acquiring the lock.
self._fd = os.open(self._lock_path, os.O_CREAT | os.O_RDWR, 0o600)
else:
llm_logger.debug(f"attaching ipc lock: {self._lock_path}")
try:
self._fd = os.open(self._lock_path, os.O_RDWR)
except FileNotFoundError as e:
llm_logger.error(
f"Failed to attach IPC lock: {self._lock_path} does not exist. "
"Ensure that the lock has been created (create=True) with the same "
"name and suffix before attaching."
)
raise RuntimeError(f"IPC lock file not found: {self._lock_path}") from e

def acquire(self) -> None:
"""Acquire the lock (blocking). Uses kernel-level flock for atomicity."""
fcntl.flock(self._fd, fcntl.LOCK_EX)

def release(self) -> None:
"""Release the lock."""
fcntl.flock(self._fd, fcntl.LOCK_UN)

def clear(self) -> None:
"""Close the file descriptor and remove the lock file."""
os.close(self._fd)
try:
os.unlink(self._lock_path)
except FileNotFoundError:
pass
41 changes: 41 additions & 0 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import (
ExistTaskStatus,
IPCLock,
IPCSignal,
ModelWeightsStatus,
RearrangeExpertStatus,
Expand Down Expand Up @@ -275,6 +276,14 @@ def init_health_status(self) -> None:
create=False,
)

# gpu_cache_lock: file-based lock for mutual exclusion between worker
# and CPU transfer when accessing GPU KV cache.
self.gpu_cache_lock = IPCLock(
name="gpu_cache_lock",
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)

def update_weights_from_tensor(self, mmap_infos):
"""
update_weights_from_tensor
Expand Down Expand Up @@ -417,6 +426,35 @@ def _run_eplb(self, tp_rank):
self.rearrange_experts_signal.value[0] = RearrangeExpertStatus.DONE.value
logger.info("redundant_expert: done")

def _acquire_kvcache_lock(self, tp_rank):
"""Acquire the GPU KV cache lock for the worker process.

Uses a file-based lock (fcntl.flock) to ensure mutual exclusion
between the worker and the CPU transfer process during model
execution. Only rank 0 acquires the lock to avoid deadlock among
tensor-parallel workers.

Args:
tp_rank: Tensor parallel rank of the current worker. Only rank 0
acquires the lock.
"""
if not envs.FD_USE_KVCACHE_LOCK:
return
if tp_rank == 0:
self.gpu_cache_lock.acquire()

def _release_kvcache_lock(self, tp_rank):
"""Release the GPU KV cache lock held by the worker process.

Args:
tp_rank: Tensor parallel rank of the current worker. Only rank 0
releases the lock.
"""
if not envs.FD_USE_KVCACHE_LOCK:
return
if tp_rank == 0:
self.gpu_cache_lock.release()

def event_loop_normal(self) -> None:
"""Main event loop for Paddle Distributed Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
Expand Down Expand Up @@ -547,7 +585,10 @@ def event_loop_normal(self) -> None:
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
start_execute_time = time.time()

self._acquire_kvcache_lock(tp_rank)
self.worker.execute_model(req_dicts, max_occupied_batch_index)
self._release_kvcache_lock(tp_rank)
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")

Expand Down
Loading