Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/cache_manager/cache_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CacheStatus(Enum):
CPU = 3
GPU2STORAGE = 4
STORAGE2GPU = 5
CTRL = -1


class BlockNode:
Expand Down
373 changes: 226 additions & 147 deletions fastdeploy/cache_manager/cache_transfer_manager.py

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def launch_cache_manager(
else:
storage_arg_str = " "

if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend:
if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend:
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
Expand All @@ -314,7 +314,6 @@ def launch_cache_manager(
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --ipc_suffix {ipc_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
Expand Down Expand Up @@ -353,9 +352,8 @@ def launch_cache_manager(

# Start additional threads
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
threading.Thread(target=self.recv_data_transfer_result, daemon=True).start()
if cache_config.enable_prefix_caching:
if cache_config.enable_prefix_caching and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()

all_cache_processes = cache_messager_processes + cache_manager_processes
Expand Down
192 changes: 168 additions & 24 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from tqdm import tqdm

import fastdeploy.metrics.trace as tracing
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
Expand Down Expand Up @@ -84,7 +86,7 @@ class EngineService:
Base class containing common engine functionality
"""

def __init__(self, cfg, start_queue=True, use_async_llm=False):
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
"""
Initializes the LLMEngine with the provided configuration.

Expand All @@ -104,14 +106,22 @@ def __init__(self, cfg, start_queue=True, use_async_llm=False):
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()

self._ctrl_worker_output_queues = []
self._ctrl_output_queues = {}
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for rank in range(tp_size):
for tp_rank in range(tp_size):
# create worker control response queue
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")

# create cache control response queue
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跟worker、cache transfer的控制通信感觉最好区分开,因为不一定所有的控制信号都会发给他俩


self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
Expand Down Expand Up @@ -1318,6 +1328,7 @@ def _control_pause(self, control_request: ControlRequest):
with self._pause_cond:
if self.is_paused:
self.llm_logger.info("Pause Request Generation: already paused.")
return
self.is_paused = True

self.llm_logger.info("Start Abort Running Requests")
Expand Down Expand Up @@ -1351,7 +1362,16 @@ def _control_pause(self, control_request: ControlRequest):
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
self.scheduler.reset()

# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Pause cache transfer")
pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"]))

self.resource_manager.cache_manager.reset()
self.llm_logger.info("END Pause Request Generation")
return None

def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
Expand All @@ -1370,6 +1390,15 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
return None
self.is_paused = False
self._pause_cond.notify_all()

# resume cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Resume cache transfer")
resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"]))

self.llm_logger.info("END Resume Request Generation")
return None

Expand Down Expand Up @@ -1406,15 +1435,129 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d
raise Exception(error_msg)
return self._call_worker(control_request, 60)

async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.
def _control_sleep(self, control_request: ControlRequest):
"""
Offload gpu memory occupation for certain parts, e.g. weight, cache.

Args:
control_request: Control request object containing parameters for offloading memory
tags: list of tags to offload, supported values: ["weight", "cache"]

This method concurrently waits for responses from all control workers
and enforces an overall timeout to avoid leaking pending tasks.
TODO: support different level of offloading, to provide options for release memory forever
or merely offloading to cpu memory for now.
"""
timeout_ms = timeout * 1000
# Create one get() coroutine per worker output queue
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
# Args check
allowed_tags = ["weight", "kv_cache"]
tags = control_request.args.get("tags", "")

for tag in tags.split(","):
if tag not in allowed_tags:
raise ValueError(f"unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}")

# Make sure llm engine is paused.
self.llm_logger.warning(
"Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. "
"Please explicitly request to /pause the engine before /sleep."
)
self._control_pause(None)

# Determine which executors are needed for the sleep command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset()
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")

# Dispatch sleep request to executors
self._dispatch_control_request(control_request, executors)
return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))

def _control_wakeup(self, control_request: ControlRequest):
"""
Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache.

Args:
control_request: Control request object containing parameters for reloading memory
tags: list of tags to reload, supported values: ["weight", "kv_cache"]
"""
# Args check
allowed_tags = ["weight", "kv_cache"]
tags = control_request.args.get("tags", "")

for tag in tags.split(","):
if tag not in allowed_tags:
raise ValueError(f"unsupported tag {tag} in {tags}, expected one of {allowed_tags}")

# Determine which executors are needed for the wakeup command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")

# Dispatch wakeup request to executors
self._dispatch_control_request(control_request, executors)
result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))

# Resume the engine after wakeup
self._control_resume(None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sleep和wake_up 内需要包含 pause和resume吗?是不是交由上游中控来调用,这样sleep和wake_up的语义更加明确

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里一方面是想做个兜底,另一方面是为了兼容老接口语义,调一次 clear_load_weight 和调一次 sleep 可以实现相同的效果。最好确实需要拆分一下,可以用参数控制 sleep 时是否需要隐式包含 pause


return result

def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]):
"""
Dispatch control requests to workers, cache managers or engine itself.

Args:
control_request: ControlRequest
executors: List
"""
if "worker" in executors:
self.engine_worker_queue.put_tasks(([control_request], 1))
if "cache_transfer" in executors:
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request))
return

