-
Notifications
You must be signed in to change notification settings - Fork 722
[Feature] support v1 update/clear api for RL #6761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
7bcdc1d
350a315
4ddddf6
1f36a38
5a67e28
0dbda1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ class CacheStatus(Enum): | |
| CPU = 3 | ||
| GPU2STORAGE = 4 | ||
| STORAGE2GPU = 5 | ||
| CTRL = -1 | ||
|
|
||
|
|
||
| class BlockNode: | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,8 @@ | |
| from tqdm import tqdm | ||
|
|
||
| import fastdeploy.metrics.trace as tracing | ||
| from fastdeploy.cache_manager.cache_data import CacheStatus | ||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.engine.request import ( | ||
| ControlRequest, | ||
| ControlResponse, | ||
|
|
@@ -84,7 +86,7 @@ class EngineService: | |
| Base class containing common engine functionality | ||
| """ | ||
|
|
||
| def __init__(self, cfg, start_queue=True, use_async_llm=False): | ||
| def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): | ||
| """ | ||
| Initializes the LLMEngine with the provided configuration. | ||
|
|
||
|
|
@@ -104,14 +106,22 @@ def __init__(self, cfg, start_queue=True, use_async_llm=False): | |
| self.is_paused = False # pause request generation | ||
| self._pause_cond = threading.Condition() | ||
|
|
||
| self._ctrl_worker_output_queues = [] | ||
| self._ctrl_output_queues = {} | ||
| tp_size = cfg.parallel_config.tensor_parallel_size | ||
| dp_index = cfg.parallel_config.local_data_parallel_id | ||
| for rank in range(tp_size): | ||
| for tp_rank in range(tp_size): | ||
| # create worker control response queue | ||
| engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port | ||
| name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}" | ||
| self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)") | ||
| self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer")) | ||
| name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}" | ||
| self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)") | ||
| self._ctrl_output_queues[name] = FMQ().queue(name, "consumer") | ||
|
|
||
| # create cache control response queue | ||
| if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: | ||
| engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port | ||
| name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}" | ||
| self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)") | ||
| self._ctrl_output_queues[name] = FMQ().queue(name, "consumer") | ||
|
|
||
| self.scheduler = cfg.scheduler_config.scheduler() | ||
| self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" | ||
|
|
@@ -1318,6 +1328,7 @@ def _control_pause(self, control_request: ControlRequest): | |
| with self._pause_cond: | ||
| if self.is_paused: | ||
| self.llm_logger.info("Pause Request Generation: already paused.") | ||
| return | ||
| self.is_paused = True | ||
|
|
||
| self.llm_logger.info("Start Abort Running Requests") | ||
|
|
@@ -1351,7 +1362,16 @@ def _control_pause(self, control_request: ControlRequest): | |
| self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.") | ||
| self.scheduler.reset() | ||
|
|
||
| # pause cache transfer | ||
| if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: | ||
| self.llm_logger.info("Pause cache transfer") | ||
| pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause") | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request)) | ||
| # Wait for cache_transfer responses | ||
| asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"])) | ||
|
|
||
| self.resource_manager.cache_manager.reset() | ||
| self.llm_logger.info("END Pause Request Generation") | ||
| return None | ||
|
|
||
| def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: | ||
|
|
@@ -1370,6 +1390,15 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: | |
| return None | ||
| self.is_paused = False | ||
| self._pause_cond.notify_all() | ||
|
|
||
| # resume cache transfer | ||
| if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: | ||
| self.llm_logger.info("Resume cache transfer") | ||
| resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume") | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request)) | ||
| # Wait for cache_transfer responses | ||
| asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"])) | ||
|
|
||
| self.llm_logger.info("END Resume Request Generation") | ||
| return None | ||
|
|
||
|
|
@@ -1406,15 +1435,129 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d | |
| raise Exception(error_msg) | ||
| return self._call_worker(control_request, 60) | ||
|
|
||
| async def _wait_all_control_responses(self, request_id: str, timeout: int): | ||
| """Wait for control responses from all workers with a global timeout. | ||
| def _control_sleep(self, control_request: ControlRequest): | ||
| """ | ||
| Offload gpu memory occupation for certain parts, e.g. weight, cache. | ||
|
|
||
| Args: | ||
| control_request: Control request object containing parameters for offloading memory | ||
| tags: list of tags to offload, supported values: ["weight", "cache"] | ||
|
|
||
| This method concurrently waits for responses from all control workers | ||
| and enforces an overall timeout to avoid leaking pending tasks. | ||
| TODO: support different level of offloading, to provide options for release memory forever | ||
| or merely offloading to cpu memory for now. | ||
| """ | ||
| timeout_ms = timeout * 1000 | ||
| # Create one get() coroutine per worker output queue | ||
| tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues] | ||
| # Args check | ||
| allowed_tags = ["weight", "kv_cache"] | ||
| tags = control_request.args.get("tags", "") | ||
|
|
||
| for tag in tags.split(","): | ||
| if tag not in allowed_tags: | ||
| raise ValueError(f"unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}") | ||
|
|
||
| # Make sure llm engine is paused. | ||
| self.llm_logger.warning( | ||
| "Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. " | ||
| "Please explicitly request to /pause the engine before /sleep." | ||
| ) | ||
| self._control_pause(None) | ||
|
|
||
| # Determine which executors are needed for the sleep command | ||
| executors = set() | ||
| if "weight" in tags: | ||
| executors.add("worker") | ||
| if "kv_cache" in tags: | ||
| executors.add("worker") | ||
| if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: | ||
| executors.add("cache_transfer") | ||
| if self.cfg.cache_config.enable_prefix_caching: | ||
| self.resource_manager.cache_manager.reset() | ||
| self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}") | ||
|
|
||
| # Dispatch sleep request to executors | ||
| self._dispatch_control_request(control_request, executors) | ||
| return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors)) | ||
|
|
||
| def _control_wakeup(self, control_request: ControlRequest): | ||
| """ | ||
| Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache. | ||
|
|
||
| Args: | ||
| control_request: Control request object containing parameters for reloading memory | ||
| tags: list of tags to reload, supported values: ["weight", "kv_cache"] | ||
| """ | ||
| # Args check | ||
| allowed_tags = ["weight", "kv_cache"] | ||
| tags = control_request.args.get("tags", "") | ||
|
|
||
| for tag in tags.split(","): | ||
| if tag not in allowed_tags: | ||
| raise ValueError(f"unsupported tag {tag} in {tags}, expected one of {allowed_tags}") | ||
|
|
||
| # Determine which executors are needed for the wakeup command | ||
| executors = set() | ||
| if "weight" in tags: | ||
| executors.add("worker") | ||
| if "kv_cache" in tags: | ||
| executors.add("worker") | ||
| if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: | ||
| executors.add("cache_transfer") | ||
|
|
||
| # Dispatch wakeup request to executors | ||
| self._dispatch_control_request(control_request, executors) | ||
| result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors)) | ||
|
|
||
| # Resume the engine after wakeup | ||
| self._control_resume(None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sleep和wake_up 内需要包含 pause和resume吗?是不是交由上游中控来调用,这样sleep和wake_up的语义更加明确
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里一方面是想做个兜底,另一方面是为了兼容老接口语义,调一次 clear_load_weight 和调一次 sleep 可以实现相同的效果。最好确实需要拆分一下,可以用参数控制 sleep 时是否需要隐式包含 pause |
||
|
|
||
| return result | ||
|
|
||
| def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]): | ||
| """ | ||
| Dispatch control requests to workers, cache managers or engine itself. | ||
|
|
||
| Args: | ||
| control_request: ControlRequest | ||
| executors: List | ||
| """ | ||
| if "worker" in executors: | ||
| self.engine_worker_queue.put_tasks(([control_request], 1)) | ||
| if "cache_transfer" in executors: | ||
| if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request)) | ||
| return | ||
|
|
||
| async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None): | ||
| """Wait for control responses from specified queues. | ||
|
|
||
| Args: | ||
| request_id: The request ID to match responses against | ||
| timeout: Global timeout in seconds | ||
| executors: List of executors to wait for, e.g., ["worker", "cache_transfer"] | ||
| If None, waits for all queues | ||
| """ | ||
| timeout_ms = timeout * 1000 if timeout else None | ||
|
|
||
| # determine which queues to wait for by executors | ||
| queues = {} | ||
| if executors is None: | ||
| queues = self._ctrl_output_queues | ||
| else: | ||
| if "worker" in executors: | ||
| for name, queue in self._ctrl_output_queues.items(): | ||
| if "w2e" in name: | ||
| queues[name] = queue | ||
| if "cache_transfer" in executors: | ||
| for name, queue in self._ctrl_output_queues.items(): | ||
| if "c2e" in name: | ||
| queues[name] = queue | ||
|
|
||
| if not queues: | ||
| self.llm_logger.info(f"No queues to wait for, executors: {executors}") | ||
| return | ||
| self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}") | ||
|
|
||
| # Create one get() coroutine per queue | ||
| tasks = [q.get(timeout=timeout_ms) for q in queues.values()] | ||
|
|
||
| try: | ||
| results = await asyncio.wait_for( | ||
|
|
@@ -1423,32 +1566,31 @@ async def _wait_all_control_responses(self, request_id: str, timeout: int): | |
| ) | ||
| except asyncio.TimeoutError: | ||
| # Keep the error message consistent with previous behavior | ||
| raise Exception("Worker Update Weights Timeouted after 600s") | ||
| raise Exception(f"Control request {request_id} timeouted after {timeout}s") | ||
|
|
||
| responses = [] | ||
| for output_queue, msg in zip(self._ctrl_worker_output_queues, results): | ||
| for name, msg in zip(queues.keys(), results): | ||
| if isinstance(msg, Exception): | ||
| self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}") | ||
| raise Exception(f"Call Worker error: {repr(msg)}") | ||
| self.llm_logger.error(f"Call {name} failed: {repr(msg)}") | ||
| raise Exception(f"Call {name} error: {repr(msg)}") | ||
| if msg is None: | ||
| # Preserve original semantics when no message is received | ||
| raise Exception("Worker Update Weights Timeouted after 600s") | ||
| raise Exception(f"No message received from {name}") | ||
| response: ControlResponse = msg.payload | ||
| if response.request_id != request_id: | ||
| self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}") | ||
| self.llm_logger.info(f"ignore old control response from {name}: {response}") | ||
| continue | ||
| if response.error_code != 200: | ||
| self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}") | ||
| raise Exception(f"Call Worker error: {response.error_message}") | ||
| self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}") | ||
| self.llm_logger.info(f"Call {name} failed: {response.error_message}") | ||
| raise Exception(f"Call {name} error: {response.error_message}") | ||
| self.llm_logger.info(f"Call {name} succeed: {response.result}") | ||
| responses.append(response.result) | ||
| return responses | ||
|
|
||
| def _call_worker(self, control_request: ControlRequest, timeout: int): | ||
| request_id = control_request.request_id | ||
| self.engine_worker_queue.put_tasks(([control_request], 1)) | ||
| # Use a single asyncio.run() to concurrently wait for all worker responses. | ||
| return asyncio.run(self._wait_all_control_responses(request_id, timeout)) | ||
| return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"])) | ||
|
|
||
| def _send_error_response(self, request_id, error_msg, error_code: int = 500): | ||
| self.llm_logger.error( | ||
|
|
@@ -1487,6 +1629,8 @@ def _zmq_send_generated_tokens(self): | |
| """ | ||
| while self.running: | ||
| try: | ||
| with self._pause_cond: | ||
| self._pause_cond.wait_for(lambda: not self.is_paused) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. output为什么需要感知pause呢?如果没有新请求,output就不会有新token要处理;对于正在处理的请求,应该要等到preempted调度后自行结束,否则可能会有中间token阻塞在输出队列里
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里忘删了 |
||
| results = self.scheduler.get_results() | ||
| if len(results) == 0: | ||
| time.sleep(0.005) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -570,6 +570,16 @@ def __init__( | |
| self.method = method | ||
| self.args = args or {} | ||
|
|
||
| self._post_init() | ||
|
|
||
| def _post_init(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个感觉放到接口自己维护比较好,request逻辑尽量保持通用
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
| if self.method in ("sleep", "wakeup"): | ||
| tags = self.args.get("tags") | ||
| if tags is None or tags == "": | ||
| self.args["tags"] = "weight,kv_cache" | ||
| elif not isinstance(tags, str): | ||
| raise TypeError("tags must be a string") | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, d: dict): | ||
| """Create ControlRequest instance from dictionary.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
跟worker、cache transfer的控制通信感觉最好区分开,因为不一定所有的控制信号都会发给他俩