-
Notifications
You must be signed in to change notification settings - Fork 47
Low Precision Recipes for LLama3-8B #1178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/ok to test 0c67d7b |
Signed-off-by: Aditya Vavre <[email protected]>
|
|
||
| def llama3_8b_low_precision_pretrain_config(mixed_precision_recipe: str, **user_kwargs: Unpack[Llama3CommonKwargs]) -> ConfigContainer: | ||
| def llama3_8b_low_precision_pretrain_config( | ||
| mixed_precision_recipe: str, **user_kwargs: Unpack[Llama3CommonKwargs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: what is the advantage of having a dedicated function for this vs. users passing in or overriding the default precision setting to use the ones listed below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted to make it clear to users that these recipes are well tested for convergence and specify the hyperparams used in long convergence testing. The default params specified for bf16 precision don't seem to work well with low precision. For example, we found that AdamW epsilon has a significant effect on convergence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think about exposing these as individual recipe configs? e.g. llama3_8b_bf16_mxfp8_mixed_pretrain_config, llama3_8b_bf16_fp8_cs_mixed_pretrain_config, llama3_8b_bf16_nvfp4_mixed_pretrain_config
this might be clearer for which low-precision recipes have been tested for long convergence, and can be easier to follow when hyperparams vary across recipes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from the precision configs, there are no other differences in the recipes. Hence I decided to just define one function for all. To be very clear with with which recipes have been tested I have an assert statement inside the function which checks if its one of FP8CS, MXFP8, and NVFP4 or not. @cuichenx / @yaoyu-33 any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe in the assert statement instead of saying "Invalid Recipe" I can just say the recipe has not been verified for long convergence?
Signed-off-by: Aditya Vavre <[email protected]>
|
/ok to test 8ab27c8 |
Signed-off-by: Aditya Vavre <[email protected]>
|
/ok to test 402563b |
This PR adds recommended args for LLama3-8B low precision recipes. Following recipes are verified (long convergence 1T tokens):
Attaching a short convergence loss curve for reference:
