Skip to content

Commit 0a3a347

Browse files
authored
[Feat] Support offload cuda graph (#354)
1 parent 3f87d04 commit 0a3a347

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

slime/ray/rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def init_rollout_engines(args, pg, all_rollout_engines):
272272
"SGLANG_JIT_DEEPGEMM_PRECOMPILE": "false",
273273
"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true",
274274
"SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true",
275+
"SGLANG_MEMORY_SAVER_CUDA_GRAPH": "true",
275276
}
276277
},
277278
).remote(args, rank=i)

train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import ray
22
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
33

4+
try:
5+
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
6+
except ImportError:
7+
GPU_MEMORY_TYPE_CUDA_GRAPH = None
8+
49
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
510
from slime.utils.arguments import parse_args
611
from slime.utils.wandb_utils import init_wandb_primary
@@ -25,6 +30,8 @@ def train(args):
2530
actor_model.update_weights()
2631

2732
if args.offload:
33+
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
34+
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
2835
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
2936

3037
# train loop.
@@ -70,6 +77,8 @@ def train(args):
7077
actor_model.update_weights()
7178

7279
if args.offload:
80+
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
81+
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
7382
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
7483

7584
if args.eval_interval is not None and (

0 commit comments

Comments
 (0)