Skip to content

Conversation

@mesakhcienet
Copy link
Contributor

@mesakhcienet mesakhcienet commented Aug 7, 2025

Description

Updated : 2025-11-07
The previous PR and CL had updated src/Maxtext/layers/deepseek.py changes into main branch. We need to update deepseek.py related unit test and usage. We also fix some code lint.

Migrate deepseek to use nnx module.

Tests

We use xpk to create tpu cluster and assign workload

Environment

Cluster

TPU type : v6e-32
Number of slices : 4
GKE version : 1.31.11-gke.1036000
Base Image : us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.0-rev1

Image

Build Image command :

bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.0-rev1

Test command

Run Xpk command :

python xpk.py workload create --cluster $CLUSTER_NAME \
 --base-docker-image mesa_maxtext_base_image \
 --workload=$WORKLOAD \
 --tpu-type=${TPU_TYPE} --num-slices=${NUM_SLICES} --max-restarts=10 \
   --on-demand \
  --script-dir=$MAXTEXT_SCRIPT_DIR --command \
   "python3 -m MaxText.train MaxText/configs/base.yml \
   run_name=runner_direct_${idx}  \
   base_output_directory=${BASE_OUTPUT_DIRECTORY}  \
   model_name=${MODEL_NAME} \
   dataset_type=synthetic  \
   async_checkpointing=false  \
   per_device_batch_size=1 \
   metrics_file='metrics.txt'  \
   steps=15"

Log

  • Before migration (from main branch) : link
  • After migration (train from scratch] : link
  • After migration (train with loading previous model from main branch / before migration, with steps argument sets from 15 to 50 ): link

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch 5 times, most recently from 13ee02d to b85f426 Compare August 11, 2025 08:02
@mesakhcienet mesakhcienet marked this pull request as ready for review August 11, 2025 08:40
@mesakhcienet mesakhcienet requested a review from RissyRan August 11, 2025 08:44
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch 2 times, most recently from 3f6c31d to fea3b8e Compare August 19, 2025 06:25
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch 2 times, most recently from dfcbb23 to c6497df Compare September 10, 2025 02:12
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from c6497df to 28961fb Compare September 23, 2025 01:27
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from 28961fb to 34a8492 Compare September 23, 2025 01:49
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mesakhcienet. Could you please run train (you already have this), decode, and then maxengine/jetstream (with profiles collected for maxengine/jetstream)? Similar to #2088. I can help with profile collection offline if you want, just let me know

@Shuang-cnt
Copy link

Shuang-cnt commented Oct 8, 2025

Results for Train and Decode .

Train

Cluster: v6e-32
Num_slices: 4

Command:

python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME \
 --base-docker-image=maxtext_base_image \
 --workload=$WORKLOAD \
 --tpu-type=${TPU_TYPE} --num-slices=${NUM_SLICES} --max-restarts=10 \
   --on-demand \
  --script-dir=$MAXTEXT_SCRIPT_DIR --command \
   "python3 -m MaxText.train MaxText/configs/base.yml \
   run_name=runner_direct_${idx}  \
   base_output_directory=${BASE_OUTPUT_DIRECTORY}  \
   model_name=${MODEL_NAME} \
   dataset_type=synthetic  \
   async_checkpointing=false  \
   per_device_batch_size=1 \
   metrics_file='metrics.txt'  \
   steps=15"

Train diff
Restore diff

Before Train_log
After Train_log
Before restore_log
After restore_log

Decode

V6e-8 Lite

Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=deepseek3-test \
    tokenizer_path=assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    scan_layers=false \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=1 \
    ici_autoregressive_parallelism=-1 \
    max_prefill_predict_length=128 \
    max_target_length=256 \
    prompt="I love to" \
    attention=dot_product \
    mla_naive_kvcache=False

Diff

Before: Using (GB) 136.71 / 1417.33 (9.645601%) --> Available:1280.62
After: Using (GB) 128.16 / 1417.33 (9.042354%) --> Available:1289.16

Before Decode_log
After Decode_log

Note: found an error with "After Decode" Error message as below:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/yusharon_google_com/maxtext/src/MaxText/decode.py", line 208, in <module>
    app.run(main)
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext/src/MaxText/decode.py", line 186, in main
    output = tokenizer_model.decode(results)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/jetstream/engine/token_utils.py", line 492, in decode
    return self.tokenizer.decode(token_ids)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/jetstream/external_tokenizers/llama3/llama3_tokenizer.py", line 177, in decode
    return self.model.decode(cast(List[int], t))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/tiktoken/core.py", line 284, in decode
    return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'Invalid token for decoding: 128268'

@Shuang-cnt
Copy link

Shuang-cnt commented Oct 8, 2025

JetStream (Deepseek2-16b)

V6e-8 Lite

Command:

#run maxengine
python -m MaxText.maxengine_server \
      MaxText/configs/base.yml \
      mla_naive_kvcache=false max_prefill_predict_length=1024 per_device_batch_size=1 model_name=deepseek2-16b  async_checkpointing=false ici_tensor_parallelism=4 max_target_length=2048 ici_fsdp_parallelism=1 ici_autoregressive_parallelism=2 megablox=False sparse_matmul=False scan_layers=False attention=dot_product tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V2-Lite load_parameters_path=gs://agagik-us/deepseek/maxtext_checkpoints/ds_16b_unscanned_new_3/unscanned_chkpt/checkpoints/0/items scan_layers=False hf_access_token=$HF_TOKEN enable_jax_profiler=True
#run benchmark_serving.py from jetstream
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py   --tokenizer=deepseek-ai/DeepSeek-V2-Lite   --num-prompts 5000   --dataset mmlu   --dataset-path mmlu/data/test/   --request-rate 0   --warmup-mode sampled   --save-request-outputs   --run-eval True   --use-hf-tokenizer True
## capture port 9999, for 6000ms
 python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link

before Jetstream_log
after Jetstream_log
Jetstream Diff

before xprof
after xprof

@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch 2 times, most recently from 06b58cf to 985394d Compare October 14, 2025 08:26
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from 40a946d to 16857b2 Compare October 16, 2025 03:28
@mesakhcienet
Copy link
Contributor Author

mesakhcienet commented Oct 16, 2025

Results for Train and Decode .

Train

Cluster: v6e-32 Num_slices: 4

Command:

python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME \
 --base-docker-image=maxtext_base_image \
 --workload=$WORKLOAD \
 --tpu-type=${TPU_TYPE} --num-slices=${NUM_SLICES} --max-restarts=10 \
   --on-demand \
  --script-dir=$MAXTEXT_SCRIPT_DIR --command \
   "python3 -m MaxText.train MaxText/configs/base.yml \
   run_name=runner_direct_${idx}  \
   base_output_directory=${BASE_OUTPUT_DIRECTORY}  \
   model_name=${MODEL_NAME} \
   dataset_type=synthetic  \
   async_checkpointing=false  \
   per_device_batch_size=1 \
   metrics_file='metrics.txt'  \
   steps=15"

Train diff Restore diff

Before Train_log After Train_log Before restore_log After restore_log

Decode

V6e-8 Lite

Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=deepseek3-test \
    tokenizer_path=assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    scan_layers=false \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=1 \
    ici_autoregressive_parallelism=-1 \
    max_prefill_predict_length=128 \
    max_target_length=256 \
    prompt="I love to" \
    attention=dot_product \
    mla_naive_kvcache=False

Diff

Before: Using (GB) 136.71 / 1417.33 (9.645601%) --> Available:1280.62 After: Using (GB) 128.16 / 1417.33 (9.042354%) --> Available:1289.16

Before Decode_log After Decode_log

Note: found an error with "After Decode" Error message as below:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/yusharon_google_com/maxtext/src/MaxText/decode.py", line 208, in <module>
    app.run(main)
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext/src/MaxText/decode.py", line 186, in main
    output = tokenizer_model.decode(results)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/jetstream/engine/token_utils.py", line 492, in decode
    return self.tokenizer.decode(token_ids)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/jetstream/external_tokenizers/llama3/llama3_tokenizer.py", line 177, in decode
    return self.model.decode(cast(List[int], t))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/tiktoken/core.py", line 284, in decode
    return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'Invalid token for decoding: 128268'

In case of Decoder error KeyError: 'Invalid token for decoding: 128268', I have tried to rebase and rerun it. The problem isn't showing anymore. Can you help us to validate my output ? Thank you.

https://paste.googleplex.com/6495477373730816

@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch 5 times, most recently from 7f470ed to f39635e Compare October 21, 2025 02:03
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from f39635e to 12f3b28 Compare November 3, 2025 01:14
@bvandermoon
Copy link
Collaborator

Results for Train and Decode .

Train

Cluster: v6e-32 Num_slices: 4
Command:

python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME \
 --base-docker-image=maxtext_base_image \
 --workload=$WORKLOAD \
 --tpu-type=${TPU_TYPE} --num-slices=${NUM_SLICES} --max-restarts=10 \
   --on-demand \
  --script-dir=$MAXTEXT_SCRIPT_DIR --command \
   "python3 -m MaxText.train MaxText/configs/base.yml \
   run_name=runner_direct_${idx}  \
   base_output_directory=${BASE_OUTPUT_DIRECTORY}  \
   model_name=${MODEL_NAME} \
   dataset_type=synthetic  \
   async_checkpointing=false  \
   per_device_batch_size=1 \
   metrics_file='metrics.txt'  \
   steps=15"

Train diff Restore diff
Before Train_log After Train_log Before restore_log After restore_log

Decode

V6e-8 Lite
Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=deepseek3-test \
    tokenizer_path=assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    scan_layers=false \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=1 \
    ici_autoregressive_parallelism=-1 \
    max_prefill_predict_length=128 \
    max_target_length=256 \
    prompt="I love to" \
    attention=dot_product \
    mla_naive_kvcache=False

