Skip to content

Commit f32b57f

Browse files
Williamren97PopSoda2002
authored andcommitted
Support varlen with CP and datapacking for normal models
1 parent e08791a commit f32b57f

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ def compute_log_prob(
260260
for batch in self.prof.iterate_train_log_probs(
261261
tqdm(packed_batches, desc=f"{store_prefix}log_probs", disable=dist.get_rank() != 0)
262262
):
263+
# Update cu_seqlens for CP before forward pass
264+
if self.args.enable_cp:
265+
self._update_cp_cu_seqlens(batch)
263266
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
264267
model_args = {
265268
"input_ids": batch["tokens"].unsqueeze(0),

tests/test_qwen3-0.6B_fsdp_colocated_2xGPU.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ FSDP_ARGS=(
9696
--update-weights-buffer-size $((512 * 1024 * 1024)) # 512MB
9797
)
9898

99+
# Context Parallelism Arguments
100+
# Uncomment to enable CP with varlen support and data packing
101+
CP_ARGS=(
102+
# --enable-cp # Enable Context Parallelism
103+
# --ring-flash-atten-type llama3 # Use llama3 ring attention implementation
104+
)
105+
99106
# launch the master node of ray in container
100107
ray start --head --node-ip-address 127.0.0.1 --num-gpus 2 --disable-usage-stats
101108

@@ -115,4 +122,6 @@ ray job submit --address="http://127.0.0.1:8265" \
115122
${OPTIMIZER_ARGS[@]} \
116123
${GRPO_ARGS[@]} \
117124
${SGLANG_ARGS[@]} \
118-
${WANDB_ARGS[@]}
125+
${WANDB_ARGS[@]} \
126+
${FSDP_ARGS[@]} \
127+
${CP_ARGS[@]}

0 commit comments

Comments
 (0)