Skip to content

Custom modeling for training #801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 74 commits into from
May 15, 2025
Merged

Custom modeling for training #801

merged 74 commits into from
May 15, 2025

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Mar 3, 2025

What does this PR do?

Custom modeling code for training

Features

This PR adds support for custom modeling code for training.
Each custom modeling code can be added under optimum/neuron/models/training.

Having a custom modeling allows us to implement Neuron specificities in a cleaner way than using dynamic patching.
It becomes easy to:

  • Fuse linear layers together for efficiency
  • Use custom linear layers such as GQAQKVColumnParallelLinear, useful with high TP sizes.
  • Use custom kernels, such as the flash attention kernel

In this PR we provide a first full custom implementation with Llama.

Model weight transformations

Because having a custom modeling code enables to change the vanilla Transformers implementation, we need a way to make sure that we can load checkpoints from Transformers, and that we can save checkpoints in the original format as well.

To do that we provide an API with the ModelWeightTransformationSpec classes.
These classes represent the transformation compared to the vanilla Transformers implementation and are directly added in the modules containing these transformations.

For now two exist:

  • FusedLinearsSpec: represents a transformation when multiple linear layers are fused into a single linear layer (possibly a parallel linear)
  • GQAQKVColumnParallelLinearSpec: represents the transformation of separate query, key, and value projections into a single GQAQKVColumnParalleLinear projection.

Then during loading, saving and consolidation, we use these specs to make sure every weight matches with Transformers weights.

Known issues

  • There seems to be an issue when saving a checkpoint for DP > 1 during training. After initial investigation, it seems to be a compiler bug, but it will require more work. I suggest to work on it on a another PR.

Training example

Specs

  • Model: meta-llama/Llama-3.2-3B-Instruct
  • Dataset: databricks/databricks-dolly-15k
  • Trainer: NeuronSFTTrainer
  • DP=4, TP=8
  • Gradient accumulation steps = 16 => Effective batch size = 4 x 16 = 64
  • Sequence length = 2048 with packing = True
  • 3 epochs
  • Learning rate = 5e-4, warmup ration = 0.3, lr scheduler type = "cosine"

Loss curve

W B Chart 24_04_2025 16_12_30

To be done in later PRs:

  • Support for PP
  • Support for LoRA
  • Refactor save_pretrained as it was done for from_pretrained in this PR.
  • Add test that tests overfitting

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing most of my comments. Waiting for the final version including refactoring and tests to review.

unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k}

model.tie_weights()
# TODO: stopped here, start from here tomorrow.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: stopped here, start from here tomorrow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for name, mod in not_initialized_submodules.items():
if isinstance(mod, GQAQKVColumnParallelLinear):
# There is a bug in initialization for this module.
# In any case, we will always have weights for this in the case of `from_pretrained`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we will have an issue when training from scratch, right ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because we wont call from_pretrained in this case.
I can investigate this issue but it does not seem top priority in this specific context.

)

@classmethod
def from_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it is much clearer now. I would personally have dropped more sections of code that correspond to:

  • multiple models (unless I missed something we actually only support training single models in neuron),
  • the less usual model deployment paradigms (depending on how/when weights are loaded), as I am not entirely sure we would support them anyway.

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support from_tf, from_flax ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the remaining artifacts for from_tf and from_flax.

archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif use_safetensors is not False and os.path.isfile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif use_safetensors is not False and os.path.isfile(
elif use_safetensors and os.path.isfile(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not equivalent because use_safetensors can be None.

kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
name="Thread-auto_conversion",
).start()
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to support this ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tested yet, but since it is supported in transformers, why not?

@michaelbenayoun
Copy link
Member Author

Thank you for this pull-request: massive contribution. The general organization of the files makes sense and I think the modelization_utils.py and transformation_utils.py files are in particular good starting points. I have requested a few changes that I think are important in this first iteration to:

  • make it clearer what is expected to be provided by someone adding a new model,
  • keep things simple in this first iteration (I personally think it is easier to add things we support later on instead of keeping placeholders).

Also, I think the basic features must be tested, and it was unclear to me by reading the pull-requests which parts are actually tested, apart from eager forward inference (and by the way we must include a flash_attention test since we support the option).

  • What is expected is simply adding the proper transformation specs, and inheriting from CustomModule.
  • In this PR we test:
    • The forward pass: we check the the forward pass from the original implementation and the one from the custom modeling produce the same outputs. We test different settings (eager attention, regular qkv, fused qkv, qkv gqa replication). The reason we do not test flash attention is because it does not match exactly for now. I was still able to train a model with flash attention though.

@michaelbenayoun
Copy link
Member Author

A lot of work contributed here, thanks! This will help adding new models in a more efficient way. I added few comments, but It mostly resumes to these requests:

  • Remove the dependency from transformers modeling code
  • Reuse and merge with existing code, avoid duplication for sharding and tests.
  • Consider adding a complete test that shows training works (overfitting?)

I suggest we handle the last point in another PR to avoid making this PR bigger than it already is.


model.tie_weights()
# TODO: stopped here, start from here tomorrow.
if device_map is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we did not support device_map, so it should be None, right ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support a subset of features: device_map in [None, "xla", "cpu"].

@dacorvo dacorvo dismissed their stale review April 30, 2025 15:14

I won't be able to review further until I get back, so trusting alvaro's review

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the linear tests now fail. I think that should be fixed before merging this.

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot

@tengomucho tengomucho merged commit 66d1977 into main May 15, 2025
8 of 9 checks passed
@tengomucho tengomucho deleted the custom_modeling_introduction branch May 15, 2025 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants