Skip to content

Commit 491f252

Browse files
authored
[Feature] Tiny fix for wandb run id (THUDM#730)
1 parent dea6ec7 commit 491f252

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

slime/ray/placement_group.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,32 +109,30 @@ def create_placement_groups(args):
109109
}
110110

111111

112-
def allocate_train_group(args, num_nodes, num_gpus_per_node, pg, wandb_run_id):
112+
def allocate_train_group(args, num_nodes, num_gpus_per_node, pg):
113113
return RayTrainGroup(
114114
args=args,
115115
num_nodes=num_nodes,
116116
num_gpus_per_node=num_gpus_per_node,
117117
pg=pg,
118-
wandb_run_id=wandb_run_id,
118+
wandb_run_id=args.wandb_run_id,
119119
num_gpus_per_actor=0.4,
120120
)
121121

122122

123-
def create_training_models(args, pgs, rollout_manager, wandb_run_id):
123+
def create_training_models(args, pgs, rollout_manager):
124124
actor_model = allocate_train_group(
125125
args=args,
126126
num_nodes=args.actor_num_nodes,
127127
num_gpus_per_node=args.actor_num_gpus_per_node,
128128
pg=pgs["actor"],
129-
wandb_run_id=wandb_run_id,
130129
)
131130
if args.use_critic:
132131
critic_model = allocate_train_group(
133132
args=args,
134133
num_nodes=args.critic_num_nodes,
135134
num_gpus_per_node=args.critic_num_gpus_per_node,
136135
pg=pgs["critic"],
137-
wandb_run_id=wandb_run_id,
138136
)
139137
critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False)
140138
else:
@@ -159,11 +157,11 @@ def create_training_models(args, pgs, rollout_manager, wandb_run_id):
159157
return actor_model, critic_model
160158

161159

162-
def create_rollout_manager(args, pg, wandb_run_id):
160+
def create_rollout_manager(args, pg):
163161
rollout_manager = RolloutManager.options(
164162
num_cpus=1,
165163
num_gpus=0,
166-
).remote(args, pg, wandb_run_id=wandb_run_id)
164+
).remote(args, pg, wandb_run_id=args.wandb_run_id)
167165

168166
# calculate num_rollout from num_epoch
169167
num_rollout_per_epoch = None

slime/utils/wandb_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def _is_offline_mode(args) -> bool:
1818

1919
def init_wandb_primary(args):
2020
if not args.use_wandb:
21-
return None
21+
args.wandb_run_id = None
22+
return
2223

2324
# Set W&B mode if specified (overrides WANDB_MODE env var)
2425
if args.wandb_mode:
@@ -71,7 +72,8 @@ def init_wandb_primary(args):
7172

7273
_init_wandb_common()
7374

74-
return wandb.run.id
75+
# Set wandb_run_id in args for easy access throughout the training process
76+
args.wandb_run_id = wandb.run.id
7577

7678

7779
def _compute_config_for_logging(args):

0 commit comments

Comments
 (0)