Skip to content

Commit a8ad2ec

Browse files
committed
update
1 parent 96a5302 commit a8ad2ec

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -246,20 +246,9 @@ def setup_context_parallelism(self):
246246
print(f"Ring attention rank: {ring_attn_rank}")
247247

248248
def _update_cp_cu_seqlens(self, packed_batch):
249-
cu_seqlens = packed_batch["cu_seqlens"].to(device=torch.cuda.current_device(), dtype=torch.int32)
249+
cu_seqlens = packed_batch["cu_seqlens"]
250250

251-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
252-
logits = self.model(
253-
input_ids=packed_batch["tokens"].unsqueeze(0),
254-
attention_mask=None,
255-
position_ids=packed_batch["position_ids"].unsqueeze(0),
256-
).logits
257-
# Handle packed sequences
258-
log_probs = gather_log_probs_packed(logits, packed_batch["tokens"], packed_batch["cu_seqlens"])
259-
packed_batch["cur_log_probs"] = log_probs
260-
unpacked_batches = unpack_sequences(packed_batch)
261-
262-
# Sync ring flash attention parameters
251+
# Update the ring attention parameters
263252
update_ring_flash_attn_params(cu_seqlens, self.cp_group)
264253

265254

0 commit comments

Comments
 (0)