Skip to content

Commit e8aba2c

Browse files
authored
[Doc] Add type annotation in actor.py (#436)
* Add type annotation in actor.py * Add type annotation in actor.py
1 parent fcce96c commit e8aba2c

File tree

1 file changed

+42
-21
lines changed

1 file changed

+42
-21
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import os
22
import socket
33
import time
4+
from argparse import Namespace
45
from contextlib import nullcontext
56
from pathlib import Path
7+
from typing import Dict, Optional, Tuple, Union
68

79
import ray
810
import torch
911
import torch.distributed as dist
12+
from ray.actor import ActorHandle
1013

1114
if torch.version.hip:
1215
from vllm.device_allocator.cumem import CuMemAllocator
@@ -20,20 +23,27 @@
2023
from slime.utils.data import process_rollout_data
2124
from slime.utils.distributed_utils import get_gloo_group, init_process_group
2225
from slime.utils.memory_utils import clear_memory, print_memory
26+
from slime.utils.ray_utils import Box
2327
from slime.utils.timer import Timer, timer
2428
from slime.utils.wandb_utils import init_wandb_secondary
2529

2630
from .checkpoint import load_checkpoint
2731
from .cp_utils import slice_log_prob_with_cp
28-
from .data import get_data_iterator, log_perf_data, log_rollout_data, sync_actor_critic_data
32+
from .data import DataIterator, get_data_iterator, log_perf_data, log_rollout_data, sync_actor_critic_data
2933
from .initialize import init, is_megatron_main_rank
3034
from .loss import compute_advantages_and_returns, get_log_probs_and_entropy, get_values
3135
from .model import forward_only, initialize_model_and_optimizer, save, train
3236
from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor, named_parameters
3337

3438

3539
class MegatronTrainRayActor(TrainRayActor):
36-
def init(self, args, role, wandb_run_id, with_ref=False):
40+
def init(
41+
self,
42+
args: Namespace,
43+
role: str,
44+
wandb_run_id: str,
45+
with_ref: bool = False,
46+
) -> Optional[int]:
3747
super().init(args, role, wandb_run_id, with_ref)
3848

3949
init(args)
@@ -128,29 +138,29 @@ def init(self, args, role, wandb_run_id, with_ref=False):
128138
return start_rollout_id
129139

130140
@torch.no_grad()
131-
def update_cpu_params_dict(self, params_dict):
141+
def update_cpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None:
132142
for name, param in named_parameters(self.args, self.model):
133143
if name not in params_dict:
134144
params_dict[name] = torch.empty_like(param, device=torch.device("cpu"), pin_memory=True)
135145
params_dict[name].copy_(param.detach(), non_blocking=True)
136146
torch.cuda.synchronize()
137147

138148
@torch.no_grad()
139-
def update_gpu_params_dict(self, params_dict):
149+
def update_gpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None:
140150
for name, param in named_parameters(self.args, self.model):
141151
assert name in params_dict
142152
param.copy_(params_dict[name], non_blocking=True)
143153
torch.cuda.synchronize()
144154

145155
@timer
146-
def sleep(self, tags):
156+
def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
147157
assert self.args.offload
148158
assert "model" in tags
149159
if isinstance(tags, str):
150160
tags = (tags,)
151161

152162
clear_memory()
153-
print_memory(f"before offload model")
163+
print_memory("before offload model")
154164
if hasattr(mpu, "destroy_process_groups"):
155165
mpu.destroy_process_groups()
156166

@@ -160,10 +170,10 @@ def sleep(self, tags):
160170
allocator = CuMemAllocator.get_instance()
161171
allocator.sleep(offload_tags=tags)
162172

163-
print_memory(f"after offload model")
173+
print_memory("after offload model")
164174

165175
@timer
166-
def wake_up(self, tags):
176+
def wake_up(self, tags: Union[str, Tuple[str, ...]]) -> None:
167177
assert self.args.offload
168178

169179
# there are weird times when sglang is not offloaded immediately, so we wait here.
@@ -189,7 +199,9 @@ def wake_up(self, tags):
189199
mpu.reload_process_groups()
190200
print_memory("after wake_up model")
191201

192-
def _get_rollout_data(self, rollout_data_ref):
202+
def _get_rollout_data(
203+
self, rollout_data_ref: Box
204+
) -> Dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]]:
193205
# Fetch data through ray on CPU, not sure if this will be performance bottleneck.
194206
# Both first pp stage and the last pp stage will recieve the data.
195207
rollout_data = process_rollout_data(
@@ -221,11 +233,11 @@ def _get_rollout_data(self, rollout_data_ref):
221233

222234
def compute_log_prob(
223235
self,
224-
model_tag,
225-
data_iterator,
226-
num_microbatches,
227-
store_prefix="",
228-
):
236+
model_tag: str,
237+
data_iterator: list[DataIterator],
238+
num_microbatches: list[int],
239+
store_prefix: str = "",
240+
) -> Dict[str, list[torch.Tensor]]:
229241
self.update_gpu_params_dict(self.weights[model_tag])
230242

231243
with timer(f"{store_prefix}log_probs"):
@@ -238,7 +250,7 @@ def compute_log_prob(
238250
store_prefix=store_prefix,
239251
)
240252

241-
def train(self, rollout_id, rollout_data_ref):
253+
def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
242254
Timer().end("train_wait")
243255

244256
if self.args.offload:
@@ -256,7 +268,9 @@ def train(self, rollout_id, rollout_data_ref):
256268
else:
257269
return self.train_actor(rollout_id, rollout_data)
258270

259-
def train_critic(self, rollout_id, rollout_data):
271+
def train_critic(
272+
self, rollout_id: int, rollout_data: Dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]]
273+
) -> None:
260274
# Create data iterator for log_probs and train.
261275
data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data)
262276
rollout_data.update(
@@ -285,7 +299,9 @@ def train_critic(self, rollout_id, rollout_data):
285299
)
286300
Timer().start("train_wait")
287301

288-
def train_actor(self, rollout_id, rollout_data):
302+
def train_actor(
303+
self, rollout_id: int, rollout_data: Dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]]
304+
) -> None:
289305
# Create data iterator for log_probs and train.
290306
data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data)
291307

@@ -386,14 +402,14 @@ def train_actor(self, rollout_id, rollout_data):
386402
log_perf_data(rollout_id, self.args)
387403
Timer().start("train_wait")
388404

389-
def save_model(self, iteration):
405+
def save_model(self, iteration: int) -> None:
390406
if self.args.debug_rollout_only:
391407
return
392408

393409
save(iteration, self.model, self.optimizer, self.opt_param_scheduler)
394410

395411
@timer
396-
def update_weights(self):
412+
def update_weights(self) -> None:
397413
if self.args.debug_train_only or self.args.debug_rollout_only:
398414
return
399415

@@ -424,7 +440,7 @@ def update_weights(self):
424440
if self.args.offload and hasattr(mpu, "destroy_process_groups"):
425441
mpu.destroy_process_groups()
426442

427-
def load_other_checkpoint(self, model_tag, path):
443+
def load_other_checkpoint(self, model_tag: str, path: str) -> None:
428444
old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune
429445
self.args.load = path
430446
self.args.no_load_optim = True
@@ -450,7 +466,12 @@ def load_other_checkpoint(self, model_tag, path):
450466
self.weights[model_tag] = {}
451467
self.update_cpu_params_dict(self.weights[model_tag])
452468

453-
def connect_actor_critic(self, actor_handle=None, master_address=None, master_port=None):
469+
def connect_actor_critic(
470+
self,
471+
actor_handle: Optional[ActorHandle] = None,
472+
master_address: Optional[str] = None,
473+
master_port: Optional[int] = None,
474+
) -> None:
454475
if self.role == "actor":
455476
master_address = ray.util.get_node_ip_address()
456477
with socket.socket() as sock:

0 commit comments

Comments
 (0)