Skip to content

Commit d0b74cd

Browse files
authored
Cuda graph padding (#174)
* Support padding for cuda graph * Fix select graph size and default max cuda graph size * Set default to 512 * Fix issues * Support moonlight
1 parent 368b136 commit d0b74cd

6 files changed

Lines changed: 134 additions & 15 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ python benchmarks/evaluate_MMLU_pro.py --model $MODEL
153153

154154
## Supported Models
155155

156-
- Kimi Series: K2-Base, K2-Instruct
156+
- Kimi Series: Moonlight, K2-Base, K2-Instruct
157157
- DeepSeek Series: DeepSeek R1, DeepSeek V3, DeepSeek V2
158158
- Qwen Series: Qwen3 VL, Qwen3, Qwen2.5 VL, Qwen2.5, Qwen2
159159
- Llama Series: Llama3.2, Llama3.1, Llama3, Llama2 and deepseek-coder

gllm/entrypoints/api_server.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,13 @@ async def run_server(args):
190190
parser.add_argument(
191191
"--max-cuda-graph-bs",
192192
type=int,
193-
help="Maximum batch size for cuda graph",
194-
default=32,
193+
help=(
194+
"Maximum batch size for CUDA graph capture. "
195+
"Larger values allow more decode batches to benefit from CUDA graphs "
196+
"but increase startup time and GPU memory usage during graph capture. "
197+
"Default: 512."
198+
),
199+
default=512,
195200
)
196201
# Parallelism
197202
parser.add_argument("--pp", type=int, help="Number of pipeline stages", default=1)

gllm/input_data.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,56 @@ def _cal_mla_metadata(self, seqs: List[Sequence]):
316316
dtype=torch.int32,
317317
)
318318

319+
def pad_for_cuda_graph(self, padded_size: int):
320+
"""Pad input buffers to padded_size using dummy values.
321+
322+
This enables CUDA graph replay for a fixed batch size (a power-of-two
323+
bucket) even when the actual number of decode tokens is smaller.
324+
325+
The dummy tokens write their KV entries to memory_manager.dummy_page,
326+
which is permanently reserved and never used by real sequences.
327+
328+
Returns:
329+
num_real_tokens (int): the actual (unpadded) token count, so that
330+
the caller can slice output_hidden_states[:num_real_tokens] when
331+
computing logits after graph replay.
332+
"""
333+
assert self.use_buffer, "pad_for_cuda_graph requires use_buffer=True"
334+
num_real_tokens = self.tokens_cpu.shape[0]
335+
if num_real_tokens >= padded_size:
336+
return num_real_tokens
337+
338+
dummy_page = self.memory_manager.dummy_page
339+
dummy_slot = dummy_page * self.page_size # slot index within dummy page
340+
341+
num_pad = padded_size - num_real_tokens
342+
343+
# tokens: pad with 0
344+
self.tokens[num_real_tokens:padded_size].zero_()
345+
# positions: pad with 0
346+
self.positions[num_real_tokens:padded_size].zero_()
347+
# mrope_positions: pad with 0
348+
self.mrope_positions[:, num_real_tokens:padded_size].zero_()
349+
# slot_mapping: pad with dummy slot so writes go to the reserved page
350+
self.slot_mapping[num_real_tokens:padded_size].fill_(dummy_slot)
351+
# seq_lens: pad with 1 (avoid division-by-zero in attention kernels)
352+
self.seq_lens[len(self.seqs):len(self.seqs) + num_pad].fill_(1)
353+
# block_table: pad rows with dummy_page
354+
self.block_table[len(self.seqs):len(self.seqs) + num_pad].fill_(dummy_page)
355+
# query_start_loc: continue the cumulative sum — each dummy token counts
356+
# as 1 query token, so the padded entries are last_loc+1, last_loc+2, ...
357+
last_loc = self.query_start_loc[len(self.seqs)]
358+
self.query_start_loc[len(self.seqs) + 1:len(self.seqs) + num_pad + 1].copy_(
359+
last_loc + torch.arange(1, num_pad + 1, dtype=torch.int32)
360+
)
361+
362+
if self.use_mla:
363+
# Pad decode_seq_lens for the dummy sequences so that MLA decode
364+
# kernels see a valid (non-zero) sequence length for every row.
365+
self.decode_seq_lens[len(self.seqs):len(self.seqs) + num_pad].fill_(1)
366+
367+
return num_real_tokens
368+
319369
def _set_mla_metadata(self):
320370
if self.num_prefills > 0:
321371
self.prefill_query_start_loc[

gllm/memory_manager.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
self.vocab_size = vocab_size
8787
self.use_mla = use_mla
8888

89-
def init(self, segment_cls=Segment):
89+
def init(self, segment_cls=Segment, reserve_dummy_page: bool = False):
9090
free_mem_size, _ = torch.cuda.mem_get_info()
9191
num_max_pages = free_mem_size // self.get_sizeof_KV_per_page()
9292
num_pages = int(num_max_pages * self.gpu_memory_util)
@@ -113,6 +113,12 @@ def init(self, segment_cls=Segment):
113113
self.use_mla,
114114
)
115115

