Skip to content

[FSDP] Support Context parallelism for FSDP using ring-flash-attn#467

Merged
zhuzilin merged 18 commits intoTHUDM:mainfrom
PopSoda2002:feat/support_normal_cp
Nov 16, 2025
Merged

[FSDP] Support Context parallelism for FSDP using ring-flash-attn#467
zhuzilin merged 18 commits intoTHUDM:mainfrom
PopSoda2002:feat/support_normal_cp

Conversation

@PopSoda2002
Copy link
Collaborator

@PopSoda2002 PopSoda2002 commented Oct 12, 2025

Try to solve #294 using ring-flash-attn with datapacking

How to use

  1. pip install ring-flash-attn

Detailed Development Tracking

Result

Compare to main branch
image
image

Almost match with main branch
script changed:

# Context Parallelism Arguments
# Context Parallelism enables training with longer sequences by splitting sequences across GPUs
# This uses ring-flash-attention library for efficient attention computation
CP_ARGS=(
   --enable-cp                        # Enable Context Parallelism
   --ring-flash-atten-type llama3     # Use llama3 ring attention implementation (recommended for varlen)
   --context-parallel-size 2
)

@Williamren97 Williamren97 force-pushed the feat/support_normal_cp branch 2 times, most recently from a8ad2ec to 4948643 Compare October 12, 2025 13:41
@PopSoda2002 PopSoda2002 force-pushed the feat/support_normal_cp branch from 88afbd4 to f32b57f Compare November 1, 2025 22:21
@PopSoda2002 PopSoda2002 closed this Nov 2, 2025
@PopSoda2002 PopSoda2002 force-pushed the feat/support_normal_cp branch from 653578b to 6d01709 Compare November 2, 2025 09:14
@PopSoda2002 PopSoda2002 reopened this Nov 2, 2025
@PopSoda2002 PopSoda2002 marked this pull request as ready for review November 2, 2025 09:23
@PopSoda2002 PopSoda2002 force-pushed the feat/support_normal_cp branch from 35c2578 to 346abce Compare November 5, 2025 05:44
@zhaochenyang20
Copy link
Collaborator

Really great process and hope you learned a lot from this process. Shall we post a blog on the journey of CP in awesome-ml-sys? Also, your job opportunity shall always be the first. Really glad to see your resolution and great improvement. Hope for the best.

@zhaochenyang20
Copy link
Collaborator

🐂🍺

@PopSoda2002 PopSoda2002 force-pushed the feat/support_normal_cp branch from a964ed1 to a41f7c1 Compare November 15, 2025 02:25
world_size = dist.get_world_size()
rank = dist.get_rank()

if self.args.enable_cp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use the self.args.context_parallel_size directly. And we don't need to separate the mesh init for cp size > 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed,thanks

)
logits = self.model(**model_args).logits.squeeze(0)
if self.args.enable_cp:
log_probs_result, entropy_result = get_chunked_logp_and_entropy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please merge the with and without cp implemtation into one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

rank = dist.get_rank()

rollout_data = process_rollout_data(self.args, rollout_data_ref, rank, world_size)
dp_rank = self.dp_rank if self.args.enable_cp else rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always use dp_rank.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

).logits.squeeze(0)

# Gather logits from all CP ranks if CP is enabled (with gradient support)
if self.args.enable_cp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar as the comment above, please merge the 2 branch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if tokens[idx].item() == 0:
pad_length += 1
else:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can re-calculate the pad length instead of a for loop

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@PopSoda2002
Copy link
Collaborator Author

New result also matches between cp 1 and cp 2:
image

@zhuzilin
Copy link
Contributor

Thank you so much for this!

@zhuzilin zhuzilin merged commit 6d3b33f into THUDM:main Nov 16, 2025
@zhaochenyang20
Copy link
Collaborator

牛逼!

llltttwww pushed a commit to llltttwww/slime that referenced this pull request Nov 30, 2025
Yangruipis pushed a commit to rednote-ai/slime that referenced this pull request Feb 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants