Skip to content

Commit bbb67c0

Browse files
authored
Add --rollout-sample-filter-path (#961)
1 parent c752927 commit bbb67c0

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

slime/rollout/sglang_rollout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,10 @@ async def generate_rollout_async(
410410

411411
# reset the global state to prevent effects on the next rollout or eval.
412412
state.reset()
413+
if args.rollout_sample_filter_path is not None:
414+
filter_func = load_function(args.rollout_sample_filter_path)
415+
filter_func(args, data)
416+
413417
return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples
414418

415419

slime/utils/arguments.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,18 @@ def add_rollout_buffer_arguments(parser):
10941094
default=128,
10951095
help="Multiplier for data padding size in data processing.",
10961096
)
1097+
parser.add_argument(
1098+
"--rollout-sample-filter-path",
1099+
type=str,
1100+
default=None,
1101+
help=(
1102+
"Path to the rollout sample filter function. "
1103+
"This function determines whether a sample will participate in loss calculation. "
1104+
"The function should take args and samples (list[Sample]) as input, and return None. "
1105+
"Please directly modify the remove_sample attribute of Sample. "
1106+
"Note: This attribute does not determine whether the sample participates in advantage normalization."
1107+
),
1108+
)
10971109
return parser
10981110

10991111
def add_custom_megatron_plugins_arguments(parser):

0 commit comments

Comments
 (0)