-
Notifications
You must be signed in to change notification settings - Fork 740
[RL][Cherry-Pick] Fix the out-of-bounds issue caused by int32 in the R3 kernel #7496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.6
Are you sure you want to change the base?
Changes from all commits
f52bd00
140a7fa
77ee863
c3a74cc
ef107f9
9d351c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -54,6 +54,7 @@ def _save_routing_kernel( | |||||
| TOP_K, | ||||||
| NUM_HIDDEN_LAYERS, | ||||||
| MAX_MODEL_LEN, | ||||||
| MAX_NUM_SEQS, | ||||||
| BLOCK_SIZE_M: tl.constexpr, | ||||||
| BLOCK_SIZE_K: tl.constexpr, | ||||||
| ): | ||||||
|
|
@@ -63,45 +64,37 @@ def _save_routing_kernel( | |||||
| token_mask = token_offsets < TOKEN_NUM | ||||||
|
|
||||||
| k_offsets = tl.arange(0, BLOCK_SIZE_K) | ||||||
|
|
||||||
| k_mask = k_offsets < TOP_K | ||||||
|
|
||||||
| topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] | ||||||
| # [BLOCK_SIZE_M, BLOCK_SIZE_K] | ||||||
|
|
||||||
| load_mask = token_mask[:, None] & k_mask[None, :] | ||||||
| topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) | ||||||
|
|
||||||
| batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) | ||||||
| pad_mask = token_mask & (batch_ids != -1) | ||||||
| # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] | ||||||
| # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] | ||||||
| # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] | ||||||
| # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] | ||||||
| start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) | ||||||
| topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1) | ||||||
|
|
||||||
| batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1) | ||||||
|
|
||||||
| batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS) | ||||||
| pad_mask = token_mask & (batch_ids != -1) & batch_mask | ||||||
|
|
||||||
| start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0) | ||||||
| token_relative_index = token_offsets - start_offsets | ||||||
|
|
||||||
| # [BLOCK_SIZE_M] | ||||||
| len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) | ||||||
| len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0) | ||||||
| token_seq_pos = len_decoder + token_relative_index | ||||||
|
|
||||||
| STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K | ||||||
| STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K | ||||||
| STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64) | ||||||
| STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64) | ||||||
| STRIDE_BUF_LAYER = TOP_K | ||||||
|
|
||||||
| # [BLOCK_SIZE_M, BLOCK_SIZE_K] | ||||||
| output_ptrs = ( | ||||||
| ROUTING_REPLAY_TABLE_PTR | ||||||
| + batch_ids[:, None] * STRIDE_BUF_SEQ | ||||||
| + token_seq_pos[:, None] * STRIDE_BUF_TOKEN | ||||||
| + LAYER_IDX * STRIDE_BUF_LAYER | ||||||
| + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ | ||||||
| + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN | ||||||
| + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER | ||||||
| + k_offsets[None, :] | ||||||
| ) | ||||||
|
|
||||||
| pos_mask = token_seq_pos < MAX_MODEL_LEN | ||||||
| pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN) | ||||||
| pos_mask = pos_mask & pad_mask | ||||||
|
|
||||||
| # [BLOCK_SIZE_M, BLOCK_SIZE_K] | ||||||
| pos_mask = pos_mask[:, None] & k_mask[None, :] | ||||||
|
|
||||||
| final_mask = load_mask & pos_mask | ||||||
|
|
@@ -120,20 +113,21 @@ def save_routing_to_buffer( | |||||
| ep_size: int, | ||||||
| tp_group: dist.communication.group.Group, | ||||||
| ): | ||||||
| token_num_per_rank = topk_ids.shape[0] | ||||||
| if token_num_per_rank == 0: | ||||||
| return | ||||||
| if tp_size > 1 and ep_size > 1: | ||||||
| token_num_per_rank = topk_ids.shape[0] | ||||||
| if token_num_per_rank == 0: | ||||||
| return | ||||||
| topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) | ||||||
| paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) | ||||||
| topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] | ||||||
|
|
||||||
| token_num, top_k = topk_ids.shape | ||||||
| max_num_seqs, max_model_len, num_hidden_layers, _ = routing_replay_table.shape | ||||||
| assert token_num > 0 | ||||||
| assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3]) | ||||||
| assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num) | ||||||
| assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs) | ||||||
| assert ( | ||||||
| topk_ids.shape[1] == routing_replay_table.shape[3] | ||||||
| ), f"({topk_ids.shape[1]}, {routing_replay_table.shape[3]})" | ||||||
| assert batch_id_per_token.shape[0] == token_num, f"({batch_id_per_token.shape[0]}, {token_num})" | ||||||
|
|
||||||
| BLOCK_SIZE_M = 128 | ||||||
| BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k | ||||||
|
|
@@ -150,6 +144,7 @@ def save_routing_to_buffer( | |||||
| TOP_K=top_k, | ||||||
| NUM_HIDDEN_LAYERS=num_hidden_layers, | ||||||
| MAX_MODEL_LEN=max_model_len, | ||||||
| MAX_NUM_SEQS=max_num_seqs, | ||||||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||||||
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||||||
| ) | ||||||
|
|
@@ -166,6 +161,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num): | |||||
| self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index | ||||||
| self.only_last_turn = fd_config.routing_replay_config.only_last_turn | ||||||
| self.use_fused_put = fd_config.routing_replay_config.use_fused_put | ||||||
| logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}") | ||||||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 日志中 "Congfig" 拼写错误,应为 "Config"。
Suggested change
|
||||||
| if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": | ||||||
| self.moe_top_k = fd_config.model_config.num_experts_per_tok | ||||||
| else: | ||||||
|
|
@@ -186,6 +182,17 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num): | |||||
| ) | ||||||
| self._store_wrapper.start_store_warpper() | ||||||
|
|
||||||
| # Suspend Routing Replay | ||||||
| self.suspend_routing_replay = False | ||||||
| self.update_suspend_routing_replay() | ||||||
|
|
||||||
| def update_suspend_routing_replay(self): | ||||||
| """Allow RL to use R3 in different training rounds""" | ||||||
| # TODO(gongshaotian): Delete this func | ||||||
| suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0") | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 from fastdeploy import envs
self.suspend_routing_replay = envs.FD_SUSPEND_ROUTING_REPLAY如果此处需要在运行时动态感知环境变量变更(不走缓存),可以忽略此建议,但建议在注释中说明原因。 |
||||||
| self.suspend_routing_replay = bool(int(suspend_routing_replay)) | ||||||
| logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}") | ||||||
|
|
||||||
| def _init_routing_cache(self, dtype: str, total_block_num: int): | ||||||
| """Initialize the device buffer and host buffer.""" | ||||||
|
|
||||||
|
|
@@ -341,6 +348,11 @@ def _put_request_to_store( | |||||
| seq_lens_decoder, | ||||||
| ): | ||||||
| if self.tp_rank == 0: | ||||||
| # TODO(gongshaotian): Delete the suspend func | ||||||
| if self.suspend_routing_replay: | ||||||
| logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store") | ||||||
| return | ||||||
|
|
||||||
| before_put_request_time = time.perf_counter() | ||||||
|
|
||||||
| # Collect the routing of finished request | ||||||
|
|
@@ -351,16 +363,19 @@ def _put_request_to_store( | |||||
|
|
||||||
| if self.use_fused_put: | ||||||
| self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) | ||||||
| # Only store the routing of last turn | ||||||
| if self.only_last_turn: | ||||||
| self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) | ||||||
|
|
||||||
| else: | ||||||
| for layer_id in range(self.num_moe_layers): | ||||||
| layer_buffer = batch_buffer[layer_id] | ||||||
| self._store_wrapper.submit_put_task( | ||||||
| routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id | ||||||
| ) | ||||||
|
|
||||||
| # Only store the routing of last turn | ||||||
| if self.only_last_turn: | ||||||
| self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) | ||||||
| # Only store the routing of last turn | ||||||
| if self.only_last_turn: | ||||||
| self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) | ||||||
|
|
||||||
| logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") | ||||||
|
|
||||||
|
|
@@ -481,7 +496,6 @@ def _monitor_queue_load(self): | |||||
| if qsize > self.queue_max_size * 0.8: | ||||||
| logger.warning( | ||||||
| f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " | ||||||
| f"Dropped tasks so far: {self._dropped_tasks}. " | ||||||
| "Consider increasing max_workers or queue_max_size." | ||||||
| ) | ||||||
| logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") | ||||||
|
|
@@ -523,22 +537,26 @@ def submit_clear_store_task(self) -> None: | |||||
| raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") | ||||||
| logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") | ||||||
|
|
||||||
| def submit_clear_prefix_batch_task(self, rollout_id) -> None: | ||||||
| def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: | ||||||
| """Submit clear prefix batch task""" | ||||||
| if not self._sotre_process_running: | ||||||
| raise RuntimeError("Store not started.") | ||||||
| prefix_batch = self.get_needed_clear_ids(rollout_id) | ||||||
|
|
||||||
| if prefix_batch is None: | ||||||
| prefix_batch_id = self.get_needed_clear_ids(rollout_id) | ||||||
| if prefix_batch_id is None: | ||||||
| return | ||||||
| start_time = time.perf_counter() | ||||||
| task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None} | ||||||
| if layer_idx is not None: | ||||||
| rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" | ||||||
| else: | ||||||
| rdma_rollout_key = prefix_batch_id | ||||||
|
|
||||||
| task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} | ||||||
| try: | ||||||
| self._task_queue.put_nowait(task) | ||||||
| except Exception: | ||||||
| raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") | ||||||
| logger.info( | ||||||
| f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s" | ||||||
| f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" | ||||||
This comment was marked as outdated.
Sorry, something went wrong. |
||||||
| ) | ||||||
|
|
||||||
| def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: | ||||||
|
|
@@ -615,7 +633,7 @@ def run(self): | |||||
| self._task_queue.task_done() | ||||||
| raise RuntimeError(f"Error during processing task. {e}") | ||||||
|
|
||||||
| logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.") | ||||||
| logger.info("RoutingReplay Consumer Process Shutdown.") | ||||||
|
|
||||||
| def process_put_task(self, store_task: StoreTask) -> None: | ||||||
| try: | ||||||
|
|
@@ -838,13 +856,18 @@ def __init__(self, routing_replay_config) -> None: | |||||
| async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: | ||||||
| """Put the routing indices into store""" | ||||||
| time_before_put = time.perf_counter() | ||||||
| result = await self.p2p_client.put(routing_key, routing_indices) | ||||||
| if len(routing_indices.shape) == 3: | ||||||
| # NOTE(gongshaotian) Fused put with bytes data | ||||||
| routing_bytes = routing_indices.tobytes() | ||||||
| result = await self.p2p_client.put(routing_key, routing_bytes) | ||||||
| else: | ||||||
| result = await self.p2p_client.put(routing_key, routing_indices) | ||||||
| logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") | ||||||
| return result | ||||||
|
|
||||||
| async def clear_prefix_batch(self, routing_prefix_key: str): | ||||||
| time_before_clear = time.perf_counter() | ||||||
| result = await self.p2p_client.delete_prefix_batch([routing_prefix_key]) | ||||||
| result = await self.p2p_client.delete_batch([routing_prefix_key]) | ||||||
|
gongshaotian marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 此处将 结合 但方法名 |
||||||
| logger.info( | ||||||
| f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" | ||||||
| ) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议 原有的
assert seq_lens_decoder.shape[0] == max_num_seqs断言被移除了。虽然 kernel 内部已新增batch_mask做越界防护,但 Python 层的断言能在 kernel 启动前更早地捕获 tensor shape 不匹配问题,并给出更清晰的错误信息。建议保留此断言: