File tree Expand file tree Collapse file tree 2 files changed +14
-1
lines changed
slime/backends/fsdp_utils Expand file tree Collapse file tree 2 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -160,6 +160,10 @@ def compute_log_prob(
160160 rollout_data = {f"{ store_prefix } log_probs" : []}
161161 with timer (f"{ store_prefix } log_probs" ) and torch .no_grad ():
162162 for batch in packed_batches :
163+ # Update cu_seqlens for CP before forward pass
164+ if self .args .enable_cp :
165+ self ._update_cp_cu_seqlens (batch )
166+
163167 with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
164168 model_args = {
165169 "input_ids" : batch ["tokens" ].unsqueeze (0 ),
Original file line number Diff line number Diff line change @@ -93,6 +93,13 @@ FSDP_ARGS=(
9393 --update-weights-buffer-size $(( 512 * 1024 * 1024 )) # 512MB
9494)
9595
96+ # Context Parallelism Arguments
97+ # Uncomment to enable CP with varlen support and data packing
98+ CP_ARGS=(
99+ # --enable-cp # Enable Context Parallelism
100+ # --ring-flash-atten-type llama3 # Use llama3 ring attention implementation
101+ )
102+
96103# launch the master node of ray in container
97104ray start --head --node-ip-address 127.0.0.1 --num-gpus 2 --disable-usage-stats
98105
@@ -112,4 +119,6 @@ ray job submit --address="http://127.0.0.1:8265" \
112119 ${OPTIMIZER_ARGS[@]} \
113120 ${GRPO_ARGS[@]} \
114121 ${SGLANG_ARGS[@]} \
115- ${WANDB_ARGS[@]}
122+ ${WANDB_ARGS[@]} \
123+ ${FSDP_ARGS[@]} \
124+ ${CP_ARGS[@]}
You can’t perform that action at this time.
0 commit comments