Skip to content

Commit 62d1698

Browse files
authored
refactor: extract duplicated checkpoint interval logic into reusable helper (THUDM#1027)
1 parent 0612652 commit 62d1698

File tree

3 files changed

+26
-17
lines changed

3 files changed

+26
-17
lines changed

slime/utils/misc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,23 @@ def get_free_port(start_port=10000, consecutive=1):
6565
while not all(is_port_available(port + i) for i in range(consecutive)):
6666
port += 1
6767
return port
68+
69+
70+
def should_run_periodic_action(
71+
rollout_id: int,
72+
interval: int | None,
73+
num_rollout_per_epoch: int | None = None,
74+
) -> bool:
75+
"""
76+
Return True when a periodic action (eval/save/checkpoint) should run.
77+
78+
Args:
79+
rollout_id: The current rollout index (0-based).
80+
interval: Desired cadence; disables checks when None.
81+
num_rollout_per_epoch: Optional epoch boundary to treat as a trigger.
82+
"""
83+
if interval is None:
84+
return False
85+
86+
step = rollout_id + 1
87+
return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0)

train.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
1010
from slime.utils.arguments import parse_args
1111
from slime.utils.logging_utils import configure_logger
12+
from slime.utils.misc import should_run_periodic_action
1213
from slime.utils.tracking_utils import init_tracking
1314

1415

@@ -61,7 +62,6 @@ def onload_rollout():
6162
# train loop.
6263
# note that for async training, one can change the position of the sync operation(ray.get).
6364
for rollout_id in range(args.start_rollout_id, args.num_rollout):
64-
# TODO extract the duplicated eval logic
6565
if args.eval_interval is not None and rollout_id == 0:
6666
ray.get(rollout_manager.eval.remote(rollout_id))
6767

@@ -78,10 +78,7 @@ def onload_rollout():
7878
else:
7979
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
8080

81-
if args.save_interval is not None and (
82-
(rollout_id + 1) % args.save_interval == 0
83-
or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
84-
):
81+
if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch):
8582
if (not args.use_critic) or (rollout_id >= args.num_critic_only_steps):
8683
actor_model.save_model(rollout_id)
8784
if args.use_critic:
@@ -98,10 +95,7 @@ def onload_rollout():
9895
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
9996
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
10097

101-
if args.eval_interval is not None and (
102-
(rollout_id + 1) % args.eval_interval == 0
103-
or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
104-
):
98+
if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch):
10599
ray.get(rollout_manager.eval.remote(rollout_id))
106100

107101
ray.get(rollout_manager.dispose.remote())

train_async.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
44
from slime.utils.arguments import parse_args
55
from slime.utils.logging_utils import configure_logger
6+
from slime.utils.misc import should_run_periodic_action
67
from slime.utils.tracking_utils import init_tracking
78

89

@@ -46,10 +47,7 @@ def train(args):
4647
else:
4748
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
4849

49-
if args.save_interval is not None and (
50-
(rollout_id + 1) % args.save_interval == 0
51-
or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
52-
):
50+
if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch):
5351
actor_model.save_model(rollout_id)
5452
if args.use_critic:
5553
critic_model.save_model(rollout_id)
@@ -62,10 +60,7 @@ def train(args):
6260
rollout_data_next_future = None
6361
actor_model.update_weights()
6462

65-
if args.eval_interval is not None and (
66-
(rollout_id + 1) % args.eval_interval == 0
67-
or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
68-
):
63+
if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch):
6964
ray.get(rollout_manager.eval.remote(rollout_id))
7065

7166
ray.get(rollout_manager.dispose.remote())

0 commit comments

Comments
 (0)