Skip to content

Commit f58bcbf

Browse files
[Hardware] AMD - Replace vllm CuMemAllocator dependency with torch_memory_saver (#444)
* remove vllm CuMemAllocator dependency * conduct pre-commit
1 parent 2710445 commit f58bcbf

File tree

8 files changed

+104
-115
lines changed

8 files changed

+104
-115
lines changed

docker/Dockerfile.rocm

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#### Use the base image
22

33
# The Docker image built with this Dockerfile:
4-
# Supports at least up to slime commit ID: d4a7741 (Sep 7, 2025) - supported by amd_patch/sglv0.5.0rc0
4+
# Supports at least up to slime commit ID: 2710445 (Oct 9, 2025) - supported by amd_patch/sglv0.5.0rc0
5+
# Still need to update amd_patch
56

67
# You can find the latest pre-built Docker image from here: https://hub.docker.com/r/rlsys/slime/tags
78
# Current latest docker img: `rlsys/slime:slime_ubuntu22.04_rocm6.3.4-patch-numa-patch_sglang0.4.9_megatron-patch_ray2.47.1_apex_torch-memory-saver0.0.8-patch-vim` manually add the patch to mitigate checkpoint loading issue. (vim /workspace/Megatron-LM-amd_version/megatron/training/checkpointing.py. Line: 1449 ~ 1457 - comment out if becasue of dismatch number of dist checkpoints
@@ -348,6 +349,15 @@ RUN pip install google-generativeai
348349
########################################
349350

350351

352+
########################################
353+
########Additional packages#############
354+
########################################
355+
RUN pip install tensorboard
356+
########################################
357+
########################################
358+
########################################
359+
360+
351361
WORKDIR /workspace/
352362

353363
CMD ["/usr/bin/bash"]

scripts/run-llama3.2-3B-Instruct-amd.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ pkill -9 python
1515
set -euxo pipefail
1616

1717
### AMD Support ###
18-
SLIME_DIR="/home/yushensu/projects/slime" # Need to change to your own path
19-
export SLIME_DIR=$SLIME_DIR
18+
SLIME_DIR="${SLIME_DIR:-/home/yushensu/projects/slime}" # Default path if not set in environment
19+
export SLIME_DIR
2020

21-
MODEL_DIR="/home/yushensu/projects/model" # Need to change to your own path
22-
export MODEL_DIR=$MODEL_DIR
21+
MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment
22+
export MODEL_DIR
2323

24-
DATA_DIR="/home/yushensu/projects/data" # Need to change to your own path
25-
export DATA_DIR=$DATA_DIR
24+
DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment
25+
export DATA_DIR
2626

2727
# For AMD GPU
2828
export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1
@@ -148,7 +148,7 @@ ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disab
148148
# Build the runtime environment JSON with proper variable substitution
149149
RUNTIME_ENV_JSON="{
150150
\"env_vars\": {
151-
\"PYTHONPATH\": \"/workspace/Megatron-LM-amd_version/\",
151+
\"PYTHONPATH\": \"/workspace/Megatron-LM/\",
152152
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\"
153153
}
154154
}"

scripts/run-qwen3-4B-amd.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ set -euxo pipefail
1515

1616

1717
### AMD Support ###
18-
SLIME_DIR="/home/yushensu/projects/slime" # Need to change to your own path
19-
export SLIME_DIR=$SLIME_DIR
18+
SLIME_DIR="${SLIME_DIR:-/home/yushensu/projects/slime}" # Default path if not set in environment
19+
export SLIME_DIR
2020

21-
MODEL_DIR="/home/yushensu/projects/model" # Need to change to your own path
22-
export MODEL_DIR=$MODEL_DIR
21+
MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment
22+
export MODEL_DIR
2323

24-
DATA_DIR="/home/yushensu/projects/data" # Need to change to your own path
25-
export DATA_DIR=$DATA_DIR
24+
DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment
25+
export DATA_DIR
2626

2727
# For AMD GPU
2828
export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1

scripts/run-qwen3-8B-amd.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ set -euxo pipefail
1919

2020

2121
### AMD Support ###
22-
SLIME_DIR="/home/yushensu/projects/slime" # Need to change to your own path
23-
export SLIME_DIR=$SLIME_DIR
22+
SLIME_DIR="${SLIME_DIR:-/home/yushensu/projects/slime}" # Default path if not set in environment
23+
export SLIME_DIR
2424

25-
MODEL_DIR="/home/yushensu/projects/model" # Need to change to your own path
26-
export MODEL_DIR=$MODEL_DIR
25+
MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment
26+
export MODEL_DIR
2727

28-
DATA_DIR="/home/yushensu/projects/data" # Need to change to your own path
29-
export DATA_DIR=$DATA_DIR
28+
DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment
29+
export DATA_DIR
3030

3131
# For AMD GPU
3232
export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1

