-
Notifications
You must be signed in to change notification settings - Fork 419
Migrate Gpt3 to NNX. #2062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Migrate Gpt3 to NNX. #2062
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome to see this @hsuan-lun-chiang. Could you please add before/after logs for the training command shown in your test section? It would be great to see that perf is the same before/after here
1df6836 to
b1f7bd8
Compare
Update the description with the before/after logs, thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cgarciae can you please take a look at this also?
b1f7bd8 to
f779923
Compare
cacd42b to
4d57ee8
Compare
875a552 to
55ee6ef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hsuan-lun-chiang. Could you please run train (you already have this), decode, and then maxengine/jetstream (with profiles collected for maxengine/jetstream)? Similar to #2088. I can help with profile collection offline if you want, just let me know
Results for Train and JetstreamTest EnvironmentMachine Type: TrainExecuted Command: Results: Maxengine / JetstreamStep 1: Launch MaxengineStep 2: Execute JetstreamStep 3: Output CollectionResults: DecodeExecuted Command: For both before and after migration, we received the same error: |
691cf2f to
edfe0a9
Compare
src/MaxText/layers/gpt3.py
Outdated
| self.kv_cache_layer = kvcache.KVCache( | ||
| max_prefill_length=self.max_prefill_predict_length, | ||
| max_target_length=self.max_target_length, | ||
| batch=feature_dim[0], | ||
| key_seq_len=feature_dim[1], | ||
| value_seq_len=feature_dim[1], | ||
| key_heads=self.num_heads, | ||
| value_heads=self.num_heads, | ||
| key_head_size=self.head_dim, | ||
| value_head_size=self.head_dim, | ||
| dtype=self.dtype, | ||
| kv_quant=self.kv_quant, | ||
| prefill_cache_axis_order=prefill_cache_axis_order, | ||
| ar_cache_axis_order=ar_cache_axis_order, | ||
| model_mode=model_mode, | ||
| rngs=self.rngs, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bvandermoon,
During testing, we discovered that the decode function in Gpt3 failed with an AssertionError - assert prefill_kv_cache on the main branch, as indicated by the logs shared by @ecnal-cienet earlier. To address this, I've patched the code by adding the KVCahe to Gpt3MultiHeadAttention. I've also updated the PR description to reflect these changes.
However, since we currently lack a reference model, we're unable to verify the results with certainty.
It would be great if you could review these changes and provide any feedback. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hsuan-lun-chiang. Does the error show up on main as well? I am wondering if decode has not been supported. If so, we don't need to add the KVCache here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the error also show up in the main branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hsuan-lun-chiang. Can we remove the KVCache portion from this PR? That way we can just focus on ensuring the before/after match for the migration. It would be good to add it back as a follow-up though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?
Hi @bvandermoon , command: |
Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow |
edfe0a9 to
efaaa20
Compare
3db6d5d to
51449d6
Compare
Thank you! I also rebased the code to the latest version. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hsuan-lun-chiang. Looking good, just have a few comments
src/MaxText/layers/gpt3.py
Outdated
| self.kv_cache_layer = kvcache.KVCache( | ||
| max_prefill_length=self.max_prefill_predict_length, | ||
| max_target_length=self.max_target_length, | ||
| batch=feature_dim[0], | ||
| key_seq_len=feature_dim[1], | ||
| value_seq_len=feature_dim[1], | ||
| key_heads=self.num_heads, | ||
| value_heads=self.num_heads, | ||
| key_head_size=self.head_dim, | ||
| value_head_size=self.head_dim, | ||
| dtype=self.dtype, | ||
| kv_quant=self.kv_quant, | ||
| prefill_cache_axis_order=prefill_cache_axis_order, | ||
| ar_cache_axis_order=ar_cache_axis_order, | ||
| model_mode=model_mode, | ||
| rngs=self.rngs, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hsuan-lun-chiang. Can we remove the KVCache portion from this PR? That way we can just focus on ensuring the before/after match for the migration. It would be good to add it back as a follow-up though
d12fdf9 to
350b7ec
Compare
350b7ec to
966ef0a
Compare
Description
This PR
Including the following classes:
decoder_positionstoint32- The trainable position embedding layer (Embed layer) requires integer indices for its lookup, but decoder_positions was passed as a float. This casts it to int32 to prevent the ValueError.Tests
Ran train command to train gpt3-6b for 10 steps:
Logs:
Linen, before migration
NNX, after migration
Checklist
Before submitting this PR, please make sure (put X in square brackets):