|
14 | 14 | from slime.backends.sglang_utils.sglang_engine import SGLangEngine |
15 | 15 | from slime.ray.rollout_data_source import RolloutDataSourceWithBuffer |
16 | 16 | from slime.utils.http_utils import find_available_port, get_host_info, init_http_client |
| 17 | +from slime.utils.metric_checker import MetricChecker |
17 | 18 | from slime.utils.misc import load_function |
18 | 19 | from slime.utils.ray_utils import Box |
19 | 20 | from slime.utils.types import Sample |
@@ -58,13 +59,19 @@ def __init__(self, args, pg, wandb_run_id): |
58 | 59 | self.rollout_engines = self.all_rollout_engines[:: self.nodes_per_engine] |
59 | 60 | self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() |
60 | 61 |
|
| 62 | + self._metric_checker = MetricChecker.maybe_create(args) |
| 63 | + |
61 | 64 | # fault tolerance |
62 | 65 | self._health_monitor_thread = None |
63 | 66 | self._health_monitor_stop_event = None |
64 | 67 | self._health_check_interval = args.rollout_health_check_interval |
65 | 68 | self._health_check_timeout = args.rollout_health_check_timeout |
66 | 69 | self._health_check_first_wait = args.rollout_health_check_first_wait |
67 | 70 |
|
| 71 | + def dispose(self): |
| 72 | + if self._metric_checker is not None: |
| 73 | + self._metric_checker.dispose() |
| 74 | + |
68 | 75 | def get_rollout_engines_and_lock(self): |
69 | 76 | return self.rollout_engines, self.rollout_engine_lock, self.num_new_engines |
70 | 77 |
|
@@ -93,7 +100,9 @@ def eval(self, rollout_id): |
93 | 100 | return |
94 | 101 | # TODO: add fault tolerance to eval |
95 | 102 | data = self.eval_generate_rollout(self.args, rollout_id, self.data_source, evaluation=True) |
96 | | - _log_eval_rollout_data(rollout_id, self.args, data) |
| 103 | + metrics = _log_eval_rollout_data(rollout_id, self.args, data) |
| 104 | + if self._metric_checker is not None: |
| 105 | + self._metric_checker.on_eval(metrics) |
97 | 106 |
|
98 | 107 | def save(self, rollout_id): |
99 | 108 | self.data_source.save(rollout_id) |
@@ -474,6 +483,8 @@ def _log_eval_rollout_data(rollout_id, args, data): |
474 | 483 | ), |
475 | 484 | ) |
476 | 485 |
|
| 486 | + return log_dict |
| 487 | + |
477 | 488 |
|
478 | 489 | def _log_rollout_data(rollout_id, args, samples, rollout_time): |
479 | 490 | if args.load_debug_rollout_data: |
|
0 commit comments