slime/backends/megatron_utils/actor.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,9 @@
99
import ray
1010
import torch
1111
import torch.distributed as dist
12-
from ray.actor import ActorHandle
13-
14-
if torch.version.hip:
15-
from vllm.device_allocator.cumem import CuMemAllocator
16-
else:
17-
from torch_memory_saver import torch_memory_saver
18-
1912
from megatron.core import mpu
13+
from ray.actor import ActorHandle
14+
from torch_memory_saver import torch_memory_saver
2015
from transformers import AutoConfig, AutoTokenizer
2116

2217
from slime.ray.train_actor import TrainRayActor
@@ -164,11 +159,7 @@ def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
164159
if hasattr(mpu, "destroy_process_groups"):
165160
mpu.destroy_process_groups()
166161

167-
if not torch.version.hip:
168-
torch_memory_saver.pause()
169-
else:
170-
allocator = CuMemAllocator.get_instance()
171-
allocator.sleep(offload_tags=tags)
162+
torch_memory_saver.pause()
172163

173164
print_memory("after offload model")
174165

@@ -188,11 +179,7 @@ def wake_up(self, tags: Union[str, Tuple[str, ...]]) -> None:
188179
if isinstance(tags, str):
189180
tags = (tags,)
190181

191-
if not torch.version.hip:
192-
torch_memory_saver.resume()
193-
else:
194-
allocator = CuMemAllocator.get_instance()
195-
allocator.wake_up(tags)
182+
torch_memory_saver.resume()
196183

197184
clear_memory()
198185
if hasattr(mpu, "reload_process_groups"):
@@ -423,7 +410,7 @@ def update_weights(self) -> None:
423410
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
424411
dist.barrier(group=get_gloo_group())
425412

426-
with torch_memory_saver.disable() if self.args.offload and not torch.version.hip else nullcontext():
413+
with torch_memory_saver.disable() if self.args.offload else nullcontext():
427414
print_memory("before update_weights")
428415
self.weight_updater.update_weights()
429416
print_memory("after update_weights")

slime/backends/megatron_utils/model.py

Lines changed: 57 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import gc
33
import math
44
import os
5-
from contextlib import nullcontext
65
from functools import partial
76

87
import torch
@@ -26,9 +25,6 @@
2625
from .loss import loss_function
2726
from .model_provider import get_model_provider_func
2827

29-
if torch.version.hip:
30-
from vllm.device_allocator.cumem import CuMemAllocator
31-
3228

3329
def get_optimizer_param_scheduler(args, optimizer):
3430
"""Build the learning rate scheduler."""
@@ -80,71 +76,64 @@ def setup_model_and_optimizer(
8076

8177
model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder, wrap_with_ddp=False)
8278

