Skip to content

Commit 94e228b

Browse files
author
kip-cxj
committed
fix pre_commit
1 parent 4755ae3 commit 94e228b

File tree

8 files changed

+34
-31
lines changed

8 files changed

+34
-31
lines changed

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ actor_rollout_ref:
272272
skip_dump_dir: /tmp/rollout_dump
273273
skip_tokenizer_init: true
274274
enable_rollout_routing_replay: false
275+
enable_checkpoint_engine: false
275276
profiler:
276277
_target_: verl.utils.profiler.ProfilerConfig
277278
tool: ${oc.select:global_profiler.tool,null}

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ actor_rollout_ref:
261261
skip_dump_dir: /tmp/rollout_dump
262262
skip_tokenizer_init: true
263263
enable_rollout_routing_replay: false
264+
enable_checkpoint_engine: false
264265
profiler:
265266
_target_: verl.utils.profiler.ProfilerConfig
266267
tool: ${oc.select:global_profiler.tool,null}

verl/workers/fsdp_workers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,22 +581,21 @@ def _build_model_optimizer(
581581
def update_weighs_by_checkpoint_engine(
582582
self,
583583
weights: Generator[tuple[str, torch.Tensor], None, None],
584-
req_func: Callable[[list[tuple[str, str]]], None]
584+
req_func: Callable[[list[tuple[str, str]]], None],
585585
):
586586
named_tensors = {}
587587
for tensor_idx, (name, tensor) in enumerate(weights):
588588
if tensor_idx % self.world_size == self.rank:
589589
named_tensors[name] = tensor
590590

591-
checkpoint_name = f"checkpoint_engine"
591+
checkpoint_name = "checkpoint_engine"
592592
self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
593593
named_tensors = {}
594594
dist.barrier()
595595
self.parameter_server.gather_metas(checkpoint_name)
596596
self.parameter_server.update(checkpoint_name, req_func)
597597
self.parameter_server.unregister_checkpoint(checkpoint_name)
598598

599-
600599
def _build_rollout(self, trust_remote_code=False):
601600
from torch.distributed.device_mesh import init_device_mesh
602601

@@ -744,16 +743,18 @@ async def rollout_mode(self):
744743
)
745744
if self.config.rollout.enable_checkpoint_engine:
746745
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
747-
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
746+
self.update_weighs_by_checkpoint_engine(per_tensor_base_params, req_func)
748747
else:
749748
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
750749
del base_model_params, per_tensor_base_params
751-
750+
752751
if self.config.rollout.enable_checkpoint_engine:
753752
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
754753
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
755754
else:
756-
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
755+
await self.rollout.update_weights(
756+
per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done
757+
)
757758
log_gpu_memory_usage("After update_weights", logger=logger)
758759
del params, per_tensor_param
759760
aggressive_empty_cache(force_sync=True)

verl/workers/megatron_workers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
import logging
2020
import os
2121
import time
22+
from collections.abc import Callable
2223
from typing import Any, Generator, Optional
2324

2425
import psutil
2526
import torch
2627
import torch.distributed
27-
from collections.abc import Callable
2828
from codetiming import Timer
2929
from omegaconf import DictConfig, OmegaConf
3030

