Skip to content

Commit 1ede369

Browse files
authored
[3/N] Enable B ruleset in Ruff (#993)
1 parent f829eb2 commit 1ede369

35 files changed

+133
-88
lines changed

examples/fully_async/fully_async_rollout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ async def continuous_worker_loop(self):
9090

9191
# Add completion callback
9292
def make_callback(gid):
93-
def task_done_callback(task):
94-
result = task.result()
93+
def task_done_callback(done_task):
94+
result = done_task.result()
9595
self.output_queue.put((gid, result))
9696

9797
return task_done_callback

examples/multi_agent/agent_system.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self):
8585

8686
async def run(self, args, prompt, max_retries: int = 1, key: str = None) -> str:
8787
"""Runs the agent by sending a prompt to the LLM."""
88-
for i in range(max_retries):
88+
for _i in range(max_retries):
8989
try:
9090
response = await generate_response(args, prompt, key=key)
9191
return response
@@ -200,7 +200,7 @@ async def run_agent_system(args, sample):
200200
results = await asyncio.gather(*tasks, return_exceptions=True)
201201

202202
rewards = await batched_async_rm(args, args.results_dict["solver"])
203-
for sample, reward in zip(args.results_dict["solver"], rewards):
203+
for sample, reward in zip(args.results_dict["solver"], rewards, strict=False):
204204
sample.reward = reward
205205

206206
previous_solutions = [item for item in results if isinstance(item, str)]
@@ -223,12 +223,12 @@ def reward_adjustment(samples, reward_weight):
223223

224224
# 处理异常结果
225225
rewrited_solutions = []
226-
for i, result in enumerate(rewrited_solutions_raw):
226+
for _i, result in enumerate(rewrited_solutions_raw):
227227
if isinstance(result, str):
228228
rewrited_solutions.append(result)
229229

230230
rewards = await batched_async_rm(args, args.results_dict["rewriter"])
231-
for sample, reward in zip(args.results_dict["rewriter"], rewards):
231+
for sample, reward in zip(args.results_dict["rewriter"], rewards, strict=False):
232232
sample.reward = reward
233233

234234
if len(rewrited_solutions) == 0:

examples/on_policy_distillation/on_policy_distillation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ def post_process_rewards(args, samples: list[Sample], **kwargs):
3030
for reward in rewards
3131
]
3232
teacher_log_probs = [
33-
t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths)
33+
t_log_prob[-response_length:]
34+
for t_log_prob, response_length in zip(teacher_log_probs, response_lengths, strict=False)
3435
]
3536

36-
for sample, t_log_probs in zip(samples, teacher_log_probs):
37+
for sample, t_log_probs in zip(samples, teacher_log_probs, strict=False):
3738
sample.teacher_log_probs = t_log_probs
3839

3940
return teacher_log_probs, teacher_log_probs

examples/retool/generate_with_retool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
try:
66
from jinja2 import Template
7-
except ImportError:
8-
raise ImportError("Jinja2 is required. Please install it with: pip install jinja2")
7+
except ImportError as e:
8+
raise ImportError("Jinja2 is required. Please install it with: pip install jinja2") from e
99

1010
from slime.rollout.sglang_rollout import GenerateState
1111
from slime.utils.http_utils import post
@@ -14,8 +14,8 @@
1414
# Import reward models
1515
try:
1616
from slime.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score
17-
except ImportError:
18-
raise ImportError("MathDapo is not installed")
17+
except ImportError as e:
18+
raise ImportError("MathDapo is not installed") from e
1919

2020
# Import tool sandbox functionality
2121
from tool_sandbox import SEMAPHORE, TOOL_CONFIGS, tool_registry

examples/search-r1/generate_with_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
157157
loss_mask = []
158158
rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None
159159

160-
for turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
160+
for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
161161
payload = {
162162
"text": prompt + response,
163163
"sampling_params": sampling_params,

examples/search-r1/qa_em_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def is_valid_sequence(text):
7878
state = "start" # start -> think -> search -> information -> think -> ... -> answer -> end
7979

8080
# 3. Check each part
81-
for i, part in enumerate(parts):
81+
for _i, part in enumerate(parts):
8282
# Skip empty parts
8383
if not part.strip():
8484
continue

examples/train_infer_mismatch_helper/mis.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def compute_mis_weights(
145145
len(train_log_probs) == len(rollout_log_probs) == len(loss_masks)
146146
), f"Input lists must have the same number of sequences: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}"
147147

148-
for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)):
148+
for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks, strict=False)):
149149
assert (
150150
train.shape == rollout.shape == loss_mask.shape
151151
), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}"
@@ -164,15 +164,19 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
164164
else:
165165
raise ValueError(f"Invalid level: {level}")
166166

167-
for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks):
167+
for train_log_prob, rollout_log_prob, loss_mask in zip(
168+
train_log_probs, rollout_log_probs, loss_masks, strict=False
169+
):
168170
add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics)
169171

170172
# only calculate mismatch metrics if TIS is not used
171173
if not args.use_tis:
172174
return None, loss_masks, metrics
173175

