22
33import torch
44
5- from slime .backends .training_utils .parallel import ParallelState
6-
75# NOTE:
86# - `compute_mis_weights` is a lightweight, standalone function that is useful to unit-test on CPU.
97# - `compute_mis_weights_with_cp` depends on Megatron context-parallel utilities, which are heavy and may not be
@@ -318,7 +316,6 @@ def compute_mis_weights_with_cp(
318316 loss_masks : list [torch .Tensor ],
319317 total_lengths : list [int ],
320318 response_lengths : list [int ],
321- parallel_state : ParallelState ,
322319 ** kwargs : Any ,
323320) -> tuple [torch .Tensor , list [torch .Tensor ], dict [str , torch .Tensor ]]:
324321 """
@@ -335,17 +332,17 @@ def compute_mis_weights_with_cp(
335332 is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors.
336333 """
337334 # Lazy import to avoid importing Megatron dependencies when only `compute_mis_weights` is used.
338- from slime .backends .training_utils .cp_utils import all_gather_with_cp , slice_log_prob_with_cp
335+ from slime .backends .megatron_utils .cp_utils import all_gather_with_cp , slice_log_prob_with_cp
339336
340337 # Gather cp slice from other cp ranks
341338 full_rollout_log_probs = [
342- all_gather_with_cp (log_prob , total_length , response_length , parallel_state )
339+ all_gather_with_cp (log_prob , total_length , response_length )
343340 for log_prob , total_length , response_length in zip (
344341 rollout_log_probs , total_lengths , response_lengths , strict = False
345342 )
346343 ]
347344 full_old_log_probs = [
348- all_gather_with_cp (old_log_prob , total_length , response_length , parallel_state )
345+ all_gather_with_cp (old_log_prob , total_length , response_length )
349346 for old_log_prob , total_length , response_length in zip (
350347 train_log_probs , total_lengths , response_lengths , strict = False
351348 )
@@ -365,7 +362,7 @@ def slice_cp_and_concat(
365362 ) -> torch .Tensor :
366363 values = [
367364 # TODO: A rename of this function?
368- slice_log_prob_with_cp (values [i ], total_lengths [i ], response_lengths [i ], parallel_state )
365+ slice_log_prob_with_cp (values [i ], total_lengths [i ], response_lengths [i ])
369366 for i in range (len (values ))
370367 ]
371368 return torch .cat (values , dim = 0 )
0 commit comments