83-
with (
84-
CuMemAllocator.get_instance().use_memory_pool(tag="model")
85-
if args.offload and torch.version.hip
86-
else nullcontext()
87-
):
88-
config = get_model_config(model[0])
89-
90-
kwargs = {}
91-
for f in dataclasses.fields(DistributedDataParallelConfig):
92-
if hasattr(args, f.name):
93-
kwargs[f.name] = getattr(args, f.name)
94-
kwargs["grad_reduce_in_fp32"] = args.accumulate_allreduce_grads_in_fp32
95-
kwargs["check_for_nan_in_grad"] = args.check_for_nan_in_loss_and_grad
96-
kwargs["check_for_large_grads"] = args.check_for_large_grads
97-
kwargs["bucket_size"] = args.ddp_bucket_size
98-
kwargs["pad_buckets_for_high_nccl_busbw"] = args.ddp_pad_buckets_for_high_nccl_busbw
99-
kwargs["average_in_collective"] = args.ddp_average_in_collective
100-
ddp_config = DistributedDataParallelConfig(**kwargs)
101-
102-
# In the custom FSDP and DDP use path, we need to initialize the bucket size.
103-
# If bucket_size is not provided as an input, use sane default.
104-
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
105-
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
106-
# latency-bound.
107-
if ddp_config.bucket_size is None:
108-
ddp_config.bucket_size = max(
109-
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
110-
)
111-
# Set bucket_size to infinity if overlap_grad_reduce is False.
112-
if not ddp_config.overlap_grad_reduce:
113-
ddp_config.bucket_size = None
114-
115-
model = [
116-
DDP(
117-
config=config,
118-
ddp_config=ddp_config,
119-
module=model_chunk,
120-
# Turn off bucketing for model_chunk 2 onwards, since communication for these
121-
# model chunks is overlapped with compute anyway.
122-
disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step,
123-
)
124-
for (model_chunk_idx, model_chunk) in enumerate(model)
125-
]
126-
127-
# Optimizer
128-
kwargs = {}
129-
for f in dataclasses.fields(OptimizerConfig):
130-
if hasattr(args, f.name):
131-
kwargs[f.name] = getattr(args, f.name)
132-
config = OptimizerConfig(**kwargs)
133-
config.timers = None
134-
135-
optimizer = get_megatron_optimizer(
136-
config,
137-
model,
138-
no_wd_decay_cond,
139-
scale_lr_cond,
140-
lr_mult,
141-
use_gloo_process_groups=args.enable_gloo_process_groups,
79+
config = get_model_config(model[0])
80+
81+
kwargs = {}
82+
for f in dataclasses.fields(DistributedDataParallelConfig):
83+
if hasattr(args, f.name):
84+
kwargs[f.name] = getattr(args, f.name)
85+
kwargs["grad_reduce_in_fp32"] = args.accumulate_allreduce_grads_in_fp32
86+
kwargs["check_for_nan_in_grad"] = args.check_for_nan_in_loss_and_grad
87+
kwargs["check_for_large_grads"] = args.check_for_large_grads
88+
kwargs["bucket_size"] = args.ddp_bucket_size
89+
kwargs["pad_buckets_for_high_nccl_busbw"] = args.ddp_pad_buckets_for_high_nccl_busbw
90+
kwargs["average_in_collective"] = args.ddp_average_in_collective
91+
ddp_config = DistributedDataParallelConfig(**kwargs)
92+
93+
# In the custom FSDP and DDP use path, we need to initialize the bucket size.
94+
# If bucket_size is not provided as an input, use sane default.
95+
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
96+
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
97+
# latency-bound.
98+
if ddp_config.bucket_size is None:
99+
ddp_config.bucket_size = max(40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True))
100+
# Set bucket_size to infinity if overlap_grad_reduce is False.
101+
if not ddp_config.overlap_grad_reduce:
102+
ddp_config.bucket_size = None
103+
104+
model = [
105+
DDP(
106+
config=config,
107+
ddp_config=ddp_config,
108+
module=model_chunk,
109+
# Turn off bucketing for model_chunk 2 onwards, since communication for these
110+
# model chunks is overlapped with compute anyway.
111+
disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step,
142112
)
143-
opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer)
144-
for optimizer in optimizer.chained_optimizers:
145-
if not getattr(optimizer, "init_state_fn", None):
146-
continue
147-
optimizer.init_state_fn(optimizer.optimizer, optimizer.config)
113+
for (model_chunk_idx, model_chunk) in enumerate(model)
114+
]
115+
116+
# Optimizer
117+
kwargs = {}
118+
for f in dataclasses.fields(OptimizerConfig):
119+
if hasattr(args, f.name):
120+
kwargs[f.name] = getattr(args, f.name)
121+
config = OptimizerConfig(**kwargs)
122+
config.timers = None
123+
124+
optimizer = get_megatron_optimizer(
125+
config,
126+
model,
127+
no_wd_decay_cond,
128+
scale_lr_cond,
129+
lr_mult,
130+
use_gloo_process_groups=args.enable_gloo_process_groups,
131+
)
132+
opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer)
133+
for optimizer in optimizer.chained_optimizers:
134+
if not getattr(optimizer, "init_state_fn", None):
135+
continue
136+
optimizer.init_state_fn(optimizer.optimizer, optimizer.config)
148137

149138
return model, optimizer, opt_param_scheduler
150139

slime/ray/actor_group.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional
33

44
import ray
5-
import torch
65
from ray.util.placement_group import PlacementGroup
76
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
87

@@ -62,7 +61,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor, wandb_run_id: Optiona
6261
**{name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST},
6362
}
6463

65-
if not torch.version.hip and self.args.offload:
64+
if self.args.offload:
6665
import torch_memory_saver
6766

6867
dynlib_path = os.path.join(

slime/ray/train_actor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,21 @@ def init(self, args, role, wandb_run_id, with_ref=False):
5757
args.world_size = dist.get_world_size()
5858

5959
try:
60-
import pynvml
60+
if torch.version.hip is not None:
61+
print(f"Detected ROCm/HIP environment, skipping NUMA affinity setup")
62+
# will find the coresponding API to implement ROCm version as below
63+
else:
64+
import pynvml
6165

62-
pynvml.nvmlInit()
66+
pynvml.nvmlInit()
6367

64-
local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node
68+
local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node
6569

66-
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
67-
pynvml.nvmlDeviceSetCpuAffinity(handle)
70+
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
71+
pynvml.nvmlDeviceSetCpuAffinity(handle)
6872

69-
print(f"Set NUMA affinity for GPU {local_rank}")
70-
pynvml.nvmlShutdown()
73+
print(f"Set NUMA affinity for GPU {local_rank}")
74+
pynvml.nvmlShutdown()
7175

7276
except ImportError:
7377
print(f"Warning: pynvml not available, skipping NUMA affinity setup")

0 commit comments

Comments
 (0)