-
Notifications
You must be signed in to change notification settings - Fork 419
fix: modify deepseek quantization and unit test #2098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: modify deepseek quantization and unit test #2098
Conversation
13ee02d to
b85f426
Compare
3f6c31d to
fea3b8e
Compare
dfcbb23 to
c6497df
Compare
c6497df to
28961fb
Compare
28961fb to
34a8492
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mesakhcienet. Could you please run train (you already have this), decode, and then maxengine/jetstream (with profiles collected for maxengine/jetstream)? Similar to #2088. I can help with profile collection offline if you want, just let me know
|
Results for Train and Decode . TrainCluster: v6e-32 Command: Before Train_log DecodeV6e-8 Lite Command: Before: Before Decode_log Note: found an error with "After Decode" Error message as below: |
JetStream (Deepseek2-16b)V6e-8 Lite Command: |
06b58cf to
985394d
Compare
40a946d to
16857b2
Compare
In case of Decoder error |
7f470ed to
f39635e
Compare
f39635e to
12f3b28
Compare
The output looks reasonable to me. Do you know what changed for it to start working? |
These profile LGTM, thanks @mesakhcienet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM. Can you please check the PR test failures?
12f3b28 to
d486bdb
Compare
Thank you, I just fix the error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mesakhcienet, generally looking good.
@RissyRan could you please also take a look?
src/MaxText/layers/deepseek.py
Outdated
| from flax import linen as nn | ||
| from flax import nnx | ||
|
|
||
| from MaxText.layers import initializers, linears, moe, nnx_wrappers, quantizations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: is layers needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| deepseek.DeepSeekDenseLayer(config, mesh=self._mesh, quant=self.quant), | ||
| deepseek.DeepSeekMoELayer(config, mesh=self._mesh, quant=self.quant), | ||
| deepseek.DeepSeekDenseLayerToLinen(config, mesh=self._mesh, quant=self.quant, model_mode=model_mode, rngs=rngs), | ||
| deepseek.DeepSeekMoELayerToLinen(config, mesh=self._mesh, quant=self.quant, model_mode=model_mode, rngs=rngs), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to be changed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use nnx module of DeepSeekDenseLayer or DeepSeekMoELayer instead of linen converted DeepSeekDenseLayerToLinen or DeepSeekMoELayerToLinen, these parts will fail. Apologize I didn't screenshot the error, I remember that the error comes from the unit test.
Also, when calling this function, I believe the model should be on linen instead of nnx since .apply function is not a function of nnx module.
8a3fa7e to
244d205
Compare
53f7584 to
7a68e07
Compare
458a0e8 to
bb0fff6
Compare

Description
Updated : 2025-11-07
The previous PR and CL had updated
src/Maxtext/layers/deepseek.pychanges into main branch. We need to updatedeepseek.pyrelated unit test and usage. We also fix some code lint.Migrate deepseek to use nnx module.Tests
We use xpk to create tpu cluster and assign workload
Environment
Cluster
TPU type : v6e-32
Number of slices : 4
GKE version : 1.31.11-gke.1036000
Base Image : us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.0-rev1
Image
Build Image command :
Test command
Run Xpk command :
Log
stepsargument sets from15to50): linkChecklist
Before submitting this PR, please make sure (put X in square brackets):