@@ -487,14 +487,14 @@ def _build_model_optimizer(
487487
def update_weighs_by_checkpoint_engine(
488488
self,
489489
weights: Generator[tuple[str, torch.Tensor], None, None],
490-
req_func: Callable[[list[tuple[str, str]]], None]
490+
req_func: Callable[[list[tuple[str, str]]], None],
491491
):
492492
named_tensors = {}
493493
for tensor_idx, (name, tensor) in enumerate(weights):
494494
if tensor_idx % self.world_size == self.rank:
495495
named_tensors[name] = tensor.to("cpu", non_blocking=True)
496496

497-
checkpoint_name = f"checkpoint_engine"
497+
checkpoint_name = "checkpoint_engine"
498498
self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
499499
named_tensors = {}
500500
torch.distributed.barrier()

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ def __init__(
7777
cuda_visible_devices: str,
7878
):
7979
print(f"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}")
80-
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
81-
assert torch.cuda.is_available(), "SGLang http server should run on GPU node"
80+
os.environ["CUDA_VISIBLE_DEVICES" if torch.cuda.is_avilable else "ASCEND_RT_VISIBLE_DEVICES"] = (
81+
cuda_visible_devices
82+
)
83+
assert torch.cuda.is_available() or torch.npu.is_available(), "SGLang http server should run on GPU/NPU node"
8284

8385
self.config: RolloutConfig = omega_conf_to_dataclass(config)
8486
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
@@ -337,7 +339,13 @@ async def launch_servers(self):
337339
node_id=node_id,
338340
soft=False,
339341
),
340-
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
342+
runtime_env={
343+
"env_vars": {
344+
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"
345+
if torch.cuda.is_available()
346+
else "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1"
347+
}
348+
},
341349
name=name,
342350
).remote(
343351
config=self.config,

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import logging
1919
import multiprocessing as mp
2020
import os
21+
from collections.abc import Callable
2122
from typing import Generator
2223

2324
import ray
2425
import sglang.srt.entrypoints.engine
2526
import torch
26-
from collections.abc import Callable
27+
from sglang.srt.checkpoint_engine.update import req_inference
2728
from sglang.srt.server_args import ServerArgs
2829
from sglang.srt.utils import (
2930
assert_pkg_version,
@@ -32,7 +33,6 @@
3233
set_ulimit,
3334
)
3435
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
35-
from sglang.srt.checkpoint_engine.update import req_inference
3636
from torch.distributed.device_mesh import DeviceMesh
3737

3838
from verl.workers.config import HFModelConfig, RolloutConfig
@@ -203,4 +203,4 @@ async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Call
203203

204204
req_func = req_inference(endpoint=endpoint, inference_parallel_size=inference_parallel_size)
205205

206-
return req_func
206+
return req_func

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
290290
"override_generation_config": json.dumps(override_generation_config),
291291
"quantization": quantization,
292292
"hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None,
293-
"worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension" if self.config.enable_checkpoint_engine else None,
293+
"worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension"
294+
if self.config.enable_checkpoint_engine
295+
else None,
294296
**engine_kwargs,
295297
}
296298

@@ -691,7 +693,9 @@ async def launch_servers(self):
691693
soft=False,
692694
),
693695
name=name,
694-
runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}} if self.config.enable_checkpoint_engine else None,
696+
runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}}
697+
if self.config.enable_checkpoint_engine
698+
else None,
695699
).remote(
696700
config=self.config,
697701
model_config=self.model_config,

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import getpass
3131
import logging
3232
import os
33-
from dataclasses import asdict
3433
from collections.abc import Callable
34+
from dataclasses import asdict
3535
from types import MethodType
3636
from typing import Any, Generator
3737

@@ -271,19 +271,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
271271
model.load_weights(weights)
272272

273273
async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]:
274-
from checkpoint_engine.ps import request_inference_to_update
275-
rank = int(os.getenv("RANK", None))
276-
src = rank // inference_parallel_size * inference_parallel_size
277-
278-
server_actor = ray.get_actor(f"vllm_server_{self.replica_rank}_{self.node_rank}")
279-
server_address, server_port = await server_actor.get_server_address.remote()
280-
def req_func(socket_paths: list[tuple[str, str]]):
281-
if rank == src:
282-
request_inference_to_update(
283-
f"http://{server_address}:{server_port}/collective_rpc",
284-
dict(socket_paths[src : src + inference_parallel_size]),
285-
)
286-
return req_func
274+
raise NotImplementedError
287275

288276
def generate_sequences(self, prompts: DataProto) -> DataProto:
289277
"""Batch generate sequences in sync mode."""

0 commit comments

Comments
 (0)