Skip to content

Commit b43d34d

Browse files
committed
fix pre-commit
1 parent 48bf8aa commit b43d34d

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

slime/backends/megatron_utils/data.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,15 @@ def log_rollout_data(rollout_id, args, rollout_data):
174174
log_dict = {}
175175
response_lengths = rollout_data["response_lengths"]
176176
for key, val in rollout_data.items():
177-
if key == "tokens" or key == "loss_masks" or key == "sample_indices"or key == "rollout_time" or key == "completion_tokens_stats" or key == "partial_samples" or key == "total_off_policy_tokens":
177+
if (
178+
key == "tokens"
179+
or key == "loss_masks"
180+
or key == "sample_indices"
181+
or key == "rollout_time"
182+
or key == "completion_tokens_stats"
183+
or key == "partial_samples"
184+
or key == "total_off_policy_tokens"
185+
):
178186
continue
179187
# Upload per sample mean for each rollout value
180188
# There are the following assumptions:
@@ -248,7 +256,9 @@ def log_partial_rollout_data(rollout_id, args, rollout_data):
248256
total_off_policy_tokens = rollout_data["total_off_policy_tokens"]
249257
if total_off_policy_tokens is not None:
250258
log_dict["total_off_policy_tokens"] = total_off_policy_tokens
251-
log_dict["off_policy_ratio"] = total_off_policy_tokens / (log_dict["total_tokens"] + total_off_policy_tokens)
259+
log_dict["off_policy_ratio"] = total_off_policy_tokens / (
260+
log_dict["total_tokens"] + total_off_policy_tokens
261+
)
252262

253263
response_lengths = rollout_data["response_lengths"]
254264

@@ -298,7 +308,7 @@ def log_partial_rollout_data(rollout_id, args, rollout_data):
298308
dst=mpu.get_data_parallel_src_rank(with_context_parallel=True),
299309
group=mpu.get_data_parallel_group(with_context_parallel=True),
300310
)
301-
311+
302312

303313
def log_multi_turn_data(rollout_id, args, rollout_data):
304314
if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage():

slime/rollout/sglang_example.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -288,19 +288,21 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis
288288
if state.completion_tokens_list:
289289
completion_tokens_array = np.array(state.completion_tokens_list)
290290
completion_tokens_stats = {
291-
'total_completion_tokens': np.sum(completion_tokens_array).item(),
292-
'completion_tokens_mean': np.mean(completion_tokens_array).item(),
293-
'completion_tokens_std': np.std(completion_tokens_array).item(),
294-
'completion_tokens_count': len(completion_tokens_array),
291+
"total_completion_tokens": np.sum(completion_tokens_array).item(),
292+
"completion_tokens_mean": np.mean(completion_tokens_array).item(),
293+
"completion_tokens_std": np.std(completion_tokens_array).item(),
294+
"completion_tokens_count": len(completion_tokens_array),
295295
}
296296

297297
if len(data) > 0:
298-
data[0][0].metadata.update({
299-
'rollout_time': rollout_time,
300-
'completion_tokens_stats': completion_tokens_stats,
301-
'partial_samples': state.partial_samples_count,
302-
'total_off_policy_tokens': state.total_off_policy_tokens,
303-
})
298+
data[0][0].metadata.update(
299+
{
300+
"rollout_time": rollout_time,
301+
"completion_tokens_stats": completion_tokens_stats,
302+
"partial_samples": state.partial_samples_count,
303+
"total_off_policy_tokens": state.total_off_policy_tokens,
304+
}
305+
)
304306
if completion_tokens_stats:
305307
print(f"[DEBUG] Rollout {rollout_id}: Completion tokens stats: {completion_tokens_stats}", flush=True)
306308

slime/utils/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Sample:
2020
reward: Optional[Union[float, dict[str, Any]]] = None
2121
loss_mask: Optional[list[int]] = None
2222
completion_tokens: Optional[int] = None
23-
23+
2424
class Status(Enum):
2525
PENDING = "pending"
2626
COMPLETED = "completed"

0 commit comments

Comments
 (0)