Skip to content

[JAX API Update] Remove jax_spmd_mode from config #1160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from

Conversation

Steboss
Copy link
Contributor

@Steboss Steboss commented May 6, 2025

This refers to PR #1136

  • jax_spmd_mode is now obsolete
  • Tests on performance have been run for fuji-3B-v3-flash-attention, and results are still matching the previous implementation:
Metrics This PR implementation Previous AXLearn implementation
Tokens per sec per gpu 9288 8904
Seqs per sec per gpu 2.26 2.17
Average time step 0.88 0.91
TFLOPS per sec per GPU 218.80 209.74

@apghml if you could review this please. thank you

@Steboss Steboss requested review from ruomingp, markblee and a team as code owners May 6, 2025 08:10
@dmarx
Copy link

dmarx commented May 19, 2025

Does this PR also need the validator? #1126 (comment)

@Steboss
Copy link
Contributor Author

Steboss commented May 19, 2025

Hey @dmarx
This PR does not need the validator.
I spotted a few more bugs wrt JAX versions, but I didn't push any PR yet. I will open them for reference for @matthew-e-hopkins

@apghml
Copy link
Contributor

apghml commented May 19, 2025

@Steboss Is this PR ready to merge?

@Steboss
Copy link
Contributor Author

Steboss commented May 19, 2025

@apghml yes it is :)

@apghml apghml added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@apghml apghml added this pull request to the merge queue May 20, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 20, 2025
@Steboss Steboss closed this May 21, 2025
@Steboss Steboss deleted the sbosisio/remove_spmd branch May 21, 2025 15:55
@Steboss
Copy link
Contributor Author

Steboss commented May 22, 2025

@apghml @dmarx
is it possible to get a log on why this merge has failed? thank you :)

@apghml
Copy link
Contributor

apghml commented May 22, 2025

I believe they fixed the issue recently. Can you merge the latest main branch and then we can retry?

@apghml
Copy link
Contributor

apghml commented May 22, 2025

Also, @Steboss I see many of your recent PRs have been closed. Are you moving them somewhere else? If so, can you share a link? Thanks!

@Steboss
Copy link
Contributor Author

Steboss commented May 22, 2025

@apghml I've mainly moved all the PRs that were dealing with jax.tree in a single one, namely this one #1207

@apghml
Copy link
Contributor

apghml commented May 22, 2025

That PR seems to be by a different user?

@Steboss
Copy link
Contributor Author

Steboss commented May 23, 2025

@apghml #1206 sorry

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants