Skip to content

Commit 9aa1c20

Browse files
author
x1314aq
committed
support megatron and sglang
1 parent a61202b commit 9aa1c20

File tree

4 files changed

+107
-18
lines changed

4 files changed

+107
-18
lines changed

recipe/one_step_off_policy/ckpt_engine_worker.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,19 @@
3737

3838

3939
class CkptEngineWorker(Worker):
40-
def __init__(self, rank_offset, ps_world_size, inference_parallel_size):
40+
def __init__(self, rank_offset, ps_world_size, inference_parallel_size, rollout_name):
4141
super().__init__()
4242
rank = self.rank + rank_offset
4343
self.ps_rank = rank
4444
self.ps_rank_offset = rank_offset
4545
self.ps_world_size = ps_world_size
4646
self.inference_parallel_size = inference_parallel_size
47+
self.rollout_name = rollout_name
4748
self.ps = ParameterServer(rank=rank, world_size=ps_world_size)
4849
self.index = 0
4950

50-
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
51-
def init_process_group(self):
52-
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020-61050"
51+
def _init_process_group(self):
52+
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
5353
self.ps.init_process_group(device_index=0, master_port=60010)
5454
del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]
5555

@@ -70,6 +70,27 @@ def check_vllm_ready(self, uds: str | None = None):
7070
logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}")
7171
time.sleep(5)
7272

73+
def check_sglang_ready(self, uds: str | None = None):
74+
if self.ps_rank != self.ps_rank // self.inference_parallel_size * self.inference_parallel_size:
75+
return
76+
retry_num = 0
77+
transport = None
78+
if uds is not None:
79+
transport = httpx.HTTPTransport(uds=uds)
80+
with httpx.Client(transport=transport) as client:
81+
while True:
82+
try:
83+
response = client.get(f"{self.endpoint}/ping", timeout=10)
84+
response.raise_for_status()
85+
break
86+
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
87+
if retry_num % 10 == 0:
88+
logger.warning(
89+
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
90+
)
91+
retry_num += 1
92+
time.sleep(0.1)
93+
7394
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
7495
def set_server_addresses(self, server_addresses: list[str]):
7596
# todo support multiple api server
@@ -81,15 +102,36 @@ def sync_rollout_weights_by_ckpt_engine(self):
81102
rank = self.rank
82103
src = rank // self.inference_parallel_size * self.inference_parallel_size
83104

84-
def req_func(socket_paths: list[tuple[str, str]]) -> None:
105+
def vllm_req_func(socket_paths: list[tuple[str, str]]) -> None:
85106
if rank == src:
86107
request_inference_to_update(
87108
url=f"{self.endpoint}/collective_rpc",
88109
socket_paths=dict(socket_paths),
89110
)
90111

112+
def vllm_req_func(socket_paths: list[tuple[str, str]]) -> None:
113+
if rank == src:
114+
with httpx.Client(transport=httpx.HTTPTransport()) as client:
115+
resp = client.post(
116+
f"{self.endpoint}/update_weights_from_ipc",
117+
json={
118+
"zmq_handles": dict(socket_paths),
119+
"flush_cache": True,
120+
"weight_version": None,
121+
},
122+
timeout=300.0,
123+
)
124+
resp.raise_for_status()
125+
pass
126+
127+
if self.rollout_name == "sglang":
128+
req_func = sglang_req_func
129+
elif self.rollout_name == "vllm":
130+
req_func = vllm_req_func
131+
132+
self._init_process_group()
91133
checkpoint_name = f"sync_{self.index}"
92134
self.ps.register_checkpoint(checkpoint_name=checkpoint_name)
93135
self.ps.gather_metas(checkpoint_name)
94-
ranks = list(range(self.ps_rank_offset, self.ps_world_size))
95-
self.ps.update(checkpoint_name, req_func, ranks=ranks)
136+
self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size)))
137+
self.index += 1

recipe/one_step_off_policy/fsdp_workers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,8 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
123123
self.ps_rank_offset = kwargs.get("rank_offset", self.rank)
124124
self.ps_world_size = kwargs.get("ps_world_size", self.world_size)
125125
self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size)
126-
127126
self.index = 0
128127

129-
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
130128
def init_process_group(self):
131129
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
132130
self.ps.init_process_group(device_index=0, master_port=60010)
@@ -161,15 +159,15 @@ def sync_rollout_weights_by_ckpt_engine(self):
161159
def req_func(socket_paths: list[tuple[str, str]]):
162160
return
163161

