-
Notifications
You must be signed in to change notification settings - Fork 309
Expand file tree
/
Copy pathgraph.py
More file actions
468 lines (404 loc) · 17.5 KB
/
graph.py
File metadata and controls
468 lines (404 loc) · 17.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# CUDA Graph implementation modified from vLLM:
# https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
from dataclasses import dataclass
from functools import lru_cache
from statistics import median
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import numpy as np
import torch
from loguru import logger
from torch import nn
from tqdm import tqdm
from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata
from lorax_server.adapters.lora import BatchLoraWeights, RankSegments
from lorax_server.adapters.types import LORA
from lorax_server.models.cache_manager import BLOCK_SIZE, get_cache_manager
from lorax_server.utils.sgmv import BGMV_MAX_RANK
if TYPE_CHECKING:
from lorax_server.models.flash_causal_lm import FlashCausalLMBatch
from lorax_server.models.model import Model
# TODO(travis): make this configurable by model / user
MAX_BATCH_SIZE = 256
MAX_RANK = BGMV_MAX_RANK
SLOT_PAD_VALUE = -1
SEGMENT_PAD_VALUE = -1
# Cached batch sizes used in vLLM. This and the helper function `get_cached_batch_size` below
# must be kept in sync.
BATCH_SIZE_INCREMENT = 32
CACHED_BATCH_SIZES = [1, 2, 4, 8, 16] + [BATCH_SIZE_INCREMENT * (i + 1) for i in range(8)]
# Include 0 to ensure we can use cuda graphs without adapters
# TODO(travis): use padding to allow for more ranks without increasing memory usage
CACHED_MAX_RANKS = [0, 8, 16, 32, 64]
_allowed_ranks = set(CACHED_MAX_RANKS)
assert all([r <= BGMV_MAX_RANK for r in _allowed_ranks]), f"Invalid ranks: {_allowed_ranks}"
MAX_SAMPLES = 3
def get_cached_batch_size(batch_size: int) -> int:
if batch_size == 1:
return 1
if batch_size == 2:
return 2
if batch_size <= 4:
return 4
if batch_size <= 8:
return 8
if batch_size <= 16:
return 16
return (batch_size + BATCH_SIZE_INCREMENT - 1) // BATCH_SIZE_INCREMENT * BATCH_SIZE_INCREMENT
def pad_and_fill(dest: torch.Tensor, src: torch.Tensor, pad_value: int):
dest[: src.shape[0]].copy_(src, non_blocking=True)
dest[src.shape[0] :].fill_(pad_value)
def next_pow_2(x: int) -> int:
assert x > 0
return 1 << (x - 1).bit_length()
@dataclass
class GraphState:
input_ids: torch.Tensor
position_ids: torch.Tensor
block_tables: torch.Tensor
slots: torch.Tensor
input_lengths: torch.Tensor
adapter_data: AdapterBatchData
traced_adapter_layer_names: Set[str]
@lru_cache(maxsize=1)
def get_max_graph_state(
device: torch.device,
adapter_layers: Tuple[str],
max_total_tokens: int,
sliding_window_blocks: Optional[int] = None,
) -> GraphState:
max_num_blocks = (max_total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE
if sliding_window_blocks is not None:
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
max_num_blocks = max(max_num_blocks, sliding_window_blocks)
block_tables_arr = np.zeros((MAX_BATCH_SIZE, max_num_blocks), dtype=np.int32)
block_tables = torch.from_numpy(block_tables_arr).to(device=device)
input_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device)
position_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device)
slots = torch.full((MAX_BATCH_SIZE,), SLOT_PAD_VALUE, dtype=torch.int64, device=device)
input_lengths = torch.ones((MAX_BATCH_SIZE,), dtype=torch.int32, device=device)
adapter_weight_data = {}
for layer_name in adapter_layers:
adapter_weight_data[layer_name] = BatchLoraWeights(
lora_a={},
lora_b={},
adapter_index_configs={},
rank_data={
MAX_RANK: RankSegments(
rank=MAX_RANK,
adapter_index_map=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
segment_starts=None,
segment_ends=None,
tmp_shrink=None,
tmp_expand=None,
),
},
use_sgmv=False, # bgmv during decode
)
return GraphState(
input_ids=input_ids,
position_ids=position_ids,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
adapter_data=AdapterBatchData(
meta=AdapterBatchMetadata(
adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
adapter_set=set(),
adapter_segments=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
segment_indices=[],
),
data=adapter_weight_data,
prefill=False,
),
traced_adapter_layer_names=set(adapter_layers),
)
class GraphWrapper:
def __init__(
self,
graph: torch.cuda.CUDAGraph,
memory_pool: Tuple[int, int],
input_state: GraphState,
output_states: Tuple[torch.Tensor, Optional[torch.Tensor]],
model: nn.Module,
):
self.graph = graph
self.memory_pool = memory_pool
self.input_state = input_state
self.output_states = output_states
self.model = model
@staticmethod
def trace(
model: nn.Module,
device: torch.device,
adapter_layers: Tuple[str],
batch_size: int,
max_rank: int,
memory_pool: Tuple[int, int],
max_total_tokens: int,
sliding_window_blocks: Optional[int] = None,
traced_adapter_layer_names: Optional[Set[str]] = None,
) -> "GraphWrapper":
max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks)
# WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2
# as the segment size. This is a workaround until we can figure out why.
# Specifically, this issue has been observed with batch_size=96.
# I suspect it is related to synchronization and the chunk size (256) used in the kernel.
# But we need to investigate further.
segment_size = next_pow_2(batch_size)
traced_adapter_layer_names = traced_adapter_layer_names or set()
adapter_weight_data = {}
for layer_name, weight_data in max_input_state.adapter_data.data.items():
if layer_name not in traced_adapter_layer_names:
continue
adapter_weight_data[layer_name] = {
LORA: BatchLoraWeights(
lora_a={},
lora_b={},
adapter_index_configs={},
rank_data=(
{
max_rank: RankSegments(
rank=max_rank,
adapter_index_map=weight_data.rank_data[MAX_RANK].adapter_index_map[:batch_size],
lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size],
lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size],
indices=weight_data.rank_data[MAX_RANK].indices[:batch_size],
segment_starts=None,
segment_ends=None,
tmp_shrink=None,
tmp_expand=None,
),
}
if max_rank > 0
else {}
),
use_sgmv=False, # bgmv during decode
)
}
input_state = GraphState(
input_ids=max_input_state.input_ids[:batch_size],
position_ids=max_input_state.position_ids[:batch_size],
block_tables=max_input_state.block_tables[:batch_size],
slots=max_input_state.slots[:batch_size],
input_lengths=max_input_state.input_lengths[:batch_size],
adapter_data=AdapterBatchData(
meta=AdapterBatchMetadata(
adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size],
adapter_set=max_input_state.adapter_data.meta.adapter_set,
adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size],
segment_indices=max_input_state.adapter_data.meta.segment_indices,
),
data=adapter_weight_data,
prefill=False,
),
traced_adapter_layer_names=traced_adapter_layer_names,
)
torch.cuda.synchronize(device)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=memory_pool): # noqa: SIM117
output_states = model.forward(
input_ids=input_state.input_ids,
position_ids=input_state.position_ids,
cu_seqlen_prefill=None,
kv_cache=get_cache_manager().kv_cache,
block_tables=input_state.block_tables,
slots=input_state.slots,
input_lengths=input_state.input_lengths,
max_s=max_total_tokens,
adapter_data=input_state.adapter_data,
lm_head_indices=None,
)
torch.cuda.synchronize(device)
return GraphWrapper(graph, graph.pool(), input_state, output_states, model)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_data: AdapterBatchData,
lm_head_indices: Optional[torch.Tensor] = None,
) -> None:
pad_and_fill(self.input_state.input_ids, input_ids, 0)
pad_and_fill(self.input_state.position_ids, position_ids, 0)
pad_and_fill(self.input_state.slots, slots, SLOT_PAD_VALUE)
pad_and_fill(self.input_state.input_lengths, input_lengths, 0)
self.input_state.block_tables.zero_()
self.input_state.block_tables[: block_tables.shape[0], : block_tables.shape[1]] = block_tables
for layer_name, weight_data in self.input_state.adapter_data.data.items():
# TODO(travis): generalize this to support other adapter types
lora_data = weight_data[LORA]
if layer_name not in adapter_data.data:
# zero out all the segments
for rank_data in lora_data.rank_data.values():
rank_data.indices.fill_(SEGMENT_PAD_VALUE)
continue
source_data = adapter_data.data[layer_name][LORA]
dest_data = lora_data
for rank, source_rank_data in source_data.rank_data.items():
dest_rank_data = dest_data.rank_data[rank]
pad_and_fill(dest_rank_data.lora_a_ptr, source_rank_data.lora_a_ptr, 0)
pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0)
pad_and_fill(dest_rank_data.indices, source_rank_data.indices, SEGMENT_PAD_VALUE)
self.graph.replay()
return tuple(state[: input_ids.shape[0]] if state is not None else None for state in self.output_states)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class GraphCache:
def __init__(
self,
model: "Model",
device: torch.device,
adapter_layers: List[str],
default_traced_adapter_layers: List[str],
max_total_tokens: int,
sliding_window_blocks: Optional[int] = None,
):
self.model = model
self.device = device
self.adapter_layers = tuple(adapter_layers)
self.default_traced_adapter_layers = set(default_traced_adapter_layers)
self.memory_pool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
self.cache: Dict[Tuple[int, int], GraphWrapper] = {}
self.max_total_tokens = max_total_tokens
self.sliding_window_blocks = sliding_window_blocks
def can_use_graph(
self,
batch: "FlashCausalLMBatch",
adapter_data: AdapterBatchData,
) -> bool:
ranks = adapter_data.ranks()
nranks = len(ranks)
max_rank = max(ranks) if len(ranks) > 0 else 0
batch_size = batch.input_ids.shape[0]
max_s = batch.max_seqlen
# Only allow LoRA adapters for now
adapter_keys = set(adapter_data.adapter_keys())
# TODO(travis): allow using CUDA graphs with multi-rank batches
return (
torch.cuda.is_available()
and batch_size <= MAX_BATCH_SIZE
and max_s <= self.max_total_tokens
and max_rank <= MAX_RANK
and nranks <= 1
and max_rank in _allowed_ranks
and all(k == LORA for k in adapter_keys)
)
def get_estimated_cache_memory(self) -> int:
# Store off graphs into temporary cache to discard after estimation
tmp_cache = {}
pool = None
# Use the largest batch size to overestimate memory overhead
batch_size = CACHED_BATCH_SIZES[-1]
samples = []
for i, max_rank in enumerate(reversed(CACHED_MAX_RANKS)):
torch.cuda.synchronize(self.device)
free_memory_before, _ = torch.cuda.mem_get_info(self.device)
key = (batch_size, max_rank)
graph = GraphWrapper.trace(
self.model,
self.device,
self.adapter_layers,
batch_size,
max_rank,
pool,
self.max_total_tokens,
self.sliding_window_blocks,
self.adapter_layers, # estimate memory assuming all adapters are traced
)
tmp_cache[key] = graph
pool = graph.memory_pool
torch.cuda.synchronize(self.device)
free_memory_after, _ = torch.cuda.mem_get_info(self.device)
# Measure memory difference after tracing the graph,
# discard first sample to account for global state initialization
delta_memory = free_memory_before - free_memory_after
if i > 0:
samples.append(delta_memory)
# Tracing all graphs can take a while, so limit the number of samples
if len(samples) == MAX_SAMPLES:
break
# Estimate memory usage for all batch sizes and ranks
ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS)
per_graph_memory = median(samples)
return ngraphs * per_graph_memory
def warmup(self):
ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS)
pool = None
with tqdm(total=ngraphs, desc="Trace CUDA graphs") as pbar:
for batch_size in reversed(CACHED_BATCH_SIZES):
pbar.set_postfix({"batch_size": batch_size})
for max_rank in reversed(CACHED_MAX_RANKS):
key = (batch_size, max_rank)
graph = GraphWrapper.trace(
self.model,
self.device,
self.adapter_layers,
batch_size,
max_rank,
pool,
self.max_total_tokens,
self.sliding_window_blocks,
self.default_traced_adapter_layers,
)
self.cache[key] = graph
pool = graph.memory_pool
pbar.update(1)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_data: AdapterBatchData,
lm_head_indices: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
batch_size = get_cached_batch_size(input_ids.shape[0])
max_rank = adapter_data.max_rank
key = (batch_size, max_rank)
graph = self.cache.get(key)
if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()):
logger.info(
"Retrace graph with new adapter layers: {} -> {}",
graph.input_state.traced_adapter_layer_names,
adapter_data.layer_names(),
)
graph = GraphWrapper.trace(
self.model,
self.device,
self.adapter_layers,
batch_size,
max_rank,
self.memory_pool,
self.max_total_tokens,
self.sliding_window_blocks,
adapter_data.layer_names(),
)
self.cache[key] = graph
output_states = graph.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
adapter_data=adapter_data,
lm_head_indices=lm_head_indices,
)
return output_states
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)