Skip to content

Commit d58effa

Browse files
authored
Allow checking accuracy correctness programmatically (THUDM#453)
1 parent 492215c commit d58effa

File tree

6 files changed

+57
-2
lines changed

6 files changed

+57
-2
lines changed

slime/backends/megatron_utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_
475475
tb = _TensorboardAdapter(args)
476476
tb.log(data=log_dict, step=accumulated_step_id)
477477

478-
if args.ci_test:
478+
if args.ci_test and not args.ci_disable_kl_checker:
479479
if step_id == 0 and "train/ppo_kl" in log_dict and "train/pg_clipfrac" in log_dict:
480480
assert log_dict["train/ppo_kl"] == 0.0 and log_dict["train/pg_clipfrac"] == 0.0
481481
if accumulated_step_id == 0 and "train/kl_loss" in log_dict:

slime/ray/rollout.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from slime.backends.sglang_utils.sglang_engine import SGLangEngine
1515
from slime.ray.rollout_data_source import RolloutDataSourceWithBuffer
1616
from slime.utils.http_utils import find_available_port, get_host_info, init_http_client
17+
from slime.utils.metric_checker import MetricChecker
1718
from slime.utils.misc import load_function
1819
from slime.utils.ray_utils import Box
1920
from slime.utils.types import Sample
@@ -58,13 +59,19 @@ def __init__(self, args, pg, wandb_run_id):
5859
self.rollout_engines = self.all_rollout_engines[:: self.nodes_per_engine]
5960
self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote()
6061

62+
self._metric_checker = MetricChecker.maybe_create(args)
63+
6164
# fault tolerance
6265
self._health_monitor_thread = None
6366
self._health_monitor_stop_event = None
6467
self._health_check_interval = args.rollout_health_check_interval
6568
self._health_check_timeout = args.rollout_health_check_timeout
6669
self._health_check_first_wait = args.rollout_health_check_first_wait
6770

71+
def dispose(self):
72+
if self._metric_checker is not None:
73+
self._metric_checker.dispose()
74+
6875
def get_rollout_engines_and_lock(self):
6976
return self.rollout_engines, self.rollout_engine_lock, self.num_new_engines
7077

@@ -93,7 +100,9 @@ def eval(self, rollout_id):
93100
return
94101
# TODO: add fault tolerance to eval
95102
data = self.eval_generate_rollout(self.args, rollout_id, self.data_source, evaluation=True)
96-
_log_eval_rollout_data(rollout_id, self.args, data)
103+
metrics = _log_eval_rollout_data(rollout_id, self.args, data)
104+
if self._metric_checker is not None:
105+
self._metric_checker.on_eval(metrics)
97106

98107
def save(self, rollout_id):
99108
self.data_source.save(rollout_id)
@@ -474,6 +483,8 @@ def _log_eval_rollout_data(rollout_id, args, data):
474483
),
475484
)
476485

486+
return log_dict
487+
477488

478489
def _log_rollout_data(rollout_id, args, samples, rollout_time):
479490
if args.load_debug_rollout_data:

slime/utils/arguments.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,20 @@ def add_ci_arguments(parser):
933933
"--ci-test",
934934
action="store_true",
935935
)
936+
parser.add_argument(
937+
"--ci-disable-kl-checker",
938+
action="store_true",
939+
)
940+
parser.add_argument(
941+
"--ci-metric-checker-key",
942+
type=str,
943+
default=None,
944+
)
945+
parser.add_argument(
946+
"--ci-metric-checker-threshold",
947+
type=float,
948+
default=None,
949+
)
936950
return parser
937951

938952
# Add custom arguments in front to prevent overwritten some slime arguments.

slime/utils/metric_checker.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Dict
2+
3+
4+
class MetricChecker:
5+
@staticmethod
6+
def maybe_create(args):
7+
if args.ci_test and (args.ci_metric_checker_key is not None):
8+
return MetricChecker(args)
9+
return None
10+
11+
def __init__(self, args):
12+
self.args = args
13+
self._exists_check_success = False
14+
15+
def on_eval(self, metrics: Dict[str, float]):
16+
actual_value = metrics.get(self.args.ci_metric_checker_key)
17+
assert actual_value is not None, f"{metrics=} {self.args.ci_metric_checker_key=}"
18+
19+
check_success = actual_value >= self.args.ci_metric_checker_threshold
20+
print(f"[MetricChecker] {check_success=} {actual_value=} {self.args.ci_metric_checker_threshold=}")
21+
22+
self._exists_check_success |= check_success
23+
24+
def dispose(self):
25+
assert self._exists_check_success, "[MetricChecker] accuracy check failed"
26+
print(f"[MetricChecker] pass dispose check", flush=True)

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def train(args):
8484
):
8585
ray.get(rollout_manager.eval.remote(rollout_id))
8686

87+
ray.get(rollout_manager.dispose.remote())
88+
8789

8890
if __name__ == "__main__":
8991
args = parse_args()

train_async.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def train(args):
6363
):
6464
ray.get(rollout_manager.eval.remote(rollout_id))
6565

66+
ray.get(rollout_manager.dispose.remote())
67+
6668

6769
if __name__ == "__main__":
6870
args = parse_args()

0 commit comments

Comments
 (0)