Skip to content

Commit 758ed24

Browse files
authored
[refactor] Add actor registry (THUDM#359)
1 parent 843c94e commit 758ed24

File tree

11 files changed

+96
-52
lines changed

11 files changed

+96
-52
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
1010

1111
import wandb
12+
from slime.ray.registry import get_actors
1213
from slime.ray.train_actor import TrainRayActor
1314
from slime.utils.data import process_rollout_data
1415
from slime.utils.distributed_utils import get_gloo_group
@@ -95,6 +96,7 @@ def init(self, args, role, wandb_run_id, with_ref: bool = False): # type: ignor
9596
self.update_cpu_params_dict(self.weights["actor"])
9697

9798
self.weight_updator = UpdateWeightFromTensor(self.args, self.model)
99+
self.connected = False
98100

99101
if self.args.offload:
100102
self.sleep(("model"))
@@ -122,15 +124,6 @@ def save_model(self, iteration):
122124

123125
raise NotImplementedError()
124126

125-
def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
126-
self.rollout_engines = rollout_engines
127-
128-
if self.args.debug_train_only or self.args.debug_rollout_only:
129-
return
130-
131-
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
132-
dist.barrier(group=get_gloo_group())
133-
134127
def compute_log_prob(
135128
self,
136129
model_tag,
@@ -392,6 +385,13 @@ def update_weights(self): # type: ignore[override]
392385
if self.args.debug_train_only or self.args.debug_rollout_only:
393386
return
394387

388+
if not self.connected:
389+
self.connected = True
390+
rollout_engines = get_actors("rollout")
391+
rollout_engine_lock = get_actors("rollout_lock", 0)
392+
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
393+
dist.barrier(group=get_gloo_group())
394+
395395
if self.args.offload:
396396
# TODO: don't wake up here
397397
self.wake_up(("model"))

slime/backends/megatron_utils/actor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from megatron.core import mpu
1515
from transformers import AutoConfig, AutoTokenizer
1616

17+
from slime.ray.registry import get_actors
1718
from slime.ray.train_actor import TrainRayActor
1819
from slime.utils.data import process_rollout_data
1920
from slime.utils.distributed_utils import get_gloo_group, init_process_group
@@ -88,6 +89,7 @@ def init(self, args, role, wandb_run_id, with_ref=False):
8889
quantization_config=getattr(self.hf_config, "quantization_config", None),
8990
vocab_size=self.tokenizer.vocab_size if self.args.vocab_size is None else self.args.vocab_size,
9091
)
92+
self.connected = False
9193

9294
# empty cache after initialization
9395
clear_memory()
@@ -384,15 +386,6 @@ def save_model(self, iteration):
384386

385387
save(iteration, self.model, self.optimizer, self.opt_param_scheduler)
386388

387-
def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
388-
self.rollout_engines = rollout_engines
389-
390-
if self.args.debug_train_only or self.args.debug_rollout_only:
391-
return
392-
393-
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
394-
dist.barrier(group=get_gloo_group())
395-
396389
@timer
397390
def update_weights(self):
398391
if self.args.debug_train_only or self.args.debug_rollout_only:
@@ -401,6 +394,13 @@ def update_weights(self):
401394
if self.args.offload and hasattr(mpu, "reload_process_groups"):
402395
mpu.reload_process_groups()
403396

397+
if not self.connected:
398+
self.connected = True
399+
rollout_engines = get_actors("rollout")
400+
rollout_engine_lock = get_actors("rollout_lock", 0)
401+
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
402+
dist.barrier(group=get_gloo_group())
403+
404404
with torch_memory_saver.disable() if self.args.offload and not torch.version.hip else nullcontext():
405405
print_memory("before update_weights")
406406
self.weight_updator.update_weights()

slime/backends/xtuner_utils/actor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from xtuner.v1.model import get_model_config_from_hf
1111

1212
import wandb
13+
from slime.ray.registry import get_actors
1314
from slime.ray.train_actor import TrainRayActor
1415
from slime.utils.data import process_rollout_data
1516
from slime.utils.distributed_utils import get_gloo_group
@@ -67,6 +68,7 @@ def init(self, args, role, wandb_run_id, with_ref: bool = False):
6768
self.sp_mesh = self.data_mesh["sp"]
6869

6970
self.weight_updator = UpdateWeightFromDistributed(args, self.model)
71+
self.connected = False
7072

7173
def sleep(self, tags):
7274
if not getattr(self.args, "offload", False):
@@ -87,15 +89,6 @@ def save_model(self, iteration):
8789
path = f"{self.args.save}/iter_{iteration:07}/hf"
8890
self.model.save_hf(path, save_dtype=torch.bfloat16)
8991

90-
def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
91-
self.rollout_engines = rollout_engines
92-
93-
if self.args.debug_train_only or self.args.debug_rollout_only:
94-
return
95-
96-
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
97-
dist.barrier(group=get_gloo_group())
98-
9992
def get_rollout_data(self, rollout_data_ref):
10093
dp_rank = dist.get_rank() // self.args.sp_size
10194
dp_size = dist.get_world_size() // self.args.sp_size
@@ -267,6 +260,13 @@ def update_weights(self): # type: ignore[override]
267260
if self.args.debug_train_only or self.args.debug_rollout_only:
268261
return
269262

263+
if not self.connected:
264+
self.connected = True
265+
rollout_engines = get_actors("rollout")
266+
rollout_engine_lock = get_actors("rollout_lock", 0)
267+
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
268+
dist.barrier(group=get_gloo_group())
269+
270270
if self.args.offload:
271271
# TODO: don't wake up here
272272
self.wake_up(("model"))

slime/ray/actor_group.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ def __init__(
3737
pg: tuple[PlacementGroup, list[int]],
3838
wandb_run_id: Optional[str] = None,
3939
num_gpus_per_actor: float = 1,
40+
role: str = "actor",
4041
) -> None:
4142
self.args = args
4243
self._num_nodes = num_nodes
4344
self._num_gpus_per_node = num_gpus_per_node
4445
self._wandb_run_id = wandb_run_id
46+
self.role = role
4547

4648
# Allocate the GPUs for actors w/o instantiating them
4749
self._allocate_gpus_for_actor(pg, num_gpus_per_actor, wandb_run_id=wandb_run_id)
@@ -113,18 +115,6 @@ def async_init(self, args, role, with_ref=False):
113115
self.args = args
114116
return [actor.init.remote(args, role, self._wandb_run_id, with_ref=with_ref) for actor in self._actor_handlers]
115117

116-
def async_init_weight_update_connections(self, rollout):
117-
"""
118-
Connect rollout engines and actors, e.g. initialize the process group between them
119-
to update weights after each training stage.
120-
"""
121-
self.rollout = rollout
122-
rollout_engines, rollout_engine_lock = ray.get(rollout.get_rollout_engines_and_lock.remote())
123-
return [
124-
actor.connect_rollout_engines.remote(rollout_engines, rollout_engine_lock)
125-
for actor in self._actor_handlers
126-
]
127-
128118
def async_train(self, rollout_id, rollout_data_ref):
129119
"""Do one rollout training"""
130120
return [actor.train.remote(rollout_id, rollout_data_ref) for actor in self._actor_handlers]

slime/ray/registry.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import ray
2+
3+
4+
@ray.remote
5+
class Registry:
6+
def __init__(self):
7+
self.actors = {}
8+
9+
def set(self, role, key, actor):
10+
if role not in self.actors:
11+
self.actors[role] = {}
12+
self.actors[role][key] = actor
13+
14+
def get(self, role: str, key=None):
15+
actors = self.actors[role]
16+
if key is None:
17+
return list(actors.values())
18+
return actors[key]
19+
20+
21+
REGISTRY = None
22+
23+
24+
def register_actor(role, key, actor):
25+
try:
26+
registry = ray.get_actor("slime_actor_registry")
27+
except ValueError:
28+
global REGISTRY
29+
REGISTRY = Registry.options(name="slime_actor_registry").remote()
30+
registry = REGISTRY
31+
registry.set.remote(role, key, actor)
32+
33+
34+
def get_actors(role, key=None):
35+
registry = ray.get_actor("slime_actor_registry")
36+
return ray.get(registry.get.remote(role, key))

slime/ray/rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def _create_rollout_engines(args, pg):
230230
"env_vars": {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST}
231231
| {
232232
"SGL_JIT_DEEPGEMM_PRECOMPILE": "false",
233+
"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true",
233234
}
234235
},
235236
).remote(args, rank=i)

