File tree Expand file tree Collapse file tree 2 files changed +16
-0
lines changed
Expand file tree Collapse file tree 2 files changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -410,6 +410,10 @@ async def generate_rollout_async(
410410
411411 # reset the global state to prevent effects on the next rollout or eval.
412412 state .reset ()
413+ if args .rollout_sample_filter_path is not None :
414+ filter_func = load_function (args .rollout_sample_filter_path )
415+ filter_func (args , data )
416+
413417 return RolloutFnTrainOutput (samples = data , metrics = metric_gatherer .collect ()), aborted_samples
414418
415419
Original file line number Diff line number Diff line change @@ -1094,6 +1094,18 @@ def add_rollout_buffer_arguments(parser):
10941094 default = 128 ,
10951095 help = "Multiplier for data padding size in data processing." ,
10961096 )
1097+ parser .add_argument (
1098+ "--rollout-sample-filter-path" ,
1099+ type = str ,
1100+ default = None ,
1101+ help = (
1102+ "Path to the rollout sample filter function. "
1103+ "This function determines whether a sample will participate in loss calculation. "
1104+ "The function should take args and samples (list[Sample]) as input, and return None. "
1105+ "Please directly modify the remove_sample attribute of Sample. "
1106+ "Note: This attribute does not determine whether the sample participates in advantage normalization."
1107+ ),
1108+ )
10971109 return parser
10981110
10991111 def add_custom_megatron_plugins_arguments (parser ):
You can’t perform that action at this time.
0 commit comments