Skip to content

[Refactor] fused kernel in forward #1624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

mingruimingrui
Copy link
Contributor

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

Shifts fused_linear_for_ppo into model.forward for FSDP

High-Level Design

Self explaining

Specific Changes

  • Update monkey patch to return log_probs and entropy instead of last_hidden_state.

API

No changes

Usage Example

actor_rollout_ref.model.use_fused_kernels=True

Test

image

Additional Info.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if necessary.

@mingruimingrui
Copy link
Contributor Author

Hi @ETOgaosion, I've made a fix for the fused kernel bug that we experienced when using 1 GPU. It's strange that we didn't encounter this when using more than 1.

@ETOgaosion
Copy link
Collaborator

Seems that currently there is no CI enables fused_kernel, could you enable some tests and disable in other tests to test both cases?

@mingruimingrui
Copy link
Contributor Author

@ETOgaosion My bad, just added 😂

@ETOgaosion
Copy link
Collaborator

@mingruimingrui #1634 fixed the CI, could you rebase main and retry please?

@mingruimingrui
Copy link
Contributor Author

Done, trouble you to trigger CI again

@ETOgaosion ETOgaosion merged commit 4779f26 into volcengine:main May 24, 2025
36 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Running error on PPO training on GSM8K dataset example. "RuntimeError: setStorage:"
2 participants