@@ -19,14 +19,15 @@ def store_topk_data(samples):
1919
2020
2121def _get_topk_data (tokens ):
22- key = tuple (tokens [:20 ].tolist () if hasattr (tokens , ' tolist' ) else tokens [:20 ])
22+ key = tuple (tokens [:20 ].tolist () if hasattr (tokens , " tolist" ) else tokens [:20 ])
2323 return _topk_data_store .get (key )
2424
2525
2626def sampled_kl_loss (args , batch , logits , sum_of_sample_mean ):
2727 """Forward KL on teacher-sampled tokens (KD_TOP_K=0)."""
2828 _ , log_probs_result = get_log_probs_and_entropy (
29- logits , args = args ,
29+ logits ,
30+ args = args ,
3031 unconcat_tokens = batch ["unconcat_tokens" ],
3132 total_lengths = batch ["total_lengths" ],
3233 response_lengths = batch ["response_lengths" ],
@@ -37,7 +38,7 @@ def sampled_kl_loss(args, batch, logits, sum_of_sample_mean):
3738 entropy = log_probs_result .get ("entropy" , [])
3839
3940 kl_terms = []
40- for s_lp , t_lp in zip (student_lps , batch ["teacher_log_probs" ]):
41+ for s_lp , t_lp in zip (student_lps , batch ["teacher_log_probs" ], strict = False ):
4142 kl_terms .append (t_lp .to (s_lp ) - s_lp )
4243
4344 loss = sum_of_sample_mean (torch .cat (kl_terms ))
@@ -66,11 +67,16 @@ def _extract_response_log_probs(logits, unconcat_tokens, total_lengths, response
6667def topk_kl_loss (args , batch , logits , sum_of_sample_mean ):
6768 """Forward KL on teacher's top-K tokens with temperature scaling."""
6869 student_full_lps = _extract_response_log_probs (
69- logits , batch ["unconcat_tokens" ], batch ["total_lengths" ], batch ["response_lengths" ],
70+ logits ,
71+ batch ["unconcat_tokens" ],
72+ batch ["total_lengths" ],
73+ batch ["response_lengths" ],
7074 )
7175
7276 topk_data_list = [_get_topk_data (tokens ) for tokens in batch ["unconcat_tokens" ]]
73- valid_data = [(s_lp , data ) for s_lp , data in zip (student_full_lps , topk_data_list ) if data is not None ]
77+ valid_data = [
78+ (s_lp , data ) for s_lp , data in zip (student_full_lps , topk_data_list , strict = False ) if data is not None
79+ ]
7480
7581 if not valid_data :
7682 return sampled_kl_loss (args , batch , logits , sum_of_sample_mean )
@@ -84,7 +90,7 @@ def topk_kl_loss(args, batch, logits, sum_of_sample_mean):
8490 s_topk = s_lp .gather (1 , t_ids )
8591 t_renorm = torch .log_softmax (t_lps / tau , dim = - 1 )
8692 s_renorm = torch .log_softmax (s_topk / tau , dim = - 1 )
87- kl_terms .append ((tau ** 2 ) * (t_renorm .exp () * (t_renorm - s_renorm )).sum (dim = - 1 ))
93+ kl_terms .append ((tau ** 2 ) * (t_renorm .exp () * (t_renorm - s_renorm )).sum (dim = - 1 ))
8894
8995 loss = sum_of_sample_mean (torch .cat (kl_terms ))
9096 return loss , {"kd/loss" : loss .detach ()}
0 commit comments