Skip to content

Commit 4f69141

Browse files
fzyzcjyzhuzilin
andauthored
Tiny add response length metrics (#649)
Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
1 parent 0b7b75a commit 4f69141

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

slime/ray/rollout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from slime.utils.http_utils import find_available_port, get_host_info, init_http_client
1818
from slime.utils.iter_utils import group_by
1919
from slime.utils.metric_checker import MetricChecker
20-
from slime.utils.metric_utils import compute_pass_rate, dict_add_prefix
20+
from slime.utils.metric_utils import compute_pass_rate, compute_statistics, dict_add_prefix
2121
from slime.utils.misc import load_function
2222
from slime.utils.ray_utils import Box
2323
from slime.utils.types import Sample
@@ -492,6 +492,7 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_
492492
if args.rollout_num_gpus:
493493
log_dict["perf/tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus
494494
log_dict["perf/longest_sample_tokens_per_sec"] = max(response_lengths) / rollout_time
495+
log_dict |= dict_add_prefix(compute_statistics(response_lengths), f"rollout/response_len/")
495496
log_dict |= _compute_zero_std_metrics(args, samples)
496497
log_dict |= _compute_spec_metrics(args, samples)
497498
log_dict |= dict_add_prefix(_compute_reward_cat_metrics(args, samples), f"rollout/")

slime/utils/metric_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def estimator(n, c, k):
5353
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples, num_correct)])
5454

5555

56+
def compute_statistics(values: List[float]) -> Dict[str, float]:
57+
values = np.array(values)
58+
return {
59+
"mean": np.mean(values).item(),
60+
"median": np.median(values).item(),
61+
}
62+
63+
5664
def compression_ratio(
5765
data: Union[str, bytes],
5866
*,

0 commit comments

Comments
 (0)