Skip to content

Commit 6a08ae9

Browse files
tourzhaoclaude
andcommitted
style: fix ruff and black formatting issues
- Add strict=False to zip() calls (ruff B905) - Apply black formatting fixes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f603b24 commit 6a08ae9

File tree

3 files changed

+34
-14
lines changed

3 files changed

+34
-14
lines changed

examples/knowledge_distillation/kd_loss.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ def store_topk_data(samples):
1919

2020

2121
def _get_topk_data(tokens):
22-
key = tuple(tokens[:20].tolist() if hasattr(tokens, 'tolist') else tokens[:20])
22+
key = tuple(tokens[:20].tolist() if hasattr(tokens, "tolist") else tokens[:20])
2323
return _topk_data_store.get(key)
2424

2525

2626
def sampled_kl_loss(args, batch, logits, sum_of_sample_mean):
2727
"""Forward KL on teacher-sampled tokens (KD_TOP_K=0)."""
2828
_, log_probs_result = get_log_probs_and_entropy(
29-
logits, args=args,
29+
logits,
30+
args=args,
3031
unconcat_tokens=batch["unconcat_tokens"],
3132
total_lengths=batch["total_lengths"],
3233
response_lengths=batch["response_lengths"],
@@ -37,7 +38,7 @@ def sampled_kl_loss(args, batch, logits, sum_of_sample_mean):
3738
entropy = log_probs_result.get("entropy", [])
3839

3940
kl_terms = []
40-
for s_lp, t_lp in zip(student_lps, batch["teacher_log_probs"]):
41+
for s_lp, t_lp in zip(student_lps, batch["teacher_log_probs"], strict=False):
4142
kl_terms.append(t_lp.to(s_lp) - s_lp)
4243

4344
loss = sum_of_sample_mean(torch.cat(kl_terms))
@@ -66,11 +67,16 @@ def _extract_response_log_probs(logits, unconcat_tokens, total_lengths, response
6667
def topk_kl_loss(args, batch, logits, sum_of_sample_mean):
6768
"""Forward KL on teacher's top-K tokens with temperature scaling."""
6869
student_full_lps = _extract_response_log_probs(
69-
logits, batch["unconcat_tokens"], batch["total_lengths"], batch["response_lengths"],
70+
logits,
71+
batch["unconcat_tokens"],
72+
batch["total_lengths"],
73+
batch["response_lengths"],
7074
)
7175

7276
topk_data_list = [_get_topk_data(tokens) for tokens in batch["unconcat_tokens"]]
73-
valid_data = [(s_lp, data) for s_lp, data in zip(student_full_lps, topk_data_list) if data is not None]
77+
valid_data = [
78+
(s_lp, data) for s_lp, data in zip(student_full_lps, topk_data_list, strict=False) if data is not None
79+
]
7480

7581
if not valid_data:
7682
return sampled_kl_loss(args, batch, logits, sum_of_sample_mean)
@@ -84,7 +90,7 @@ def topk_kl_loss(args, batch, logits, sum_of_sample_mean):
8490
s_topk = s_lp.gather(1, t_ids)
8591
t_renorm = torch.log_softmax(t_lps / tau, dim=-1)
8692
s_renorm = torch.log_softmax(s_topk / tau, dim=-1)
87-
kl_terms.append((tau ** 2) * (t_renorm.exp() * (t_renorm - s_renorm)).sum(dim=-1))
93+
kl_terms.append((tau**2) * (t_renorm.exp() * (t_renorm - s_renorm)).sum(dim=-1))
8894

8995
loss = sum_of_sample_mean(torch.cat(kl_terms))
9096
return loss, {"kd/loss": loss.detach()}

examples/knowledge_distillation/knowledge_distillation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,19 @@ async def _generate_rollout_async(args, data_source):
110110
semaphore = asyncio.Semaphore(max(getattr(args, "sglang_server_concurrency", 64), 1))
111111

112112
async with aiohttp.ClientSession() as session:
113-
generated_groups = await asyncio.gather(*(
114-
asyncio.gather(*(_generate_sample(args, s, sampling_params, tokenizer, session, semaphore) for s in group))
115-
for group in samples
116-
))
113+
generated_groups = await asyncio.gather(
114+
*(
115+
asyncio.gather(
116+
*(_generate_sample(args, s, sampling_params, tokenizer, session, semaphore) for s in group)
117+
)
118+
for group in samples
119+
)
120+
)
117121

118122
first = generated_groups[0][0]
119-
logger.info(f"KD rollout: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}")
123+
logger.info(
124+
f"KD rollout: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}"
125+
)
120126

121127
token_count = sum(s.response_length for g in generated_groups for s in g)
122128
return RolloutFnTrainOutput(samples=generated_groups, metrics={"kd/token_count": token_count})
@@ -132,6 +138,7 @@ def generate_rollout(args, rollout_id, data_source, evaluation=False):
132138
# Store top-k data for loss function access
133139
if KD_TOP_K > 0:
134140
from examples.knowledge_distillation.kd_loss import store_topk_data
141+
135142
store_topk_data(result.samples)
136143

137144
if KD_SAVE_PATH:

examples/knowledge_distillation/offline_kd.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ def _load_from_jsonl(load_path, rollout_id, batch_size, num_rollouts_per_prompt)
3939
)
4040

4141
if "teacher_top_k_ids" in record and "teacher_top_k_logprobs" in record:
42-
assert len(record["teacher_top_k_ids"]) == resp_len, f"Sample {len(samples)}: top_k_ids length mismatch"
43-
assert len(record["teacher_top_k_logprobs"]) == resp_len, f"Sample {len(samples)}: top_k_logprobs length mismatch"
42+
assert (
43+
len(record["teacher_top_k_ids"]) == resp_len
44+
), f"Sample {len(samples)}: top_k_ids length mismatch"
45+
assert (
46+
len(record["teacher_top_k_logprobs"]) == resp_len
47+
), f"Sample {len(samples)}: top_k_logprobs length mismatch"
4448
sample.train_metadata = {
4549
"teacher_top_k_ids": record["teacher_top_k_ids"],
4650
"teacher_top_k_logprobs": record["teacher_top_k_logprobs"],
@@ -80,10 +84,13 @@ def generate_rollout(args, rollout_id, data_source, evaluation=False):
8084
# Store top-k data for loss function access
8185
if KD_TOP_K > 0:
8286
from examples.knowledge_distillation.kd_loss import store_topk_data
87+
8388
store_topk_data(samples)
8489

8590
first = samples[0][0]
86-
logger.info(f"Offline KD: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}")
91+
logger.info(
92+
f"Offline KD: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}"
93+
)
8794

8895
token_count = sum(s.response_length for g in samples for s in g)
8996
return RolloutFnTrainOutput(samples=samples, metrics={"kd/token_count": token_count})

0 commit comments

Comments
 (0)