Skip to content

Commit ff78792

Browse files
author
kip-cxj
committed
adapt new architecture
1 parent 1b49a38 commit ff78792

File tree

5 files changed

+142
-143
lines changed

5 files changed

+142
-143
lines changed

tests/checkpoint_engine/test_correctness_on_gpu.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RayResourcePool,
2323
split_resource_pool,
2424
)
25+
from verl.utils.device import get_device_name
2526
from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig
2627

2728

@@ -127,6 +128,54 @@ async def test_nixl_checkpoint_engine(
127128
ray.shutdown()
128129

129130

131+
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
132+
@pytest.mark.asyncio
133+
@pytest.mark.parametrize("rebuild_group", [False])
134+
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
135+
async def test_kimi_checkpoint_engine(
136+
rebuild_group,
137+
num_trainer,
138+
num_rollout,
139+
num_nodes=1,
140+
num_gpus_per_node=8,
141+
check_allclose=True,
142+
model_path="~/models/Qwen/Qwen3-8B-Base",
143+
):
144+
model_path = os.path.expanduser(model_path)
145+
ray.init(
146+
runtime_env={
147+
"env_vars": {
148+
"NCCL_IB_HCA": "mlx5",
149+
"VERL_LOGGING_LEVEL": "DEBUG",
150+
"ASCEND_USE_SHORT_CONNECTION": "1",
151+
}
152+
}
153+
)
154+
155+
# initialize config
156+
checkpoint_engine_config = CheckpointEngineConfig(
157+
backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}}
158+
)
159+
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
160+
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)
161+
162+
# create trainer and rollout worker group
163+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
164+
resource_pool.get_placement_groups(device_name=get_device_name())
165+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
166+
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
167+
trainer.reset()
168+
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)
169+
170+
# create checkpoint engine manager
171+
checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas)
172+
for _ in range(3):
173+
await checkpoint_manager.update_weights()
174+
rollout.check_weights()
175+
176+
ray.shutdown()
177+
178+
130179
if __name__ == "__main__":
131180
test_nccl_checkpoint_engine(
132181
rebuild_group=False,

tests/checkpoint_engine/test_correctness_on_npu.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,56 @@ async def test_hccl_checkpoint_engine(
7474
ray.shutdown()
7575

7676

77+
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
78+
@pytest.mark.asyncio
79+
@pytest.mark.parametrize("rebuild_group", [False])
80+
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
81+
async def test_kimi_checkpoint_engine(
82+
rebuild_group,
83+
num_trainer,
84+
num_rollout,
85+
num_nodes=1,
86+
num_gpus_per_node=8,
87+
check_allclose=True,
88+
model_path="~/models/Qwen/Qwen3-8B-Base",
89+
):
90+
model_path = os.path.expanduser(model_path)
91+
ray.init(
92+
runtime_env={
93+
"env_vars": {
94+
"HCCL_CONNECT_TIMEOUT": "1500",
95+
"HCCL_HOST_SOCKET_PORT_RANGE": "60000-60050",
96+
"HCCL_NPU_SOCKET_PORT_RANGE": "61000-61050",
97+
"VERL_LOGGING_LEVEL": "DEBUG",
98+
"ASCEND_USE_SHORT_CONNECTION": "1",
99+
}
100+
}
101+
)
102+
103+
# initialize config
104+
checkpoint_engine_config = CheckpointEngineConfig(
105+
backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}}
106+
)
107+
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
108+
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)
109+
110+
# create trainer and rollout worker group
111+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
112+
resource_pool.get_placement_groups(device_name=get_device_name())
113+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
114+
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
115+
trainer.reset()
116+
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)
117+
118+
# create checkpoint engine manager
119+
checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas)
120+
for _ in range(3):
121+
await checkpoint_manager.update_weights()
122+
rollout.check_weights()
123+
124+
ray.shutdown()
125+
126+
77127
if __name__ == "__main__":
78128
test_hccl_checkpoint_engine(
79129
rebuild_group=False,

tests/checkpoint_engine/test_kimi_checkpoint_engine.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

verl/checkpoint_engine/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
except ImportError:
4545
HCCLCheckpointEngine = None
4646

47-
4847
try:
4948
from .nixl_checkpoint_engine import NIXLCheckpointEngine
5049

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,10 @@ async def receive_tensor(
174174

175175
@dataclass
176176
class MasterMetadata:
177-
ip: str
178-
port: int
177+
zmq_ip: str
178+
zmq_port: int
179+
dist_ip: str
180+
dist_port: int
179181

180182

181183
class BroadcastOperation:
@@ -231,17 +233,11 @@ class KIMICheckpointEngine(CheckpointEngine):
231233

232234
def __init__(
233235
self,
234-
train_world_size: int,
235-
rollout_world_size: int,
236236
bucket_size: int,
237237
rebuild_group: bool = False,
238238
is_master: bool = False,
239239
rollout_dtype: torch.dtype = torch.bfloat16,
240240
) -> None:
241-
self.train_world_size = train_world_size
242-
self.rollout_world_size = rollout_world_size
243-
self.world_size = train_world_size + rollout_world_size
244-
245241
self.bucket_size = bucket_size
246242
self.rebuild_group = rebuild_group
247243
self.rollout_dtype = rollout_dtype
@@ -254,39 +250,65 @@ def prepare(self) -> MasterMetadata:
254250
self.ip = ray.util.get_node_ip_address().strip("[]")
255251
self.listen_port, _ = get_free_port(self.ip)
256252

257-
return MasterMetadata(ip=self.ip, port=self.listen_port) if self.is_master else None
253+
return (
254+
MasterMetadata(zmq_ip=None, zmq_port=None, dist_ip=self.ip, dist_port=self.listen_port)
255+
if self.is_master
256+
else None
257+
)
258258

259-
def finish(self):
259+
def finalize(self):
260260
"""Destroy the ckpt engine process group if rebuild_group is True."""
261261
if self.rebuild_group:
262262
dist.destroy_process_group()
263263
self.rank = None
264264
self.world_size = None
265265
self.initialized = False
266266

267-
def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata):
267+
@classmethod
268+
def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]):
269+
trainer_kwargs = {
270+
"method": ["init_process_group"] * trainer_world_size,
271+
"rank": list(range(0, trainer_world_size)),
272+
"trainer_world_size": [trainer_world_size] * trainer_world_size,
273+
"rollout_world_size": [rollout_world_size] * rollout_world_size,
274+
"master_metadata": [metadata[0]] * trainer_world_size,
275+
}
276+
rollout_kwargs = {
277+
"method": ["init_process_group"] * rollout_world_size,
278+
"rank": list(range(trainer_world_size, trainer_world_size + rollout_world_size)),
279+
"trainer_world_size": [trainer_world_size] * trainer_world_size,
280+
"rollout_world_size": [rollout_world_size] * rollout_world_size,
281+
"master_metadata": [metadata[0]] * rollout_world_size,
282+
}
283+
return trainer_kwargs, rollout_kwargs
284+
285+
def init_process_group(self, rank: int, trainer_world_size: int, rollout_world_size :int, master_metadata: MasterMetadata):
268286
"""Initialize the ckpt engine process group.
269287
270288
Args:
271289
rank (int): The rank of the current process.
272290
world_size (int): The total number of processes.
273291
"""
274292
self.rank = rank
293+
self.trainer_world_size = trainer_world_size
294+
self.rollout_world_size = rollout_world_size
295+
self.world_size = trainer_world_size + rollout_world_size
275296
# unregister_memory in transfer engine is not supported on NPU,
276297
# so we have to initialize ParameterServer each time
277298
if get_device_name() == "npu" or not self.initialized:
278-
self.parameter_server = ParameterServer(rank=rank, world_size=world_size, auto_pg=False, custom_dist=True)
279-
self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server)
280-
if not self.initialized:
281-
dist.init_process_group(
282-
host=master_metadata.ip,
283-
port=master_metadata.port,
299+
self.parameter_server = ParameterServer(
284300
rank=rank,
285-
world_size=world_size,
286-
backend=get_nccl_backend(),
301+
world_size=self.world_size,
302+
auto_pg=False,
303+
master_addr=master_metadata.dist_ip,
304+
master_port=master_metadata.dist_port,
287305
)
306+
self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server)
307+
if not self.initialized:
308+
dist.use_backend(f"vllm_{get_nccl_backend()}")
309+
self.parameter_server.init_process_group()
288310

289-
self.rollout_ranks = list(range(self.train_world_size, world_size))
311+
self.rollout_ranks = list(range(self.trainer_world_size, self.world_size))
290312
self.rollout_group = dist.new_group(self.rollout_ranks)
291313
self.initialized = True
292314

@@ -304,7 +326,7 @@ def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch
304326
start_time = time.time()
305327
named_tensors = {}
306328
for named_tensors_gpu in ckpt_get_named_tensor_buckets(
307-
weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype
329+
weights, self.bucket_size, self.trainer_world_size, self.rank, self.rollout_dtype
308330
):
309331
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
310332
futures = [

0 commit comments

Comments
 (0)