116+
# Reserve a dedicated dummy page for CUDA graph padding only when
117+
# CUDA graphs are enabled. This page is never returned to normal use,
118+
# so real sequences will never overwrite it, and padding dummy tokens
119+
# can safely write here.
120+
self.dummy_page: int = self.segment.allocate() if reserve_dummy_page else None
121+
116122
self.kv_cache_dtype = "auto"
117123
self.k_scale = torch.tensor(1.0, dtype=torch.float32)
118124
self.v_scale = self.k_scale
@@ -182,8 +188,8 @@ class PrefixMemoryManager(MemoryManager):
182188
def __init__(self, *args, **kwargs):
183189
super().__init__(*args, **kwargs)
184190

185-
def init(self):
186-
super().init(segment_cls=PrefixSegment)
191+
def init(self, reserve_dummy_page: bool = False):
192+
super().init(segment_cls=PrefixSegment, reserve_dummy_page=reserve_dummy_page)
187193

188194
# for prefill stage
189195
self.num_allocated_pages = 0

gllm/model_runner.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(
127127
self.disable_cuda_graph = disable_cuda_graph
128128
self.max_cuda_graph_bs = max_cuda_graph_bs
129129
self.size_to_graph: Dict[int, torch.cuda.CUDAGraph] = dict()
130-
self.capture_sizes = list(range(self.max_cuda_graph_bs, 0, -1))
130+
# Use power-of-two bucket sizes to reduce the number of captured graphs.
131+
# At runtime the actual batch is padded up to the nearest bucket.
132+
self.capture_sizes = self._build_capture_sizes(self.max_cuda_graph_bs)
131133

132134
# max length
133135
self.model_max_length = self.resolve_model_max_length(model_max_length)
@@ -142,6 +144,26 @@ def resolve_model_max_length(self, model_max_length):
142144
logger.info(f"Model max length: {model_max_length}")
143145
return model_max_length
144146

147+
@staticmethod
148+
def _build_capture_sizes(max_bs: int):
149+
"""Return power-of-two bucket sizes up to max_bs, in descending order.
150+
151+
For example, max_bs=20 → [20, 16, 8, 4, 2, 1].
152+
We always include 1 as a floor bucket.
153+
"""
154+
if max_bs <= 0:
155+
return []
156+
sizes = []
157+
s = 1
158+
while s <= max_bs:
159+
sizes.append(s)
160+
s *= 2
161+
# If max_bs is not itself a power of two, add it as the top bucket so
162+
# that batches of exactly max_bs can still use CUDA graph.
163+
if sizes[-1] != max_bs:
164+
sizes.append(max_bs)
165+
return list(reversed(sizes))
166+
145167
def init(self, mp_load_progress=None):
146168
self.model = self.model_loader.load_model(mp_load_progress)
147169
memory_manager_cls = (
@@ -171,8 +193,9 @@ def init(self, mp_load_progress=None):
171193
self.output_residual = torch.zeros((self.max_num_batched_tokens, self.hidden_size))
172194
# Profile run
173195
self.profile_run()
174-
# Init KV cache at last
175-
self.memory_manager.init()
196+
# Init KV cache at last; only reserve the dummy page when CUDA graphs
197+
# are actually enabled so we don't waste memory otherwise.
198+
self.memory_manager.init(reserve_dummy_page=not self.disable_cuda_graph)
176199

177200
if not self.disable_cuda_graph:
178201
self.capture_graph()
@@ -373,7 +396,7 @@ def profile_run(self):
373396
def capture_graph(self):
374397
iterator = self.capture_sizes
375398
if get_local_rank() == 0:
376-
# logger.info(f"Capturing cuda graph for sizes {self.capture_sizes}")
399+
logger.info(f"Capturing CUDA graphs for bucket sizes: {list(reversed(self.capture_sizes))}")
377400
iterator = tqdm(self.capture_sizes, desc="Capturing CUDA Graphs", ncols=100)
378401
memory_pool = torch.cuda.graph_pool_handle()
379402
for size in iterator:
@@ -426,9 +449,22 @@ def check_decode_batch(self):
426449
@torch.inference_mode()
427450
def step_once(self):
428451
num_cal_tokens = self.input_data.tokens_cpu.shape[0]
429-
# Only decode batch use cuda graph
430-
if self.check_decode_batch() and num_cal_tokens in self.size_to_graph:
431-
self.size_to_graph[num_cal_tokens].replay()
452+
# Only pure decode batches use CUDA graph.
453+
if self.check_decode_batch():
454+
# Find the smallest captured bucket >= actual batch size.
455+
padded_size = None
456+
for bucket in self.capture_sizes:
457+
if bucket >= num_cal_tokens:
458+
padded_size = bucket
459+
if padded_size is not None and padded_size in self.size_to_graph:
460+
# Pad input buffers to the bucket size with dummy values, then
461+
# replay the pre-captured graph.
462+
num_real_tokens = self.input_data.pad_for_cuda_graph(padded_size)
463+
self.size_to_graph[padded_size].replay()
464+
# After replay, use only the real-token slice for logits.
465+
num_cal_tokens = num_real_tokens
466+
else:
467+
self.forward()
432468
else:
433469
self.forward()
434470
if is_last_pp_rank():

gllm/models/deepseek_v2.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,18 @@ def __init__(self, layer_id: int, config):
157157
self.kv_lora_rank = config.kv_lora_rank
158158
self.rope_theta = getattr(config, "rope_theta", 10000)
159159
self.max_poistion_embeddings = getattr(config, "max_position_embeddings", 8192)
160-
self.rope_scaling = getattr(config, "rope_scaling", None)
160+
rope_scaling = getattr(config, "rope_scaling", None)
161+
if rope_scaling is None:
162+
self.rope_scaling = {
163+
"factor": 1.0,
164+
"original_max_position_embeddings": self.max_poistion_embeddings,
165+
}
166+
else:
167+
self.rope_scaling = dict(rope_scaling)
168+
self.rope_scaling.setdefault("factor", 1.0)
169+
self.rope_scaling.setdefault(
170+
"original_max_position_embeddings", self.max_poistion_embeddings
171+
)
161172

162173
if self.q_lora_rank is not None:
163174
self.q_a_proj = ReplicatedLinear(
@@ -286,7 +297,18 @@ def __init__(self, layer_id: int, config):
286297
self.kv_lora_rank = config.kv_lora_rank
287298
self.rope_theta = getattr(config, "rope_theta", 10000)
288299
self.max_poistion_embeddings = getattr(config, "max_position_embeddings", 8192)
289-
self.rope_scaling = getattr(config, "rope_scaling", None)
300+
rope_scaling = getattr(config, "rope_scaling", None)
301+
if rope_scaling is None:
302+
self.rope_scaling = {
303+
"factor": 1.0,
304+
"original_max_position_embeddings": self.max_poistion_embeddings,
305+
}
306+
else:
307+
self.rope_scaling = dict(rope_scaling)
308+
self.rope_scaling.setdefault("factor", 1.0)
309+
self.rope_scaling.setdefault(
310+
"original_max_position_embeddings", self.max_poistion_embeddings
311+
)
290312
self.layer_id = layer_id
291313

292314
if self.q_lora_rank is not None:

0 commit comments

Comments
 (0)