1- from typing import Union
1+ from typing import Callable , Union
22
33import torch
44import torch .distributed as dist
@@ -45,26 +45,26 @@ def get_logits_and_tokens_offset_with_cp(
4545
4646
4747def get_sum_of_sample_mean (
48- total_lengths ,
49- response_lengths ,
50- loss_masks ,
48+ total_lengths : list [ int ] ,
49+ response_lengths : list [ int ] ,
50+ loss_masks : list [ torch . Tensor ] ,
5151 calculate_per_token_loss : bool = False ,
52- ):
52+ ) -> Callable [[ torch . Tensor ], torch . Tensor ] :
5353 """
5454 Calculate correct sample mean for CP
5555 """
5656 cp_size = mpu .get_context_parallel_world_size ()
5757 if cp_size == 1 :
5858
59- def sum_of_sample_mean (x : torch .Tensor ):
59+ def sum_of_sample_mean (x : torch .Tensor ) -> torch . Tensor :
6060 return sum (
6161 [
6262 (x_i * loss_mask_i ).sum () / torch .clamp_min (loss_mask_i .sum (), 1 )
6363 for x_i , loss_mask_i in zip (x .split (response_lengths , dim = 0 ), loss_masks )
6464 ]
6565 )
6666
67- def sum_of_token (x : torch .Tensor ):
67+ def sum_of_token (x : torch .Tensor ) -> torch . Tensor :
6868 return sum (
6969 [(x_i * loss_mask_i ).sum () for x_i , loss_mask_i in zip (x .split (response_lengths , dim = 0 ), loss_masks )]
7070 )
@@ -82,7 +82,7 @@ def sum_of_token(x: torch.Tensor):
8282 chunked_loss_masks .append (torch .cat ([loss_mask_0 , loss_mask_1 ], dim = 0 ))
8383 cp_chunk_lengths .append (chunked_loss_masks [i ].size (0 ))
8484
85- def sum_of_sample_mean (x ) :
85+ def sum_of_sample_mean (x : torch . Tensor ) -> torch . Tensor :
8686 return sum (
8787 [
8888 (x_i * chunked_loss_mask ).sum () / torch .clamp_min (loss_mask .sum (), 1 )
@@ -92,7 +92,7 @@ def sum_of_sample_mean(x):
9292 ]
9393 )
9494
95- def sum_of_token (x : torch .Tensor ):
95+ def sum_of_token (x : torch .Tensor ) -> torch . Tensor :
9696 return sum (
9797 [
9898 (x_i * chunked_loss_mask ).sum ()
@@ -103,7 +103,7 @@ def sum_of_token(x: torch.Tensor):
103103 return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token
104104
105105
106- def all_gather_with_cp (tensor : torch .Tensor , total_length : int , response_length : int ):
106+ def all_gather_with_cp (tensor : torch .Tensor , total_length : int , response_length : int ) -> torch . Tensor :
107107 """
108108 Gather tensors across all ranks in the context parallel group.
109109 The first dimension of the output tensor will be the `response_length`.
@@ -122,7 +122,7 @@ def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length:
122122 chunk_1 = tensor [logits_offset [0 ][1 ] - logits_offset [0 ][0 ] :]
123123 assert chunk_1 .shape [0 ] == logits_offset [1 ][1 ] - logits_offset [1 ][0 ]
124124
125- def zero (len ) :
125+ def zero (len : int ) -> torch . Tensor :
126126 return torch .zeros (
127127 [len ] + list (tensor .shape [1 :]),
128128 dtype = tensor .dtype ,
@@ -155,7 +155,7 @@ def zero(len):
155155 return full_tensor
156156
157157
158- def slice_with_cp (tokens : torch .Tensor , pad_value ) :
158+ def slice_with_cp (tokens : torch .Tensor , pad_value : int ) -> torch . Tensor :
159159 cp_rank = mpu .get_context_parallel_rank ()
160160 cp_size = mpu .get_context_parallel_world_size ()
161161
@@ -172,7 +172,11 @@ def slice_with_cp(tokens: torch.Tensor, pad_value):
172172 return torch .cat ([tokens [start_1 :end_1 ], tokens [start_2 :end_2 ]])
173173
174174
175- def slice_log_prob_with_cp (log_prob : Union [list [float ], torch .Tensor ], total_length : int , response_length : int ):
175+ def slice_log_prob_with_cp (
176+ log_prob : Union [list [float ], torch .Tensor ],
177+ total_length : int ,
178+ response_length : int ,
179+ ) -> Union [list [float ], torch .Tensor ]:
176180 assert len (log_prob ) == response_length
177181
178182 cp_size = mpu .get_context_parallel_world_size ()
0 commit comments