@@ -127,7 +127,9 @@ def __init__(
127127 self .disable_cuda_graph = disable_cuda_graph
128128 self .max_cuda_graph_bs = max_cuda_graph_bs
129129 self .size_to_graph : Dict [int , torch .cuda .CUDAGraph ] = dict ()
130- self .capture_sizes = list (range (self .max_cuda_graph_bs , 0 , - 1 ))
130+ # Use power-of-two bucket sizes to reduce the number of captured graphs.
131+ # At runtime the actual batch is padded up to the nearest bucket.
132+ self .capture_sizes = self ._build_capture_sizes (self .max_cuda_graph_bs )
131133
132134 # max length
133135 self .model_max_length = self .resolve_model_max_length (model_max_length )
@@ -142,6 +144,26 @@ def resolve_model_max_length(self, model_max_length):
142144 logger .info (f"Model max length: { model_max_length } " )
143145 return model_max_length
144146
147+ @staticmethod
148+ def _build_capture_sizes (max_bs : int ):
149+ """Return power-of-two bucket sizes up to max_bs, in descending order.
150+
151+ For example, max_bs=20 → [20, 16, 8, 4, 2, 1].
152+ We always include 1 as a floor bucket.
153+ """
154+ if max_bs <= 0 :
155+ return []
156+ sizes = []
157+ s = 1
158+ while s <= max_bs :
159+ sizes .append (s )
160+ s *= 2
161+ # If max_bs is not itself a power of two, add it as the top bucket so
162+ # that batches of exactly max_bs can still use CUDA graph.
163+ if sizes [- 1 ] != max_bs :
164+ sizes .append (max_bs )
165+ return list (reversed (sizes ))
166+
145167 def init (self , mp_load_progress = None ):
146168 self .model = self .model_loader .load_model (mp_load_progress )
147169 memory_manager_cls = (
@@ -171,8 +193,9 @@ def init(self, mp_load_progress=None):
171193 self .output_residual = torch .zeros ((self .max_num_batched_tokens , self .hidden_size ))
172194 # Profile run
173195 self .profile_run ()
174- # Init KV cache at last
175- self .memory_manager .init ()
196+ # Init KV cache at last; only reserve the dummy page when CUDA graphs
197+ # are actually enabled so we don't waste memory otherwise.
198+ self .memory_manager .init (reserve_dummy_page = not self .disable_cuda_graph )
176199
177200 if not self .disable_cuda_graph :
178201 self .capture_graph ()
@@ -373,7 +396,7 @@ def profile_run(self):
373396 def capture_graph (self ):
374397 iterator = self .capture_sizes
375398 if get_local_rank () == 0 :
376- # logger.info(f"Capturing cuda graph for sizes { self.capture_sizes}")
399+ logger .info (f"Capturing CUDA graphs for bucket sizes: { list ( reversed ( self .capture_sizes )) } " )
377400 iterator = tqdm (self .capture_sizes , desc = "Capturing CUDA Graphs" , ncols = 100 )
378401 memory_pool = torch .cuda .graph_pool_handle ()
379402 for size in iterator :
@@ -426,9 +449,22 @@ def check_decode_batch(self):
426449 @torch .inference_mode ()
427450 def step_once (self ):
428451 num_cal_tokens = self .input_data .tokens_cpu .shape [0 ]
429- # Only decode batch use cuda graph
430- if self .check_decode_batch () and num_cal_tokens in self .size_to_graph :
431- self .size_to_graph [num_cal_tokens ].replay ()
452+ # Only pure decode batches use CUDA graph.
453+ if self .check_decode_batch ():
454+ # Find the smallest captured bucket >= actual batch size.
455+ padded_size = None
456+ for bucket in self .capture_sizes :
457+ if bucket >= num_cal_tokens :
458+ padded_size = bucket
459+ if padded_size is not None and padded_size in self .size_to_graph :
460+ # Pad input buffers to the bucket size with dummy values, then
461+ # replay the pre-captured graph.
462+ num_real_tokens = self .input_data .pad_for_cuda_graph (padded_size )
463+ self .size_to_graph [padded_size ].replay ()
464+ # After replay, use only the real-token slice for logits.
465+ num_cal_tokens = num_real_tokens
466+ else :
467+ self .forward ()
432468 else :
433469 self .forward ()
434470 if is_last_pp_rank ():
0 commit comments