Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,9 @@ def __init__(self, args: dict):
else:
self.metrics_port = self.api_server_port

def __str__(self):
return json.dumps({key: value for key, value in self.__dict__.items()})


class CommitConfig:
"""
Expand Down Expand Up @@ -1858,6 +1861,9 @@ def to_json_string(self):
"""
return json.dumps({key: value for key, value in self.__dict__.items()})

def __str__(self):
return self.to_json_string()


class FDConfig:
"""
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def _validate_split_kv_size(value: int) -> int:
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
),
# Suspend rollouting routing replay
"FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))),
# train-infer consistency, used in RL
# Whether to align RoPE and moe gate precision with training
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
Expand Down
109 changes: 66 additions & 43 deletions fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _save_routing_kernel(
TOP_K,
NUM_HIDDEN_LAYERS,
MAX_MODEL_LEN,
MAX_NUM_SEQS,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
Expand All @@ -63,45 +64,37 @@ def _save_routing_kernel(
token_mask = token_offsets < TOKEN_NUM

k_offsets = tl.arange(0, BLOCK_SIZE_K)

k_mask = k_offsets < TOP_K

topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
# [BLOCK_SIZE_M, BLOCK_SIZE_K]

load_mask = token_mask[:, None] & k_mask[None, :]
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)

batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
pad_mask = token_mask & (batch_ids != -1)
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1)

batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1)

batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS)
pad_mask = token_mask & (batch_ids != -1) & batch_mask

start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0)
token_relative_index = token_offsets - start_offsets

# [BLOCK_SIZE_M]
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0)
token_seq_pos = len_decoder + token_relative_index

STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K
STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K
STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
STRIDE_BUF_LAYER = TOP_K

# [BLOCK_SIZE_M, BLOCK_SIZE_K]
output_ptrs = (
ROUTING_REPLAY_TABLE_PTR
+ batch_ids[:, None] * STRIDE_BUF_SEQ
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
+ LAYER_IDX * STRIDE_BUF_LAYER
+ tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ
+ tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN
+ tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER
+ k_offsets[None, :]
)

pos_mask = token_seq_pos < MAX_MODEL_LEN
pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN)
pos_mask = pos_mask & pad_mask

# [BLOCK_SIZE_M, BLOCK_SIZE_K]
pos_mask = pos_mask[:, None] & k_mask[None, :]

final_mask = load_mask & pos_mask
Expand All @@ -120,20 +113,21 @@ def save_routing_to_buffer(
ep_size: int,
tp_group: dist.communication.group.Group,
):
token_num_per_rank = topk_ids.shape[0]
if token_num_per_rank == 0:
return
if tp_size > 1 and ep_size > 1:
token_num_per_rank = topk_ids.shape[0]
if token_num_per_rank == 0:
return
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]

token_num, top_k = topk_ids.shape
max_num_seqs, max_model_len, num_hidden_layers, _ = routing_replay_table.shape
assert token_num > 0
assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3])
assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num)
assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs)
assert (
topk_ids.shape[1] == routing_replay_table.shape[3]
), f"({topk_ids.shape[1]}, {routing_replay_table.shape[3]})"
assert batch_id_per_token.shape[0] == token_num, f"({batch_id_per_token.shape[0]}, {token_num})"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 原有的 assert seq_lens_decoder.shape[0] == max_num_seqs 断言被移除了。虽然 kernel 内部已新增 batch_mask 做越界防护,但 Python 层的断言能在 kernel 启动前更早地捕获 tensor shape 不匹配问题,并给出更清晰的错误信息。建议保留此断言:

assert seq_lens_decoder.shape[0] >= max_num_seqs, f"({seq_lens_decoder.shape[0]}, {max_num_seqs})"


BLOCK_SIZE_M = 128
BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k
Expand All @@ -150,6 +144,7 @@ def save_routing_to_buffer(
TOP_K=top_k,
NUM_HIDDEN_LAYERS=num_hidden_layers,
MAX_MODEL_LEN=max_model_len,
MAX_NUM_SEQS=max_num_seqs,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
Expand All @@ -166,6 +161,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
self.only_last_turn = fd_config.routing_replay_config.only_last_turn
self.use_fused_put = fd_config.routing_replay_config.use_fused_put
logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}")

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 日志中 "Congfig" 拼写错误,应为 "Config"。

Suggested change
logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}")
logger.info(f"[R3] Rollout Routing Replay Config: {fd_config.routing_replay_config}")

if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
self.moe_top_k = fd_config.model_config.num_experts_per_tok
else:
Expand All @@ -186,6 +182,17 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
)
self._store_wrapper.start_store_warpper()

# Suspend Routing Replay
self.suspend_routing_replay = False
self.update_suspend_routing_replay()

def update_suspend_routing_replay(self):
"""Allow RL to use R3 in different training rounds"""
# TODO(gongshaotian): Delete this func
suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 FD_SUSPEND_ROUTING_REPLAY 已在 envs.py 中注册,但这里通过 os.environ.get 直接读取。建议统一使用 envs 模块的方式读取,保持项目内环境变量访问模式的一致性:

from fastdeploy import envs
self.suspend_routing_replay = envs.FD_SUSPEND_ROUTING_REPLAY

如果此处需要在运行时动态感知环境变量变更(不走缓存),可以忽略此建议,但建议在注释中说明原因。

self.suspend_routing_replay = bool(int(suspend_routing_replay))
logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}")

def _init_routing_cache(self, dtype: str, total_block_num: int):
"""Initialize the device buffer and host buffer."""

Expand Down Expand Up @@ -341,6 +348,11 @@ def _put_request_to_store(
seq_lens_decoder,
):
if self.tp_rank == 0:
# TODO(gongshaotian): Delete the suspend func
if self.suspend_routing_replay:
logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store")
return

before_put_request_time = time.perf_counter()

# Collect the routing of finished request
Expand All @@ -351,16 +363,19 @@ def _put_request_to_store(

if self.use_fused_put:
self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id)
# Only store the routing of last turn
if self.only_last_turn:
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)

else:
for layer_id in range(self.num_moe_layers):
layer_buffer = batch_buffer[layer_id]
self._store_wrapper.submit_put_task(
routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id
)

# Only store the routing of last turn
if self.only_last_turn:
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)
# Only store the routing of last turn
if self.only_last_turn:
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id)

logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}")

Expand Down Expand Up @@ -481,7 +496,6 @@ def _monitor_queue_load(self):
if qsize > self.queue_max_size * 0.8:
logger.warning(
f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. "
f"Dropped tasks so far: {self._dropped_tasks}. "
"Consider increasing max_workers or queue_max_size."
)
logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}")
Expand Down Expand Up @@ -523,22 +537,26 @@ def submit_clear_store_task(self) -> None:
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s")

def submit_clear_prefix_batch_task(self, rollout_id) -> None:
def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None:
"""Submit clear prefix batch task"""
if not self._sotre_process_running:
raise RuntimeError("Store not started.")
prefix_batch = self.get_needed_clear_ids(rollout_id)

if prefix_batch is None:
prefix_batch_id = self.get_needed_clear_ids(rollout_id)
if prefix_batch_id is None:
return
start_time = time.perf_counter()
task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None}
if layer_idx is not None:
rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}"
else:
rdma_rollout_key = prefix_batch_id

task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None}
try:
self._task_queue.put_nowait(task)
except Exception:
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
logger.info(
f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s"
f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s"

This comment was marked as outdated.

)

def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]:
Expand Down Expand Up @@ -615,7 +633,7 @@ def run(self):
self._task_queue.task_done()
raise RuntimeError(f"Error during processing task. {e}")

logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.")
logger.info("RoutingReplay Consumer Process Shutdown.")

def process_put_task(self, store_task: StoreTask) -> None:
try:
Expand Down Expand Up @@ -838,13 +856,18 @@ def __init__(self, routing_replay_config) -> None:
async def put(self, routing_key: str, routing_indices: np.ndarray) -> None:
"""Put the routing indices into store"""
time_before_put = time.perf_counter()
result = await self.p2p_client.put(routing_key, routing_indices)
if len(routing_indices.shape) == 3:
# NOTE(gongshaotian) Fused put with bytes data
routing_bytes = routing_indices.tobytes()
result = await self.p2p_client.put(routing_key, routing_bytes)
else:
result = await self.p2p_client.put(routing_key, routing_indices)
logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s")
return result

async def clear_prefix_batch(self, routing_prefix_key: str):
time_before_clear = time.perf_counter()
result = await self.p2p_client.delete_prefix_batch([routing_prefix_key])
result = await self.p2p_client.delete_batch([routing_prefix_key])
Comment thread
gongshaotian marked this conversation as resolved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 此处将 delete_prefix_batch 改为了 delete_batch,语义从「按前缀批量删除」变为「按精确 key 删除」。

结合 submit_clear_prefix_batch_task 中 non-fused 路径现在传入 layer_idx 构建精确 key(如 {rollout_id}_{layer_idx}),逐层删除可以工作。

但方法名 clear_prefix_batch 和参数名 routing_prefix_key 仍暗示前缀语义,与实际的精确删除行为不一致,容易造成后续维护者误解。建议同步更新方法名和参数名,例如改为 clear_batch / routing_key

logger.info(
f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s"
)
Expand Down
6 changes: 5 additions & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2737,9 +2737,13 @@ def update_parameters(self, pid):
# Recapture CUDAGraph
if self.use_cudagraph:
self.capture_model()
# Rollout Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
# TODO(gongshaotian): Delete suspend func
self.routing_replay_manager.update_suspend_routing_replay()

# Send single
self.dynamic_weight_manager.finalize_update(pid)

self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")

def update_weights(self, version: str = None, verify_checksum: bool = False):
Expand Down
Loading