async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None):
"""Wait for control responses from specified queues.

Args:
request_id: The request ID to match responses against
timeout: Global timeout in seconds
executors: List of executors to wait for, e.g., ["worker", "cache_transfer"]
If None, waits for all queues
"""
timeout_ms = timeout * 1000 if timeout else None

# determine which queues to wait for by executors
queues = {}
if executors is None:
queues = self._ctrl_output_queues
else:
if "worker" in executors:
for name, queue in self._ctrl_output_queues.items():
if "w2e" in name:
queues[name] = queue
if "cache_transfer" in executors:
for name, queue in self._ctrl_output_queues.items():
if "c2e" in name:
queues[name] = queue

if not queues:
self.llm_logger.info(f"No queues to wait for, executors: {executors}")
return
self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}")

# Create one get() coroutine per queue
tasks = [q.get(timeout=timeout_ms) for q in queues.values()]

try:
results = await asyncio.wait_for(
Expand All @@ -1423,32 +1566,31 @@ async def _wait_all_control_responses(self, request_id: str, timeout: int):
)
except asyncio.TimeoutError:
# Keep the error message consistent with previous behavior
raise Exception("Worker Update Weights Timeouted after 600s")
raise Exception(f"Control request {request_id} timeouted after {timeout}s")

responses = []
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
for name, msg in zip(queues.keys(), results):
if isinstance(msg, Exception):
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
raise Exception(f"Call Worker error: {repr(msg)}")
self.llm_logger.error(f"Call {name} failed: {repr(msg)}")
raise Exception(f"Call {name} error: {repr(msg)}")
if msg is None:
# Preserve original semantics when no message is received
raise Exception("Worker Update Weights Timeouted after 600s")
raise Exception(f"No message received from {name}")
response: ControlResponse = msg.payload
if response.request_id != request_id:
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
self.llm_logger.info(f"ignore old control response from {name}: {response}")
continue
if response.error_code != 200:
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
raise Exception(f"Call Worker error: {response.error_message}")
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
self.llm_logger.info(f"Call {name} failed: {response.error_message}")
raise Exception(f"Call {name} error: {response.error_message}")
self.llm_logger.info(f"Call {name} succeed: {response.result}")
responses.append(response.result)
return responses

def _call_worker(self, control_request: ControlRequest, timeout: int):
request_id = control_request.request_id
self.engine_worker_queue.put_tasks(([control_request], 1))
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"]))

def _send_error_response(self, request_id, error_msg, error_code: int = 500):
self.llm_logger.error(
Expand Down Expand Up @@ -1487,6 +1629,8 @@ def _zmq_send_generated_tokens(self):
"""
while self.running:
try:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output为什么需要感知pause呢?如果没有新请求,output就不会有新token要处理;对于正在处理的请求,应该要等到preempted调度后自行结束,否则可能会有中间token阻塞在输出队列里

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里忘删了

results = self.scheduler.get_results()
if len(results) == 0:
time.sleep(0.005)
Expand Down
10 changes: 10 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,16 @@ def __init__(
self.method = method
self.args = args or {}

self._post_init()

def _post_init(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个感觉放到接口自己维护比较好,request逻辑尽量保持通用

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

if self.method in ("sleep", "wakeup"):
tags = self.args.get("tags")
if tags is None or tags == "":
self.args["tags"] = "weight,kv_cache"
elif not isinstance(tags, str):
raise TypeError("tags must be a string")

@classmethod
def from_dict(cls, d: dict):
"""Create ControlRequest instance from dictionary."""
Expand Down
23 changes: 20 additions & 3 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def check_health(self, time_interval_threashold=30):
return True, ""

async def run_control_method(self, request: ControlRequest):
api_server_logger.info(f"Start Run Control Method: {request}")
api_server_logger.info(f"Received control request: {request}")
self.zmq_client.send_json(request.to_dict())
request_id = request.request_id
dealer, response_queue = await self.connection_manager.get_connection(request_id)
Expand All @@ -601,12 +601,29 @@ async def run_control_method(self, request: ControlRequest):
# todo: support user specified timeout. default 600s is enough for most control cases
response = await asyncio.wait_for(response_queue.get(), timeout=600)
response = ControlResponse.from_dict(response[0])
api_server_logger.info(f"End Run Control Method: {response}")
api_server_logger.info(f"Return control response: {response}")
return response
except asyncio.TimeoutError:
error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response")
api_server_logger.error(f"Error Run Control Method: {error_response}")
api_server_logger.error(f"Control request timed out: {error_response}")
return error_response
except Exception as e:
import traceback

api_server_logger.error(f"Unknown error in control method: {str(e)}\n{traceback.format_exc()}")
error_response = ControlResponse(request_id, 500, str(e))
return error_response

def run_control_method_sync(self, request: ControlRequest, event_loop):
"""
Support running control methods by a synchronous caller.

NOTE: Since asyncio.Queue operations must occur in the same event loop,
this method bridges synchronous and asynchronous execution by running
the async run_control_method in the specified event loop.
"""
future = asyncio.run_coroutine_threadsafe(self.run_control_method(request), event_loop)
return future.result()

def is_workers_alive(self):
"""
Expand Down
Loading
Loading