Skip to content

Commit b372afe

Browse files
committed
fix precommit
1 parent caaee34 commit b372afe

File tree

4 files changed

+29
-17
lines changed

4 files changed

+29
-17
lines changed

slime/backends/megatron_utils/cp_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def sum_of_token(x: torch.Tensor):
9898

9999
return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token
100100

101+
101102
def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int):
102103
"""
103104
Gather tensors across all ranks in the context parallel group.
@@ -140,4 +141,4 @@ def slice_with_cp(tokens: torch.Tensor, pad_value):
140141
# get 2 chunk for thd cp
141142
start_1, end_1 = chunk_size * cp_rank, chunk_size * (cp_rank + 1)
142143
start_2, end_2 = chunk_size * (2 * cp_size - cp_rank - 1), chunk_size * (2 * cp_size - cp_rank)
143-
return torch.cat([tokens[start_1:end_1], tokens[start_2:end_2]])
144+
return torch.cat([tokens[start_1:end_1], tokens[start_2:end_2]])

slime/backends/megatron_utils/data.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,15 @@ def log_rollout_data(rollout_id, args, rollout_data):
161161
log_dict = {}
162162
response_lengths = rollout_data["response_lengths"]
163163
for key, val in rollout_data.items():
164-
if key == "tokens" or key == "loss_masks" or key == "sample_indices"or key == "rollout_time" or key == "completion_tokens_stats" or key == "partial_samples" or key == "total_off_policy_tokens":
164+
if (
165+
key == "tokens"
166+
or key == "loss_masks"
167+
or key == "sample_indices"
168+
or key == "rollout_time"
169+
or key == "completion_tokens_stats"
170+
or key == "partial_samples"
171+
or key == "total_off_policy_tokens"
172+
):
165173
continue
166174
# Upload per sample mean for each rollout value
167175
# There are the following assumptions:
@@ -235,7 +243,9 @@ def log_partial_rollout_data(rollout_id, args, rollout_data):
235243
total_off_policy_tokens = rollout_data["total_off_policy_tokens"]
236244
if total_off_policy_tokens is not None:
237245
log_dict["total_off_policy_tokens"] = total_off_policy_tokens
238-
log_dict["off_policy_ratio"] = total_off_policy_tokens / (log_dict["total_tokens"] + total_off_policy_tokens)
246+
log_dict["off_policy_ratio"] = total_off_policy_tokens / (
247+
log_dict["total_tokens"] + total_off_policy_tokens
248+
)
239249

240250
response_lengths = rollout_data["response_lengths"]
241251

@@ -285,7 +295,6 @@ def log_partial_rollout_data(rollout_id, args, rollout_data):
285295
dst=mpu.get_data_parallel_src_rank(with_context_parallel=True),
286296
group=mpu.get_data_parallel_group(with_context_parallel=True),
287297
)
288-
289298

290299

291300
def log_multi_turn_data(rollout_id, args, rollout_data):

slime/backends/utils/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,4 @@ def get_partition(val):
167167
if "partial_samples" in data:
168168
rollout_data["partial_samples"] = data["partial_samples"]
169169
if "total_off_policy_tokens" in data:
170-
rollout_data["total_off_policy_tokens"] = data["total_off_policy_tokens"]
170+
rollout_data["total_off_policy_tokens"] = data["total_off_policy_tokens"]

slime/rollout/sglang_example.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def abort(args, rollout_id: int, data_buffer):
184184
print(f"Abort request for {url}", flush=True)
185185
# await post(f"{url}/abort_request", {"abort_all": True}, use_http2=False)
186186
# based on https://github.com/THUDM/slime/pull/63/files
187-
await post(f"{url}/abort_request", {"rid":"", "abort_all": True}, use_http2=False)
187+
await post(f"{url}/abort_request", {"rid": "", "abort_all": True}, use_http2=False)
188188

189189
# make sure all the pending tasks are finished
190190
count = 0
@@ -281,26 +281,28 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis
281281

282282
assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
283283
data = sorted(data, key=lambda group: group[0].index)
284-
284+
285285
rollout_time = time.time() - state.rollout_start_time
286286

287287
completion_tokens_stats = {}
288288
if state.completion_tokens_list:
289289
completion_tokens_array = np.array(state.completion_tokens_list)
290290
completion_tokens_stats = {
291-
'total_completion_tokens': np.sum(completion_tokens_array).item(),
292-
'completion_tokens_mean': np.mean(completion_tokens_array).item(),
293-
'completion_tokens_std': np.std(completion_tokens_array).item(),
294-
'completion_tokens_count': len(completion_tokens_array),
291+
"total_completion_tokens": np.sum(completion_tokens_array).item(),
292+
"completion_tokens_mean": np.mean(completion_tokens_array).item(),
293+
"completion_tokens_std": np.std(completion_tokens_array).item(),
294+
"completion_tokens_count": len(completion_tokens_array),
295295
}
296296

297297
if len(data) > 0:
298-
data[0][0].metadata.update({
299-
'rollout_time': rollout_time,
300-
'completion_tokens_stats': completion_tokens_stats,
301-
'partial_samples': state.partial_samples_count,
302-
'total_off_policy_tokens': state.total_off_policy_tokens,
303-
})
298+
data[0][0].metadata.update(
299+
{
300+
"rollout_time": rollout_time,
301+
"completion_tokens_stats": completion_tokens_stats,
302+
"partial_samples": state.partial_samples_count,
303+
"total_off_policy_tokens": state.total_off_policy_tokens,
304+
}
305+
)
304306
if completion_tokens_stats:
305307
print(f"[DEBUG] Rollout {rollout_id}: Completion tokens stats: {completion_tokens_stats}", flush=True)
306308

0 commit comments

Comments
 (0)