Skip to content

Commit ab708eb

Browse files
author
kip-cxj
committed
fix pre-commit
1 parent ff78792 commit ab708eb

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

tests/checkpoint_engine/test_correctness_on_gpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ async def test_kimi_checkpoint_engine(
147147
"env_vars": {
148148
"NCCL_IB_HCA": "mlx5",
149149
"VERL_LOGGING_LEVEL": "DEBUG",
150-
"ASCEND_USE_SHORT_CONNECTION": "1",
151150
}
152151
}
153152
)

tests/checkpoint_engine/test_correctness_on_npu.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,22 @@ async def test_hccl_checkpoint_engine(
7777
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
7878
@pytest.mark.asyncio
7979
@pytest.mark.parametrize("rebuild_group", [False])
80-
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
80+
@pytest.mark.parametrize("num_trainer, num_rollout", [(4, 28)])
8181
async def test_kimi_checkpoint_engine(
8282
rebuild_group,
8383
num_trainer,
8484
num_rollout,
85-
num_nodes=1,
86-
num_gpus_per_node=8,
85+
num_nodes=2,
86+
num_gpus_per_node=16,
8787
check_allclose=True,
88-
model_path="~/models/Qwen/Qwen3-8B-Base",
88+
model_path="~/models/Qwen/Qwen3-32B",
8989
):
9090
model_path = os.path.expanduser(model_path)
9191
ray.init(
9292
runtime_env={
9393
"env_vars": {
9494
"HCCL_CONNECT_TIMEOUT": "1500",
95-
"HCCL_HOST_SOCKET_PORT_RANGE": "60000-60050",
96-
"HCCL_NPU_SOCKET_PORT_RANGE": "61000-61050",
9795
"VERL_LOGGING_LEVEL": "DEBUG",
98-
"ASCEND_USE_SHORT_CONNECTION": "1",
9996
}
10097
}
10198
)

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,25 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada
270270
"method": ["init_process_group"] * trainer_world_size,
271271
"rank": list(range(0, trainer_world_size)),
272272
"trainer_world_size": [trainer_world_size] * trainer_world_size,
273-
"rollout_world_size": [rollout_world_size] * rollout_world_size,
273+
"rollout_world_size": [rollout_world_size] * trainer_world_size,
274274
"master_metadata": [metadata[0]] * trainer_world_size,
275275
}
276276
rollout_kwargs = {
277277
"method": ["init_process_group"] * rollout_world_size,
278278
"rank": list(range(trainer_world_size, trainer_world_size + rollout_world_size)),
279-
"trainer_world_size": [trainer_world_size] * trainer_world_size,
279+
"trainer_world_size": [trainer_world_size] * rollout_world_size,
280280
"rollout_world_size": [rollout_world_size] * rollout_world_size,
281281
"master_metadata": [metadata[0]] * rollout_world_size,
282282
}
283283
return trainer_kwargs, rollout_kwargs
284284

285-
def init_process_group(self, rank: int, trainer_world_size: int, rollout_world_size :int, master_metadata: MasterMetadata):
285+
def init_process_group(
286+
self,
287+
rank: int,
288+
trainer_world_size: int,
289+
rollout_world_size: int,
290+
master_metadata: MasterMetadata,
291+
):
286292
"""Initialize the ckpt engine process group.
287293
288294
Args:
@@ -293,9 +299,8 @@ def init_process_group(self, rank: int, trainer_world_size: int, rollout_world_s
293299
self.trainer_world_size = trainer_world_size
294300
self.rollout_world_size = rollout_world_size
295301
self.world_size = trainer_world_size + rollout_world_size
296-
# unregister_memory in transfer engine is not supported on NPU,
297-
# so we have to initialize ParameterServer each time
298-
if get_device_name() == "npu" or not self.initialized:
302+
303+
if not self.initialized:
299304
self.parameter_server = ParameterServer(
300305
rank=rank,
301306
world_size=self.world_size,
@@ -304,7 +309,7 @@ def init_process_group(self, rank: int, trainer_world_size: int, rollout_world_s
304309
master_port=master_metadata.dist_port,
305310
)
306311
self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server)
307-
if not self.initialized:
312+
308313
dist.use_backend(f"vllm_{get_nccl_backend()}")
309314
self.parameter_server.init_process_group()
310315

@@ -345,7 +350,7 @@ def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch
345350

346351
self.parameter_server.register_checkpoint(self.checkpoint_name, named_tensors=named_tensors)
347352
named_tensors = {}
348-
torch.cuda.empty_cache()
353+
get_torch_device().empty_cache()
349354
logger.info(f"Rank {self.rank} offload and register, time cost: {time.time() - start_time:.2f}s")
350355

351356
self.parameter_server.gather_metas(self.checkpoint_name)

0 commit comments

Comments
 (0)