Skip to content

Commit 69ed0c5

Browse files
Merge pull request #2522 from CIeNET-International:feat/migrate-deepseek-split-batch-to-nnx
PiperOrigin-RevId: 829304686
2 parents 725d15c + b556bcd commit 69ed0c5

File tree

4 files changed

+413
-343
lines changed

4 files changed

+413
-343
lines changed

src/MaxText/layers/decoders.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,15 @@ def get_decoder_layers(self):
395395
return [mixtral.MixtralDecoderLayerToLinen]
396396
case DecoderBlockType.DEEPSEEK:
397397
if self.config.use_batch_split_schedule:
398-
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
398+
return [
399+
deepseek_batchsplit.DeepSeekDenseLayerToLinen,
400+
deepseek_batchsplit.DeepSeekMoELayerToLinen,
401+
]
399402
else:
400-
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
403+
return [
404+
deepseek.DeepSeekDenseLayerToLinen,
405+
deepseek.DeepSeekMoELayerToLinen,
406+
]
401407
case DecoderBlockType.GEMMA:
402408
return [gemma.GemmaDecoderLayerToLinen]
403409
case DecoderBlockType.GEMMA2:

0 commit comments

Comments
 (0)