2020import torch
2121import torch .distributed as dist
2222import vllm .forward_context as vfc
23- from vllm .config import CUDAGraphMode , VllmConfig
23+ from vllm .config import CUDAGraphMode , ParallelConfig , VllmConfig
2424from vllm .forward_context import (BatchDescriptor , DPMetadata ,
2525 batchsize_logging_interval ,
2626 create_forward_context ,
3535
3636@dataclass
3737class RBLNDPMetadata (DPMetadata ):
38- max_pads_across_dp : int = 0
38+ max_pads_across_dp : torch . Tensor | None = None
3939
4040 @staticmethod
4141 def num_tokens_across_dp (num_tokens : int , dp_size : int ,
@@ -53,26 +53,66 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
5353 dist .all_reduce (num_tokens_tensor , group = get_dp_group ().cpu_group )
5454 return num_tokens_tensor
5555
56+ @staticmethod
57+ def num_tokens_across_dp_with_max_decode_tokens (
58+ num_tokens : int , dp_size : int , dp_rank : int ,
59+ is_prefill : bool ) -> tuple [torch .Tensor , int | None ]:
60+ pad_flag = 1 << 16
61+ pad_mask = pad_flag - 1
62+ assert num_tokens < pad_flag , \
63+ "num_tokens should be less than pad_flag"
64+
65+ if is_prefill :
66+ num_tokens |= pad_flag
67+
68+ tokens_across_dp_cpu = RBLNDPMetadata .num_tokens_across_dp (
69+ num_tokens , dp_size , dp_rank )
70+ max_across_dp = torch .max (tokens_across_dp_cpu ).item ()
71+
72+ if is_prefill or max_across_dp > pad_flag :
73+ mask_tensor = torch .tensor ([pad_mask ] * dp_size ,
74+ device = "cpu" ,
75+ dtype = torch .int32 )
76+ num_tokens_across_dp_cpu = tokens_across_dp_cpu & mask_tensor
77+ max_across_dp = None
78+ else :
79+ num_tokens_across_dp_cpu = tokens_across_dp_cpu
80+
81+ return num_tokens_across_dp_cpu , max_across_dp
82+
5683 @staticmethod
5784 def make (
58- vllm_config : VllmConfig ,
85+ parallel_config : ParallelConfig ,
5986 num_tokens : int ,
87+ num_tokens_across_dp : torch .Tensor | None = None ,
88+ num_padded_tokens : int | None = None ,
6089 ) -> "RBLNDPMetadata" :
61- parallel_config = vllm_config .parallel_config
6290 dp_size = parallel_config .data_parallel_size
63- dp_rank = parallel_config .data_parallel_rank
64-
65- scheduler_config = vllm_config .scheduler_config
66- max_pad = scheduler_config .max_num_batched_tokens
67- batchsize = num_tokens
6891
69- num_tokens_across_dp_cpu = RBLNDPMetadata .num_tokens_across_dp (
70- batchsize , dp_size , dp_rank )
71- max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp_cpu )
92+ if dp_size > 1 :
93+ assert num_tokens_across_dp is not None , \
94+ "num_tokens_across_dp should be applied for DP case"
95+ assert num_padded_tokens is not None , \
96+ "num_padded_tokens should be applied for DP case"
97+ num_tokens_across_dp_cpu = num_tokens_across_dp
98+ max_pad = num_padded_tokens
99+
100+ max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp_cpu )
101+ max_pads_across_dp = torch .empty (max_pad , device = "cpu" )
102+ else :
103+ assert num_tokens_across_dp is None , \
104+ "num_tokens_across_dp should not be applied for non-DP case"
105+ assert num_padded_tokens is None , \
106+ "num_padded_tokens should not be applied for non-DP case"
107+ num_tokens_across_dp_cpu = torch .tensor ([num_tokens ],
108+ device = "cpu" ,
109+ dtype = torch .int32 )
110+ max_tokens_across_dp_cpu = num_tokens
111+ max_pads_across_dp = None
72112
73113 return RBLNDPMetadata (max_tokens_across_dp_cpu ,
74114 num_tokens_across_dp_cpu ,
75- max_pads_across_dp = max_pad )
115+ max_pads_across_dp = max_pads_across_dp )
76116
77117
78118@contextmanager
@@ -85,6 +125,7 @@ def _set_forward_context(
85125 cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode .NONE ,
86126 batch_descriptor : BatchDescriptor | None = None ,
87127 ubatch_slices : UBatchSlices | None = None ,
128+ num_padded_tokens : int | None = None ,
88129):
89130 """A context manager that stores the current forward context,
90131 can be attention metadata, etc.
@@ -99,7 +140,10 @@ def _set_forward_context(
99140 use_moe_tokens_mask = envs .VLLM_RBLN_USE_MOE_TOKENS_MASK
100141 if (enable_dp or use_moe_tokens_mask ) and (attn_metadata is not None
101142 or num_tokens is not None ):
102- dp_metadata = RBLNDPMetadata .make (vllm_config , num_tokens or 0 )
143+ dp_metadata = RBLNDPMetadata .make (vllm_config .parallel_config ,
144+ num_tokens or 0 ,
145+ num_tokens_across_dp ,
146+ num_padded_tokens )
103147
104148 forward_context = create_forward_context (
105149 attn_metadata ,
0 commit comments