Diff
Before: Using (GB) 136.71 / 1417.33 (9.645601%) --> Available:1280.62 After: Using (GB) 128.16 / 1417.33 (9.042354%) --> Available:1289.16
Before Decode_log After Decode_log
Note: found an error with "After Decode" Error message as below:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/yusharon_google_com/maxtext/src/MaxText/decode.py", line 208, in <module>
    app.run(main)
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext/src/MaxText/decode.py", line 186, in main
    output = tokenizer_model.decode(results)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/jetstream/engine/token_utils.py", line 492, in decode
    return self.tokenizer.decode(token_ids)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/jetstream/external_tokenizers/llama3/llama3_tokenizer.py", line 177, in decode
    return self.model.decode(cast(List[int], t))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yusharon_google_com/maxtext_venv/lib/python3.12/site-packages/tiktoken/core.py", line 284, in decode
    return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'Invalid token for decoding: 128268'

In case of Decoder error KeyError: 'Invalid token for decoding: 128268', I have tried to rebase and rerun it. The problem isn't showing anymore. Can you help us to validate my output ? Thank you.

https://paste.googleplex.com/6495477373730816

The output looks reasonable to me. Do you know what changed for it to start working?

@bvandermoon
Copy link
Collaborator

JetStream (Deepseek2-16b)

V6e-8 Lite

Command:

#run maxengine
python -m MaxText.maxengine_server \
      MaxText/configs/base.yml \
      mla_naive_kvcache=false max_prefill_predict_length=1024 per_device_batch_size=1 model_name=deepseek2-16b  async_checkpointing=false ici_tensor_parallelism=4 max_target_length=2048 ici_fsdp_parallelism=1 ici_autoregressive_parallelism=2 megablox=False sparse_matmul=False scan_layers=False attention=dot_product tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V2-Lite load_parameters_path=gs://agagik-us/deepseek/maxtext_checkpoints/ds_16b_unscanned_new_3/unscanned_chkpt/checkpoints/0/items scan_layers=False hf_access_token=$HF_TOKEN enable_jax_profiler=True
#run benchmark_serving.py from jetstream
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py   --tokenizer=deepseek-ai/DeepSeek-V2-Lite   --num-prompts 5000   --dataset mmlu   --dataset-path mmlu/data/test/   --request-rate 0   --warmup-mode sampled   --save-request-outputs   --run-eval True   --use-hf-tokenizer True
## capture port 9999, for 6000ms
 python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link

before Jetstream_log after Jetstream_log Jetstream Diff

before xprof after xprof

These profile LGTM, thanks @mesakhcienet

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM. Can you please check the PR test failures?

@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from 12f3b28 to d486bdb Compare November 4, 2025 03:46
@mesakhcienet
Copy link
Contributor Author

Generally LGTM. Can you please check the PR test failures?

Thank you, I just fix the error.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mesakhcienet, generally looking good.

@RissyRan could you please also take a look?

from flax import linen as nn
from flax import nnx

from MaxText.layers import initializers, linears, moe, nnx_wrappers, quantizations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is layers needed here?

Copy link
Contributor Author

@mesakhcienet mesakhcienet Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I followed the existing pattern for importing these modules. The imports are necessary because the modules are in thesrc/MaxText/layers subfolder and are not exposed through the src/MaxText/__init__.py file.

I will fix the import order, thank you.

image

Comment on lines -97 to +101
deepseek.DeepSeekDenseLayer(config, mesh=self._mesh, quant=self.quant),
deepseek.DeepSeekMoELayer(config, mesh=self._mesh, quant=self.quant),
deepseek.DeepSeekDenseLayerToLinen(config, mesh=self._mesh, quant=self.quant, model_mode=model_mode, rngs=rngs),
deepseek.DeepSeekMoELayerToLinen(config, mesh=self._mesh, quant=self.quant, model_mode=model_mode, rngs=rngs),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be changed here?

Copy link
Contributor Author

@mesakhcienet mesakhcienet Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use nnx module of DeepSeekDenseLayer or DeepSeekMoELayer instead of linen converted DeepSeekDenseLayerToLinen or DeepSeekMoELayerToLinen, these parts will fail. Apologize I didn't screenshot the error, I remember that the error comes from the unit test.

Also, when calling this function, I believe the model should be on linen instead of nnx since .apply function is not a function of nnx module.

@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from 8a3fa7e to 244d205 Compare November 7, 2025 02:41
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch 2 times, most recently from 53f7584 to 7a68e07 Compare November 7, 2025 08:34
@mesakhcienet mesakhcienet changed the title feat: migrate deepseek to nnx core: update deepseek unit test Nov 7, 2025
@mesakhcienet mesakhcienet changed the title core: update deepseek unit test fix: modify deepseek quantization and unit test Nov 7, 2025
@mesakhcienet mesakhcienet force-pushed the feat/migrate-deepseek-to-nnx branch from 458a0e8 to bb0fff6 Compare November 7, 2025 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants