-
Notifications
You must be signed in to change notification settings - Fork 76
Custom modeling for training #801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing most of my comments. Waiting for the final version including refactoring and tests to review.
unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k} | ||
|
||
model.tie_weights() | ||
# TODO: stopped here, start from here tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO: stopped here, start from here tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
for name, mod in not_initialized_submodules.items(): | ||
if isinstance(mod, GQAQKVColumnParallelLinear): | ||
# There is a bug in initialization for this module. | ||
# In any case, we will always have weights for this in the case of `from_pretrained`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we will have an issue when training from scratch, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because we wont call from_pretrained
in this case.
I can investigate this issue but it does not seem top priority in this specific context.
) | ||
|
||
@classmethod | ||
def from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it is much clearer now. I would personally have dropped more sections of code that correspond to:
- multiple models (unless I missed something we actually only support training single models in neuron),
- the less usual model deployment paradigms (depending on how/when weights are loaded), as I am not entirely sure we would support them anyway.
pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
is_local = os.path.isdir(pretrained_model_name_or_path) | ||
if is_local: | ||
if from_tf and os.path.isfile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support from_tf, from_flax ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the remaining artifacts for from_tf
and from_flax
.
archive_file = os.path.join( | ||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) | ||
) | ||
elif use_safetensors is not False and os.path.isfile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif use_safetensors is not False and os.path.isfile( | |
elif use_safetensors and os.path.isfile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not equivalent because use_safetensors
can be None
.
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, | ||
name="Thread-auto_conversion", | ||
).start() | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really want to support this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tested yet, but since it is supported in transformers, why not?
|
I suggest we handle the last point in another PR to avoid making this PR bigger than it already is. |
|
||
model.tie_weights() | ||
# TODO: stopped here, start from here tomorrow. | ||
if device_map is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we did not support device_map, so it should be None, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support a subset of features: device_map in [None, "xla", "cpu"]
.
I won't be able to review further until I get back, so trusting alvaro's review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the linear tests now fail. I think that should be fixed before merging this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot
What does this PR do?
Custom modeling code for training
Features
This PR adds support for custom modeling code for training.
Each custom modeling code can be added under
optimum/neuron/models/training
.Having a custom modeling allows us to implement Neuron specificities in a cleaner way than using dynamic patching.
It becomes easy to:
GQAQKVColumnParallelLinear
, useful with high TP sizes.In this PR we provide a first full custom implementation with Llama.
Model weight transformations
Because having a custom modeling code enables to change the vanilla Transformers implementation, we need a way to make sure that we can load checkpoints from Transformers, and that we can save checkpoints in the original format as well.
To do that we provide an API with the
ModelWeightTransformationSpec
classes.These classes represent the transformation compared to the vanilla Transformers implementation and are directly added in the modules containing these transformations.
For now two exist:
FusedLinearsSpec
: represents a transformation when multiple linear layers are fused into a single linear layer (possibly a parallel linear)GQAQKVColumnParallelLinearSpec
: represents the transformation of separate query, key, and value projections into a single GQAQKVColumnParalleLinear projection.Then during loading, saving and consolidation, we use these specs to make sure every weight matches with Transformers weights.
Known issues
Training example
Specs
meta-llama/Llama-3.2-3B-Instruct
databricks/databricks-dolly-15k
NeuronSFTTrainer
Loss curve
To be done in later PRs:
save_pretrained
as it was done forfrom_pretrained
in this PR.