slime/ray/train_actor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ def sleep(self, tags):
8282
def wake_up(self, tags):
8383
raise NotImplementedError
8484

85-
@abc.abstractmethod
86-
def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
87-
raise NotImplementedError
88-
8985
@abc.abstractmethod
9086
def train(self, rollout_id, rollout_data_ref):
9187
raise NotImplementedError

slime/ray/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import ray
55
import torch
6+
from slime.ray.ray_actor import RayActor
7+
68

79
# Refer to
810
# https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96
@@ -34,7 +36,7 @@ def get_physical_gpu_id():
3436

3537

3638
@ray.remote
37-
class Lock:
39+
class Lock(RayActor):
3840
def __init__(self):
3941
self._locked = False # False: unlocked, True: locked
4042

slime/utils/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
# TODO: this is slow.
7474
if max_length is not None:
7575
if not multimodal_keys:
76-
if len(tokenizer(data[prompt_key])["input_ids"]) > max_length:
76+
if len(prompt) > max_length:
7777
continue
7878

7979
self.origin_samples.append(

train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
33

44
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_group
5+
from slime.ray.registry import register_actor
56
from slime.utils.arguments import parse_args
67
from slime.utils.wandb_utils import init_wandb_primary
78

@@ -18,6 +19,17 @@ def train(args):
1819
# create the rollout manager, with sglang engines inside.
1920
rollout_manager = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
2021

22+
# TODO: extract this to single function
23+
rollout_engines, rollout_engine_lock = ray.get(rollout_manager.get_rollout_engines_and_lock.remote())
24+
for i, rollout_engine in enumerate(rollout_engines):
25+
register_actor("rollout", i, rollout_engine)
26+
register_actor("rollout_lock", 0, rollout_engine_lock)
27+
for i, actor in enumerate(actor_model._actor_handlers):
28+
register_actor("actor", i, actor)
29+
if args.use_critic:
30+
for i, critic in enumerate(critic_model._actor_handlers):
31+
register_actor("critic", i, critic)
32+
2133
# calculate num_rollout from num_epoch
2234
num_rollout_per_epoch = None
2335
if args.num_rollout is None:
@@ -32,17 +44,13 @@ def train(args):
3244
start_rollout_ids = ray.get(
3345
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
3446
)
35-
3647
assert len(set(start_rollout_ids)) == 1
3748
if args.start_rollout_id is None:
3849
args.start_rollout_id = start_rollout_ids[0]
3950

4051
if args.rollout_global_dataset:
4152
ray.get(rollout_manager.load.remote(args.start_rollout_id - 1))
4253

43-
# initialize the connection for weight update during training
44-
ray.get(actor_model.async_init_weight_update_connections(rollout_manager))
45-
4654
if args.use_critic:
4755
ray.get(critic_init_handle)
4856
ray.get(actor_model.async_connect(critic_model))

0 commit comments

Comments
 (0)