174176
# handle each sequence independently
175-
for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks):
177+
for train_log_prob, rollout_log_prob, loss_mask in zip(
178+
train_log_probs, rollout_log_probs, loss_masks, strict=False
179+
):
176180
loss_mask = loss_mask.float()
177181
raw_log_ratio_diff = train_log_prob - rollout_log_prob
178182
modified_mask = loss_mask.clone().float()
@@ -228,14 +232,14 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
228232
tis_level = args.tis_level if args.use_tis else "token"
229233
if tis_level == "token":
230234
# Token-level: normalize over all token weights
231-
total_weights_sum = sum(masked_sum(w, m) for w, m in zip(all_weights, loss_masks))
235+
total_weights_sum = sum(masked_sum(w, m) for w, m in zip(all_weights, loss_masks, strict=False))
232236
total_mask_count = sum(m.sum() for m in loss_masks)
233237
weights_mean = total_weights_sum / torch.clamp_min(total_mask_count, 1)
234238
elif tis_level == "sequence":
235239
# Sequence-level: normalize over sequence weights (one weight per sequence)
236240
# For each sequence, compute mean over valid tokens (they all have the same weight)
237241
# then average across sequences
238-
seq_weights_means = [masked_mean(w, m) for w, m in zip(all_weights, loss_masks)]
242+
seq_weights_means = [masked_mean(w, m) for w, m in zip(all_weights, loss_masks, strict=False)]
239243
weights_mean = sum(seq_weights_means) / len(seq_weights_means)
240244
else:
241245
raise ValueError(f"Unsupported tis_level: {tis_level}")
@@ -279,11 +283,15 @@ def compute_mis_weights_with_cp(
279283
# Gather cp slice from other cp ranks
280284
full_rollout_log_probs = [
281285
all_gather_with_cp(log_prob, total_length, response_length)
282-
for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths)
286+
for log_prob, total_length, response_length in zip(
287+
rollout_log_probs, total_lengths, response_lengths, strict=False
288+
)
283289
]
284290
full_old_log_probs = [
285291
all_gather_with_cp(old_log_prob, total_length, response_length)
286-
for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths)
292+
for old_log_prob, total_length, response_length in zip(
293+
train_log_probs, total_lengths, response_lengths, strict=False
294+
)
287295
]
288296

289297
# Main logic for is (decoupled)

pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ src_paths = ["slime", "slime_plugins"]
2323
line_length = 119
2424

2525
[tool.ruff]
26-
line-length = 119
26+
line-length = 320 # TODO
27+
select = [
28+
"E", # Pycodestyle Errors (Structural/Fundamental Errors like bad indentation)
29+
"F", # Pyflakes (Core Errors: Unused imports, undefined names)
30+
"B", # Flake8-Bugbear (Logic Bugs: Variable shadowing, dangerous default arguments)
31+
# "UP", # pyupgrade (Modernization and compatibility issues) # TODO
32+
]
2733
ignore = [
28-
"E402",
34+
"E402", # module-import-not-at-top-of-file
35+
"E501", # Line too long # TODO handle it later
2936
]
3037

3138
[tool.pytest.ini_options]

slime/backends/fsdp_utils/actor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _fsdp2_load_full_state_dict(self, model, full_state, device_mesh, cpu_offloa
256256
set_model_state_dict(model, full_state, options=options)
257257

258258
# set_model_state_dict will not broadcast buffers, so we need to broadcast them manually.
259-
for name, buf in model.named_buffers():
259+
for _name, buf in model.named_buffers():
260260
dist.broadcast(buf, src=0)
261261

262262
if is_cpu_offload:
@@ -476,7 +476,7 @@ def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches):
476476
if metric_key not in packed_batches[0]:
477477
continue
478478
val = torch.tensor([0.0], device=torch.cuda.current_device())
479-
for mbs_id, batches in enumerate(packed_batches):
479+
for _mbs_id, batches in enumerate(packed_batches):
480480
unpacked_batches = unpack_sequences(batches)
481481
for unpacked_batch in unpacked_batches:
482482
if isinstance(unpacked_batch[metric_key], torch.Tensor):
@@ -598,11 +598,11 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
598598

599599
seq_kls = [
600600
((log_ratio_i * mask_i).sum() / mask_i.sum().clamp_min(1))
601-
for log_ratio_i, mask_i in zip(log_ratio_splits, loss_masks)
601+
for log_ratio_i, mask_i in zip(log_ratio_splits, loss_masks, strict=False)
602602
]
603603

604604
ppo_kl_list = []
605-
for seq_kl, length in zip(seq_kls, response_lengths):
605+
for seq_kl, length in zip(seq_kls, response_lengths, strict=False):
606606
ppo_kl_list.append(seq_kl.expand(length))
607607

608608
ppo_kl = torch.cat(ppo_kl_list)
@@ -976,7 +976,7 @@ def sum_of_sample_mean(x: torch.Tensor, response_lengths: list[int], loss_masks:
976976
return sum(
977977
[
978978
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
979-
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks)
979+
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
980980
]
981981
)
982982

slime/backends/fsdp_utils/update_weight_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def update_bucket_weights(self, named_tensors) -> None:
128128

129129
# Create flattened bucket for each dtype group
130130
serialized_tensors = []
131-
for dtype, named_tensors in named_tensors_by_dtypes.items():
131+
for _dtype, named_tensors in named_tensors_by_dtypes.items():
132132
flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=named_tensors)
133133
metadata = flattened_tensor_bucket.get_metadata()
134134
flattened_tensor_data = {
@@ -241,7 +241,7 @@ def update_bucket_weights(self, named_tensors) -> None:
241241

242242
handles = []
243243
# Broadcast parameters one by one with memory management
244-
for name, param in named_tensors:
244+
for _name, param in named_tensors:
245245
torch.cuda.empty_cache()
246246
# Ensure tensor is contiguous and on the right device
247247
param_data = param.data.contiguous()

0 commit comments

Comments
 (0)