Skip to content

Commit 88afbd4

Browse files
committed
Support varlen with CP and datapacking for normal models
1 parent c7cb165 commit 88afbd4

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ def compute_log_prob(
160160
rollout_data = {f"{store_prefix}log_probs": []}
161161
with timer(f"{store_prefix}log_probs") and torch.no_grad():
162162
for batch in packed_batches:
163+
# Update cu_seqlens for CP before forward pass
164+
if self.args.enable_cp:
165+
self._update_cp_cu_seqlens(batch)
166+
163167
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
164168
model_args = {
165169
"input_ids": batch["tokens"].unsqueeze(0),

tests/test_qwen3-0.6B_fsdp_colocated_2xGPU.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ FSDP_ARGS=(
9393
--update-weights-buffer-size $((512 * 1024 * 1024)) # 512MB
9494
)
9595

96+
# Context Parallelism Arguments
97+
# Uncomment to enable CP with varlen support and data packing
98+
CP_ARGS=(
99+
# --enable-cp # Enable Context Parallelism
100+
# --ring-flash-atten-type llama3 # Use llama3 ring attention implementation
101+
)
102+
96103
# launch the master node of ray in container
97104
ray start --head --node-ip-address 127.0.0.1 --num-gpus 2 --disable-usage-stats
98105

@@ -112,4 +119,6 @@ ray job submit --address="http://127.0.0.1:8265" \
112119
${OPTIMIZER_ARGS[@]} \
113120
${GRPO_ARGS[@]} \
114121
${SGLANG_ARGS[@]} \
115-
${WANDB_ARGS[@]}
122+
${WANDB_ARGS[@]} \
123+
${FSDP_ARGS[@]} \
124+
${CP_ARGS[@]}

0 commit comments

Comments
 (0)