162+
self.init_process_group()
164163
named_tensors = self.split_tensors()
165-
166164
checkpoint_name = f"sync_{self.index}"
167165

168166
self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors)
169167
self.ps.gather_metas(checkpoint_name)
170-
ranks = list(range(self.ps_rank_offset, self.ps_world_size))
168+
self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size)))
171169

172-
self.ps.update(checkpoint_name, req_func, ranks=ranks)
170+
self.index += 1
173171

174172
def _get_actor_params(self):
175173
assert self._is_actor

recipe/one_step_off_policy/megatron_workers.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch.distributed
21+
from checkpoint_engine.ps import ParameterServer
2122
from omegaconf import DictConfig
2223
from ray.util.collective import collective
2324

@@ -120,6 +121,58 @@ async def update_weights(self, inference_engine, params):
120121

121122

122123
class DetachActorWorker(DetachSync):
124+
def __init__(self, config: DictConfig, role: str, **kwargs):
125+
ActorRolloutRefWorker.__init__(self, config, role)
126+
127+
if role == "actor":
128+
self.ps_rank_offset = kwargs.get("rank_offset", self.rank)
129+
self.ps_world_size = kwargs.get("ps_world_size", self.world_size)
130+
self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size)
131+
self.index = 0
132+
133+
def init_process_group(self):
134+
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
135+
self.ps.init_process_group(device_index=0, master_port=60010)
136+
del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]
137+
138+
def split_tensors(self) -> dict[str, torch.Tensor]:
139+
assert self._is_actor and not self.config.hybrid_engine
140+
assert hasattr(self, "_weights_info") and self._weights_info is not None
141+
142+
params_generator = self._get_actor_params_generator() if self._is_actor else None
143+
144+
if self._is_actor and self._is_offload_param:
145+
load_megatron_model_to_gpu(self.actor_module)
146+
147+
named_tensors = {}
148+
149+
world_size = self.world_size
150+
rank = self.rank
151+
152+
weights_per_rank = (len(self._weights_info) + world_size - 1) // world_size
153+
for index, (key, tensor) in enumerate(params_generator):
154+
if index >= rank * weights_per_rank and index < (rank + 1) * weights_per_rank:
155+
named_tensors[key] = tensor.to("cpu", non_blocking=True)
156+
157+
get_torch_device().synchronize()
158+
159+
return named_tensors
160+
161+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
162+
def sync_rollout_weights_by_ckpt_engine(self):
163+
def req_func(socket_paths: list[tuple[str, str]]):
164+
return
165+
166+
self.init_process_group()
167+
named_tensors = self.split_tensors()
168+
checkpoint_name = f"sync_{self.index}"
169+
170+
self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors)
171+
self.ps.gather_metas(checkpoint_name)
172+
self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size)))
173+
174+
self.index += 1
175+
123176
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
124177
def _get_actor_params_generator(self):
125178
assert self._is_actor
@@ -160,7 +213,7 @@ def get_actor_weights_info(self):
160213

161214

162215
class DetachAsyncRolloutWorker(DetachSync):
163-
def __init__(self, config: DictConfig, role: str):
216+
def __init__(self, config: DictConfig, role: str, **kwargs):
164217
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
165218
ActorRolloutRefWorker.__init__(self, config, role)
166219

recipe/one_step_off_policy/ray_trainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def _create_ckpt_engine_class(self):
173173
rank_offset=self.rank_offset,
174174
ps_world_size=self.ps_world_size,
175175
inference_parallel_size=self.config.actor_rollout_ref.rollout.tensor_model_parallel_size,
176+
rollout_name=self.config.actor_rollout_ref.rollout.name
176177
)
177178
self.resource_pool_to_cls[resource_pool][str(Role.CkptEngine)] = ckpt_engine_cls
178179

@@ -272,11 +273,6 @@ def _init_models(self):
272273
self.actor_rollout_wg = self.actor_wg
273274
weights_info = self.actor_wg.get_actor_weights_info()[0]
274275
self.rollout_wg.set_actor_weights_info(weights_info)
275-
self._create_weight_sync_group()
276-
277-
def _create_weight_sync_group(self):
278-
self.actor_wg.init_process_group()
279-
ray.get(self.ckpt_engine_wg.init_process_group())
280276

281277
def _init_async_rollout_manager(self):
282278
# create async rollout manager and request scheduler

0 commit comments

Comments
 (0)