Skip to content

Commit a9e5cb0

Browse files
committed
refactor code related with wandb run id
1 parent 9ebc230 commit a9e5cb0

File tree

3 files changed

+11
-15
lines changed

3 files changed

+11
-15
lines changed

slime/ray/placement_group.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,32 +109,30 @@ def create_placement_groups(args):
109109
}
110110

111111

112-
def allocate_train_group(args, num_nodes, num_gpus_per_node, pg, wandb_run_id):
112+
def allocate_train_group(args, num_nodes, num_gpus_per_node, pg):
113113
return RayTrainGroup(
114114
args=args,
115115
num_nodes=num_nodes,
116116
num_gpus_per_node=num_gpus_per_node,
117117
pg=pg,
118-
wandb_run_id=wandb_run_id,
118+
wandb_run_id=args.wandb_run_id,
119119
num_gpus_per_actor=0.4,
120120
)
121121

122122

123-
def create_training_models(args, pgs, rollout_manager, wandb_run_id):
123+
def create_training_models(args, pgs, rollout_manager):
124124
actor_model = allocate_train_group(
125125
args=args,
126126
num_nodes=args.actor_num_nodes,
127127
num_gpus_per_node=args.actor_num_gpus_per_node,
128128
pg=pgs["actor"],
129-
wandb_run_id=wandb_run_id,
130129
)
131130
if args.use_critic:
132131
critic_model = allocate_train_group(
133132
args=args,
134133
num_nodes=args.critic_num_nodes,
135134
num_gpus_per_node=args.critic_num_gpus_per_node,
136135
pg=pgs["critic"],
137-
wandb_run_id=wandb_run_id,
138136
)
139137
critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False)
140138
else:
@@ -159,11 +157,11 @@ def create_training_models(args, pgs, rollout_manager, wandb_run_id):
159157
return actor_model, critic_model
160158

161159

162-
def create_rollout_manager(args, pg, wandb_run_id):
160+
def create_rollout_manager(args, pg):
163161
rollout_manager = RolloutManager.options(
164162
num_cpus=1,
165163
num_gpus=0,
166-
).remote(args, pg, wandb_run_id=wandb_run_id)
164+
).remote(args, pg, wandb_run_id=args.wandb_run_id)
167165

168166
# calculate num_rollout from num_epoch
169167
num_rollout_per_epoch = None

train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
def train(args):
1515
# allocate the GPUs
1616
pgs = create_placement_groups(args)
17-
wandb_run_id = init_wandb_primary(args)
18-
args.wandb_run_id = wandb_run_id
17+
args.wandb_run_id = init_wandb_primary(args)
1918

2019
# create the rollout manager, with sglang engines inside.
2120
# need to initialize rollout manager first to calculate num_rollout
22-
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
21+
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"])
2322

2423
# create the actor and critic models
25-
actor_model, critic_model = create_training_models(args, pgs, rollout_manager, wandb_run_id=wandb_run_id)
24+
actor_model, critic_model = create_training_models(args, pgs, rollout_manager)
2625

2726
if args.offload_rollout:
2827
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))

train_async.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ def train(args):
99
assert not args.colocate, "Colocation is not supported for async training."
1010
# allocate the GPUs
1111
pgs = create_placement_groups(args)
12-
wandb_run_id = init_wandb_primary(args)
13-
args.wandb_run_id = wandb_run_id
12+
args.wandb_run_id = init_wandb_primary(args)
1413

1514
# create the rollout manager, with sglang engines inside.
1615
# need to initialize rollout manager first to calculate num_rollout
17-
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
16+
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"])
1817

1918
# create the actor and critic models
20-
actor_model, critic_model = create_training_models(args, pgs, rollout_manager, wandb_run_id=wandb_run_id)
19+
actor_model, critic_model = create_training_models(args, pgs, rollout_manager)
2120

2221
# always update weight first so that sglang has the loaded weights from training.
2322
actor_model.update_weights()

0 commit comments

Comments
 (0)