diff --git a/examples/deepseek_v3/config.py b/examples/deepseek_v3/config.py new file mode 100644 index 000000000..bb58573b2 --- /dev/null +++ b/examples/deepseek_v3/config.py @@ -0,0 +1,174 @@ +import dataclasses +from typing import TypedDict +from transformers.utils import logging +from collections.abc import Callable + +from flax.core.frozen_dict import FrozenDict + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + +class RopeScalingConf(TypedDict): + type: str + factor: int + mscale: float + beta_fast: int + beta_slow: int + mscale_all_dim: float + original_max_position_embeddings: int + + +@dataclasses.dataclass(unsafe_hash=True, eq=True) +class Config: + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = 'deepseek_v3' + keys_to_ignore_at_inference: tuple[str, ...] = ('past_key_values',) + vocab_size: int = 129280 + hidden_size: int = 7168 + intermediate_size: int = 18432 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 61 + num_nextn_predict_layers: int = 1 + num_attention_heads: int = 128 + num_key_value_heads: int = 128 + n_shared_experts: int = 1 + n_routed_experts: int = 256 + ep_size: int = 1 + routed_scaling_factor: float = 2.5 + kv_lora_rank: int = 512 + q_lora_rank: int | None = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = 'noaux_tc' + n_group: int = 8 + topk_group: int = 4 + num_experts_per_tok: int = 8 + moe_layer_freq: int = 1 + first_k_dense_replace: int = 3 + norm_topk_prob: bool = True + scoring_func: str = 'sigmoid' + aux_loss_alpha: float = 0.001 + seq_aux: bool = True + hidden_act: str | Callable = 'silu' + max_position_embeddings: int = 4096 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int = 0 + eos_token_id: int = 1 + pretraining_tp: int = 1 + tie_word_embeddings: bool = False + rope_theta: float = 10000.0 + rope_scaling: RopeScalingConf | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + + def __post_init__(self): + # for backward compatibility + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling is not None: + self.rope_scaling = FrozenDict(self.rope_scaling) diff --git a/examples/deepseek_v3/configs/base.py b/examples/deepseek_v3/configs/base.py new file mode 100644 index 000000000..e63b2aafc --- /dev/null +++ b/examples/deepseek_v3/configs/base.py @@ -0,0 +1,102 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base configuration for the model.""" + +import dataclasses +from typing import Literal + + +@dataclasses.dataclass +class ModelArgs: + """Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE + routing. + route_scale (float): Scaling factor for routing scores. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without + positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary + embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + world_size (int): World size. + rank (int): Rank. + block_size (int): Block size. + gemm_impl (Literal["bf16", "fp8"] | None): Implementation for GEMM + operations. + attn_impl (Literal["naive", "absorb"]): Implementation for attention + operations. + """ + + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + # moe + n_routed_experts: int = 64 + n_shared_experts: int = 2 + n_activated_experts: int = 6 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1.0 + # mla + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + # misc + world_size: int = 1 + rank: int = 0 + block_size: int = 128 + gemm_impl: Literal["bf16", "fp8"] | None = "bf16" + attn_impl: Literal["naive", "absorb"] = "absorb" diff --git a/examples/deepseek_v3/configs/config_16B.py b/examples/deepseek_v3/configs/config_16B.py new file mode 100644 index 000000000..dd46c6bf8 --- /dev/null +++ b/examples/deepseek_v3/configs/config_16B.py @@ -0,0 +1,40 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Configuration for the 16B model.""" + +from configs.base import ModelArgs + + +def get_config(): + """Returns the configuration for the model.""" + return ModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + n_routed_experts=64, + n_shared_experts=2, + n_activated_experts=6, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.707, + ) diff --git a/examples/deepseek_v3/configs/config_236B.py b/examples/deepseek_v3/configs/config_236B.py new file mode 100644 index 000000000..977602ca2 --- /dev/null +++ b/examples/deepseek_v3/configs/config_236B.py @@ -0,0 +1,41 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Configuration for the 236B model.""" + +from configs.base import ModelArgs + + +def get_config(): + """Returns the configuration for the model.""" + return ModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + n_routed_experts=160, + n_shared_experts=2, + n_activated_experts=6, + n_expert_groups=8, + n_limited_groups=3, + route_scale=16.0, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) diff --git a/examples/deepseek_v3/configs/config_671B.py b/examples/deepseek_v3/configs/config_671B.py new file mode 100644 index 000000000..ef6d9290b --- /dev/null +++ b/examples/deepseek_v3/configs/config_671B.py @@ -0,0 +1,43 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Configuration for the 671B model.""" + +from configs.base import ModelArgs + + +def get_config(): + """Returns the configuration for the model.""" + return ModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=3, + n_heads=128, + n_routed_experts=256, + n_shared_experts=1, + n_activated_experts=8, + n_expert_groups=8, + n_limited_groups=4, + route_scale=2.5, + score_func="sigmoid", + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + dtype="fp8", + ) diff --git a/examples/deepseek_v3/kernel.py b/examples/deepseek_v3/kernel.py new file mode 100644 index 000000000..ff5346ff6 --- /dev/null +++ b/examples/deepseek_v3/kernel.py @@ -0,0 +1,258 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Triton kernels for PyTorch.""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +import torch +import triton +import triton.language as tl + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized + values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors + will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each + program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and + its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for + quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous() + assert x.size(-1) % block_size == 0 + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +def act_quant_jax(x: jax.Array): + s = jnp.max(jnp.abs(x)) / 448.0 + y = x / s + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: + """Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. + Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions + are not 2. + """ + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: ( + triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE']), + ) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +def weight_dequant_jax(x: jax.Array, s: jax.Array): + y = x * s + return y + + +fp8_gemm_configs = [ + triton.Config( + {'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm( + a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor +): + """Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be + contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must + be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c + + +def fp8_gemm_jax(a: jax.Array, a_s: jax.Array, b: jax.Array, b_s: jax.Array): + a = a.astype(jnp.float32) + b = b.astype(jnp.float32) + a_s = a_s.astype(jnp.float32) + b_s = b_s.astype(jnp.float32) + a = a / a_s + b = b / b_s + c = jnp.dot(a, b) + return c diff --git a/examples/deepseek_v3/kernel_pytorch.py b/examples/deepseek_v3/kernel_pytorch.py new file mode 100644 index 000000000..253c8882d --- /dev/null +++ b/examples/deepseek_v3/kernel_pytorch.py @@ -0,0 +1,235 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Triton kernels for PyTorch.""" + +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized + values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors + will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each + program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and + its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for + quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous() + assert x.size(-1) % block_size == 0 + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: + """Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. + Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions + are not 2. + """ + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: ( + triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE']), + ) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +fp8_gemm_configs = [ + triton.Config( + {'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm( + a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor +): + """Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be + contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must + be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c diff --git a/examples/deepseek_v3/model.py b/examples/deepseek_v3/model.py new file mode 100644 index 000000000..066390411 --- /dev/null +++ b/examples/deepseek_v3/model.py @@ -0,0 +1,1347 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""JAX DeepSeek model.""" + +import math +import warnings +from collections.abc import Callable + +import jax +import jax.numpy as jnp +from flax import nnx +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from config import Config +from einop import einop +import einx + + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap( + _prepare_4d_causal_attention_mask + ) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'Config' + + +class RMSNorm(nnx.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nnx.Param(jnp.ones((hidden_size,))) + self.variance_epsilon = eps + + def __call__(self, hidden_states: jax.Array): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.astype(jnp.float32) + variance = jnp.pow(hidden_states, 2).mean(-1, keepdims=True) + hidden_states = hidden_states * jax.lax.rsqrt( + variance + self.variance_epsilon + ) + return self.weight.value * hidden_states.astype(input_dtype) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +): + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func: jax.Array = (jnp.arange(dim, dtype=jnp.float32) - min) / ( + max - min + ) + ramp_func = jnp.clip(linear_func, 0, 1) + return ramp_func + +class Buffer(nnx.Variable[nnx.A]): + pass + +class YarnRotaryEmbedding(nnx.Module): + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: int = 0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq: jax.Array = 1.0 / ( + self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim) + ) + self.inv_freq = Buffer(inv_freq) + + # Build here to make `torch.jit.trace` work. + seq_len = max_position_embeddings + dtype = jnp.float32 + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).astype( + dtype=jnp.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.inv_freq = Buffer(inv_freq) + + t = jnp.arange(seq_len, dtype=jnp.float32) + + freqs = jnp.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = jnp.concatenate((freqs, freqs), axis=-1) + self.cos_cached = Buffer((jnp.cos(emb) * _mscale).astype(dtype)) + self.sin_cached = Buffer((jnp.sin(emb) * _mscale).astype(dtype)) + self.max_seq_len_cached = None + + def __call__(self, x: jax.Array, seq_len: int | None = None): + return ( + self.cos_cached.value[:seq_len].astype(dtype=x.dtype), + self.sin_cached.value[:seq_len].astype(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return jnp.concatenate((-x2, x1), axis=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MLP(nnx.Module): + def __init__( + self, + config: Config, + hidden_size: int | None = None, + intermediate_size: int | None = None, + *, + rngs: nnx.Rngs, + ): + super().__init__() + self.config = config + self.hidden_size = ( + config.hidden_size if hidden_size is None else hidden_size + ) + self.intermediate_size = ( + config.intermediate_size + if intermediate_size is None + else intermediate_size + ) + + self.gate_proj = nnx.Linear( + self.hidden_size, self.intermediate_size, use_bias=False, rngs=rngs + ) + self.up_proj = nnx.Linear( + self.hidden_size, self.intermediate_size, use_bias=False, rngs=rngs + ) + self.down_proj = nnx.Linear( + self.intermediate_size, self.hidden_size, use_bias=False, rngs=rngs + ) + self.act_fn: Callable[[jax.Array], jax.Array] = ( + getattr(nnx, config.hidden_act) + if isinstance(config.hidden_act, str) + else config.hidden_act + ) + + def __call__(self, x: jax.Array): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def linear(x: jax.Array, weight: jax.Array, bias: jax.Array | None = None): + # y = jnp.dot(x, weight) + y = einx.dot('... j, j k -> ... k', x, weight) + if bias is not None: + y = einx.add('... k, k -> ... k', y, bias) + return y + + +class MoEGate(nnx.Module): + def __init__(self, config: Config, *, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nnx.Param(jnp.empty((self.gating_dim, self.n_routed_experts))) + if self.topk_method != 'noaux_tc': + raise NotImplementedError( + f'insupportable TopK function for MoE gating: {self.topk_method}' + ) + if self.scoring_func != 'sigmoid': + raise NotImplementedError( + f'insupportable scoring function for MoE gating: {self.scoring_func}' + ) + self.e_score_correction_bias = nnx.Param( + jnp.empty((self.n_routed_experts,)) + ) + self.weight.value = nnx.initializers.kaiming_uniform()( + rngs.params(), self.weight.value.shape, self.weight.value.dtype + ) + + def __call__(self, hidden_states: jax.Array): + # [b, n, h] + bsz, seq_len, h = hidden_states.shape + ### compute gating score + # [b * n, h] + # hidden_states = hidden_states.reshape(-1, h) + hidden_states = einop(hidden_states, '... h -> (...) h') + # [b * n, e] + logits = linear( + hidden_states.astype(jnp.float32), + self.weight.value.astype(jnp.float32), + None, + ) + # [b * n, e] + scores: jax.Array = nnx.sigmoid(logits) + ### select top-k experts + # [b * n, e] + scores_for_choice: jax.Array = scores + self.e_score_correction_bias[None] + # [b * n, n_group, e_g] + # group_features = scores_for_choice.reshape(bsz * seq_len, self.n_group, -1) + group_features = einop( + scores_for_choice, 'bn (g e_g) -> bn g e_g', g=self.n_group + ) + # [b * n, n_group] + group_scores = jax.lax.top_k(group_features, 2)[0].sum(axis=-1) + # [n, top_k_group] + _, group_idx = jax.lax.top_k(group_scores, k=self.topk_group) + # [b * n, n_group] + group_mask = jnp.zeros_like(group_scores) + # [b * n, n_group] + group_mask = jnp.put_along_axis( + group_mask, group_idx, values=1, axis=1, inplace=False + ) + # [b * n, e] + score_mask = jnp.broadcast_to( + group_mask[..., None], + (bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group), + ).reshape(bsz * seq_len, -1) + # [b * n, e] + tmp_scores = jnp.where(score_mask.astype(jnp.bool), scores_for_choice, 0.0) + # [b * n, top_k] + _, topk_idx = jax.lax.top_k(tmp_scores, k=self.top_k) + # [b * n, top_k] + topk_weight = jnp.take_along_axis(scores_for_choice, topk_idx, axis=-1) + # [b * n, e] + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(axis=-1, keepdims=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = ( + topk_weight * self.routed_scaling_factor + ) # must multiply the scaling factor + return topk_idx, topk_weight + + +class MoE(nnx.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config: Config, *, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + self.experts = [ + MLP(config, intermediate_size=config.moe_intermediate_size, rngs=rngs) + for i in range(config.n_routed_experts) + ] + + self.gate = MoEGate(config, rngs=rngs) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = MLP( + config=config, intermediate_size=intermediate_size, rngs=rngs + ) + else: + self.shared_experts = None + + def __call__( + self, + x: jax.Array, # [b, n, h] + ): + b, n, h = x.shape + x0 = x + # [b * n, top_k], [b * n, top_k] + topk_idx, topk_weight = self.gate(x) + # [b * n, h] + # x = x.reshape(-1, x.shape[-1]) + x = einop(x, 'b n h -> (b n) h') + # [b * n, e, h] + experts_y = jnp.stack([expert(x) for expert in self.experts], axis=1) + # [b * n, top_k, h] + masked_y = jnp.take_along_axis(experts_y, topk_idx[:, :, None], axis=1) + # [b, n, h] + y = einx.dot( + '(b n) top_k h, (b n) top_k -> b n h', masked_y, topk_weight, b=b + ) + if self.shared_experts is not None: + y = y + self.shared_experts(x0) + return y + + def moe_infer( + self, + x: jax.Array, # [b * n, h] + topk_ids: jax.Array, # [b * n, top_k] + topk_weight: jax.Array, # [b * n, top_k] + ): + # [b * n, e, h] + experts_y = jnp.stack([expert(x) for expert in self.experts], axis=1) + # [b * n, top_k, h] + masked_y = jnp.take_along_axis(experts_y, topk_ids[:, :, None], axis=1) + # [b * n, h] + y = jnp.sum(masked_y * topk_weight[:, :, None], axis=1) + return y + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 +class Attention(nnx.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Config, layer_idx: int, *, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nnx.Linear( + self.hidden_size, + self.num_heads * self.q_head_dim, + use_bias=False, + rngs=rngs, + ) + else: + self.q_a_proj = nnx.Linear( + self.hidden_size, + self.q_lora_rank, + use_bias=config.attention_bias, + rngs=rngs, + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank) + self.q_b_proj = nnx.Linear( + self.q_lora_rank, + self.num_heads * self.q_head_dim, + use_bias=False, + rngs=rngs, + ) + + self.kv_a_proj_with_mqa = nnx.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + use_bias=config.attention_bias, + rngs=rngs, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nnx.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + use_bias=False, + rngs=rngs, + ) + + self.o_proj = nnx.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + use_bias=config.attention_bias, + rngs=rngs, + ) + if ( + self.config.rope_scaling is None + or self.config.rope_scaling['type'] != 'yarn' + ): + raise ValueError('Only "yarn" scaling is supported for "rope_scaling".') + + scaling_factor = self.config.rope_scaling['factor'] + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + 'original_max_position_embeddings', + 'beta_fast', + 'beta_slow', + 'mscale', + 'mscale_all_dim', + ] + if key in self.config.rope_scaling + } + self.rotary_emb = YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + + self.softmax_scale = self.q_head_dim ** (-0.5) + mscale_all_dim = self.config.rope_scaling.get('mscale_all_dim', 0) + scaling_factor = self.config.rope_scaling['factor'] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> tuple[jax.Array, jax.Array | None, tuple[jax.Array] | None]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + bsz, q_len, _ = hidden_states.shape + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.reshape(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + # q = einop(q, 'b n (head q_h) -> b head n q_h', num_heads=self.num_heads) + q_nope, q_pe = jnp.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = jnp.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 + ) + k_pe = k_pe.reshape(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + # k_pe = einop(k_pe, 'b n (head q_h) -> b head n q_h', num_heads=1) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .reshape( + bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + .transpose(1, 2) + ) + + k_nope, value_states = jnp.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.' + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) + * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + +class DeepseekV3DecoderLayer(nnx.Module): + def __init__(self, config: Config, layer_idx: int, *, rngs: nnx.Rngs): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Attention(config=config, layer_idx=layer_idx, rngs=rngs) + + self.mlp = ( + MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else MLP(config) + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + **kwargs, + ) -> tuple[ + torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.', + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['DeepseekV3DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.', + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + Args: + config: Config + """ + + def __init__(self, config: Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = [ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + self._use_flash_attention_2 = ( + config._attn_implementation == 'flash_attention_2' + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = ( + input_ids.device if input_ids is not None else inputs_embeds.device + ) + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + _tied_weights_keys = ['lm_head.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + 'Cannot handle batch sizes > 1 if no padding token is defined.' + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/examples/deepseek_v3/model_pytorch.py b/examples/deepseek_v3/model_pytorch.py new file mode 100644 index 000000000..15ef9de9f --- /dev/null +++ b/examples/deepseek_v3/model_pytorch.py @@ -0,0 +1,1854 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PyTorch DeepSeek model.""" +import math +import warnings + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from config import Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap( + _prepare_4d_causal_attention_mask + ) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'Config' + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__( + self, dim, max_position_embeddings=2048, base=10000, device=None + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + 'cos_cached', (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + 'sin_cached', (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = ( + config.hidden_size if hidden_size is None else hidden_size + ) + self.intermediate_size = ( + config.intermediate_size + if intermediate_size is None + else intermediate_size + ) + + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + if self.topk_method == 'noaux_tc': + self.e_score_correction_bias = nn.Parameter( + torch.empty(self.n_routed_experts) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == 'sigmoid': + scores = logits.sigmoid() + else: + raise NotImplementedError( + f'insupportable scoring function for MoE gating: {self.scoring_func}' + ) + + ### select top-k experts + if self.topk_method == 'noaux_tc': + assert not self.training + scores_for_choice = scores.view( + bsz * seq_len, -1 + ) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0 + ) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError( + f'insupportable TopK function for MoE gating: {self.topk_method}' + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = ( + topk_weight * self.routed_scaling_factor + ) # must multiply the scaling factor + + return topk_idx, topk_weight + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, 'ep_size') and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = ( + torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + ) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape( + batch, num_key_value_heads * n_rep, slen, head_dim + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Config, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f'Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ' + 'to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ' + 'when creating this class.' + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get('mscale_all_dim', 0) + scaling_factor = self.config.rope_scaling['factor'] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'linear': + self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == 'dynamic': + self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == 'yarn': + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + 'original_max_position_embeddings', + 'beta_fast', + 'beta_slow', + 'mscale', + 'mscale_all_dim', + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> tuple[ + torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None + ]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.' + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) + * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 +class DeepseekV3FlashAttention2(DeepseekV3Attention): + """ + DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = ( + not is_flash_attn_greater_or_equal_2_10() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> tuple[ + torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None + ]: + # DeepseekV3FlashAttention2 attention does not support output_attentions + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.' + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask + ) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape( + batch_size * kv_seq_len, num_key_value_heads, head_dim + ), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + 'eager': DeepseekV3Attention, + 'flash_attention_2': DeepseekV3FlashAttention2, +} + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + **kwargs, + ) -> tuple[ + torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.', + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['DeepseekV3DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.', + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + Args: + config: Config + """ + + def __init__(self, config: Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = ( + config._attn_implementation == 'flash_attention_2' + ) + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = ( + input_ids.device if input_ids is not None else inputs_embeds.device + ) + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + _tied_weights_keys = ['lm_head.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + 'Cannot handle batch sizes > 1 if no padding token is defined.' + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/examples/deepseek_v3/model_test.py b/examples/deepseek_v3/model_test.py new file mode 100644 index 000000000..a876b76cc --- /dev/null +++ b/examples/deepseek_v3/model_test.py @@ -0,0 +1,201 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================å +"""Tests for the DeepSeek model.""" + + +from absl.testing import absltest +from absl.testing import parameterized +from flax import nnx +import model as model_lib +import model_pytorch as model_lib_pt +import jax +import jax.numpy as jnp +import numpy as np +import torch + + +def bf16_pt_to_jax(pt_tensor: torch.Tensor) -> jax.Array: + return jnp.array( + pt_tensor.detach().to(torch.float32).numpy().astype(jnp.bfloat16) + ) + +def to_jax(pt_tensor: torch.Tensor, /) -> jax.Array: + return jnp.asarray(pt_tensor.detach().numpy()) + + +class ModelTest(parameterized.TestCase): + def _copy_rs_norm_weights( + self, model_jax: model_lib.RMSNorm, model_pt: model_lib_pt.DeepseekV3RMSNorm + ): + model_jax.weight.value = to_jax(model_pt.weight) + + def test_rsnorm_parity(self): + """Tests that the RSNorm layer is implemented correctly.""" + hidden_size = 16 + + # pytorch + model_pt = model_lib_pt.DeepseekV3RMSNorm(hidden_size) + x_pt = torch.randn(1, 5, hidden_size) + y_pt = model_pt(x_pt) + + # jax + model_jax = model_lib.RMSNorm(hidden_size) + self._copy_rs_norm_weights(model_jax, model_pt) + x_jax = to_jax(x_pt) + y_jax = model_jax(x_jax) + + np.testing.assert_allclose(to_jax(y_pt), y_jax, atol=1e-3) + + def _copy_embedding_weights( + self, + model_jax: model_lib.YarnRotaryEmbedding, + model_pt: model_lib_pt.DeepseekV3YarnRotaryEmbedding, + ): + model_jax.inv_freq.value = to_jax(model_pt.inv_freq) + model_jax.cos_cached.value = to_jax(model_pt.cos_cached) + model_jax.sin_cached.value = to_jax(model_pt.sin_cached) + + def test_rotary_embedding(self): + hidden_dim = 32 + seq_len = 8 + + # pytorch + model_pt = model_lib_pt.DeepseekV3YarnRotaryEmbedding( + dim=hidden_dim, max_position_embeddings=128, base=10000 + ) + x_pt = torch.randn(2, seq_len, hidden_dim) + cos_pt, sin_pt = model_pt(x_pt, seq_len=seq_len) + + # jax + model_jax = model_lib.YarnRotaryEmbedding(hidden_dim, 128, 10000) + self._copy_embedding_weights(model_jax, model_pt) + x_jax = to_jax(x_pt) + cos_jax, sin_jax = model_jax(x_jax, seq_len=seq_len) + + np.testing.assert_allclose(to_jax(cos_pt), cos_jax, atol=1e-3) + np.testing.assert_allclose(to_jax(sin_pt), sin_jax, atol=1e-3) + + def _copy_mlp_weights( + self, model_jax: model_lib.MLP, model_pt: model_lib_pt.DeepseekV3MLP + ): + model_jax.gate_proj.kernel.value = to_jax(model_pt.gate_proj.weight).T + model_jax.up_proj.kernel.value = to_jax(model_pt.up_proj.weight).T + model_jax.down_proj.kernel.value = to_jax(model_pt.down_proj.weight).T + + def test_mlp_parity(self): + """Tests that the MLP layer is implemented correctly.""" + hidden_size = 16 + mlp_dim = 32 + config = model_lib.Config() + + # pytorch + model_pt = model_lib_pt.DeepseekV3MLP(config, hidden_size, mlp_dim) + x_pt = torch.randn(1, 5, hidden_size) + y_pt = model_pt(x_pt) + + # jax + model_jax = model_lib.MLP(config, hidden_size, mlp_dim, rngs=nnx.Rngs(0)) + self._copy_mlp_weights(model_jax, model_pt) + x_jax = to_jax(x_pt) + y_jax = model_jax(x_jax) + + np.testing.assert_allclose(to_jax(y_pt), y_jax, atol=1e-3) + + def _copy_moe_gate_weights( + self, model_jax: model_lib.MoEGate, model_pt: model_lib_pt.MoEGate + ): + model_jax.weight.value = to_jax(model_pt.weight).T + if model_jax.e_score_correction_bias is not None: + model_jax.e_score_correction_bias.value = to_jax( + model_pt.e_score_correction_bias + ) + + def test_moe_gate_parity(self): + """Tests that the MoEGate layer is implemented correctly.""" + + config = model_lib.Config( + n_routed_experts=8, + hidden_size=16, + n_group=4, + topk_group=2, + ) + # pytorch + model_pt = model_lib_pt.MoEGate(config) + model_pt.eval() + x_pt = torch.randn(1, 12, 16) + topk_idx_pt, topk_weight_pt = model_pt(x_pt) + topk_idx_pt, topk_weight_pt = to_jax(topk_idx_pt), to_jax(topk_weight_pt) + + # sort pytorch output + idxs_pt = jnp.argsort(topk_idx_pt, axis=-1) + topk_idx_pt = jnp.take_along_axis(topk_idx_pt, idxs_pt, axis=-1) + topk_weight_pt = jnp.take_along_axis(topk_weight_pt, idxs_pt, axis=-1) + + # jax + model_jax = model_lib.MoEGate(config, rngs=nnx.Rngs(0)) + self._copy_moe_gate_weights(model_jax, model_pt) + x_jax = to_jax(x_pt) + topk_idx_jax, topk_weight_jax = model_jax(x_jax) + + # sort jax output + idxs_jax = jnp.argsort(topk_idx_jax, axis=-1) + topk_idx_jax = jnp.take_along_axis(topk_idx_jax, idxs_jax, axis=-1) + topk_weight_jax = jnp.take_along_axis(topk_weight_jax, idxs_jax, axis=-1) + + np.testing.assert_allclose(topk_idx_pt, topk_idx_jax, atol=1e-3) + np.testing.assert_allclose(topk_weight_pt, topk_weight_jax, atol=1e-3) + + def _copy_moe_weights( + self, jax_model: model_lib.MoE, pt_model: model_lib_pt.DeepseekV3MoE + ): + for expert_jax, expert_pt in zip(jax_model.experts, pt_model.experts): + assert isinstance(expert_pt, model_lib_pt.DeepseekV3MLP) + self._copy_mlp_weights(expert_jax, expert_pt) + + if jax_model.shared_experts is not None: + self._copy_mlp_weights(jax_model.shared_experts, pt_model.shared_experts) + + self._copy_moe_gate_weights(jax_model.gate, pt_model.gate) + + def test_moe_parity(self): + """Tests that the Moe layer is implemented correctly.""" + + batch_size = 4 + seq_len = 10 + config = model_lib.Config( + n_routed_experts=12, + hidden_size=16, + n_group=4, + topk_group=2, + num_experts_per_tok=3, + ) + # pytorch + model_pt = model_lib_pt.DeepseekV3MoE(config) + model_pt.eval() + x_pt = torch.randn(batch_size, seq_len, config.hidden_size) + y_pt = model_pt(x_pt) + + # jax + model_jax = model_lib.MoE(config, rngs=nnx.Rngs(0)) + + self._copy_moe_weights(model_jax, model_pt) + x_jax = to_jax(x_pt) + y_jax = model_jax(x_jax) + + np.testing.assert_allclose(to_jax(y_pt), y_jax, atol=1e-3) + + +if __name__ == "__main__": + absltest.main() diff --git a/examples/deepseek_v3/requirements.txt b/examples/deepseek_v3/requirements.txt new file mode 100644 index 000000000..870ce335a --- /dev/null +++ b/examples/deepseek_v3/requirements.txt @@ -0,0 +1,3 @@ +transformers==4.46.3 +einop +einx \ No newline at end of file diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py new file mode 100644 index 000000000..374a6eb92 --- /dev/null +++ b/flax/nnx/summary.py @@ -0,0 +1,552 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict +import dataclasses +import functools +import inspect +import io +import typing as tp +from types import MappingProxyType + +import jax +import numpy as np +import rich.console +import rich.table +import rich.text +import yaml +import jax.numpy as jnp + +from flax import nnx +from flax import typing +from flax.nnx import graph, statelib, variablelib + +try: + from IPython import get_ipython + + in_ipython = get_ipython() is not None +except ImportError: + in_ipython = False + +class SizeBytes(typing.SizeBytes): + def __repr__(self) -> str: + bytes_repr = _bytes_repr(self.bytes) + return f'{self.size:,} [dim]({bytes_repr})[/dim]' + +class ObjectInfo(tp.NamedTuple): + path: statelib.PathParts + stats: dict[type[variablelib.Variable], SizeBytes] + variable_groups: defaultdict[ + type[variablelib.Variable], defaultdict[typing.Key, variablelib.Variable] + ] + +NodeStats = dict[int, ObjectInfo | None] + +def _collect_stats( + path: statelib.PathParts, + node: tp.Any, + node_stats: NodeStats, + object_types: set[type], +): + if not graph.is_node(node) and not isinstance(node, variablelib.Variable): + raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.') + + if id(node) in node_stats: + return + + stats: dict[type[variablelib.Variable], SizeBytes] = {} + variable_groups: defaultdict[ + type[variablelib.Variable], defaultdict[typing.Key, variablelib.Variable] + ] = defaultdict(lambda: defaultdict()) + node_stats[id(node)] = ObjectInfo(path, stats, variable_groups) + + if isinstance(node, nnx.Object): + node._nnx_tabulate_id = id(node) # type: ignore + object_types.add(type(node)) + + node_dict = graph.get_node_impl(node).node_dict(node) + for key, value in node_dict.items(): + if id(value) in node_stats: + continue + elif isinstance(value, variablelib.Variable): + var_type = type(value) + if issubclass(var_type, nnx.RngState): + var_type = nnx.RngState + size_bytes = SizeBytes.from_any(value.value) + if var_type in stats: + stats[var_type] += size_bytes + else: + stats[var_type] = size_bytes + variable_groups[var_type][key] = value + node_stats[id(value)] = None + elif graph.is_node(value): + _collect_stats((*path, key), value, node_stats, object_types) + # accumulate stats from children + child_info = node_stats[id(value)] + assert child_info is not None + for var_type, size_bytes in child_info.stats.items(): + if var_type in stats: + stats[var_type] += size_bytes + else: + stats[var_type] = size_bytes + +@dataclasses.dataclass(frozen=True, repr=False) +class ArrayRepr: + shape: tuple[int, ...] + dtype: tp.Any + + @classmethod + def from_array(cls, x: jax.Array | np.ndarray): + return cls(jnp.shape(x), jnp.result_type(x)) + + def __str__(self): + shape_repr = ','.join(str(x) for x in self.shape) + return f'[dim]{self.dtype}[/dim][{shape_repr}]' + + +@dataclasses.dataclass +class CallInfo: + object_id: int + type: type + path: statelib.PathParts + input_args: tuple[tp.Any, ...] + input_kwargs: dict[str, tp.Any] + outputs: tp.Any + +class SimpleObjectRepr: + def __init__(self, obj: tp.Any): + self.type = type(obj) + + def __str__(self): + return f'{self.type.__name__}(...)' + + def __repr__(self): + return f'{self.type.__name__}(...)' + + +def get_method_wrapper( + rows: list[CallInfo], + node_stats: NodeStats, + method: tp.Callable, +) -> tp.Callable: + method_name = method.__name__ + @functools.wraps(method) + def method_wrapper(obj, *args, **kwargs): + def _to_dummy_array(x): + if isinstance(x, jax.Array | np.ndarray): + return ArrayRepr.from_array(x) + elif graph.is_graph_node(x): + return SimpleObjectRepr(x) + else: + return x + + object_id: int = getattr(obj, '_nnx_tabulate_id') + input_args_info, input_kwargs_info = jax.tree.map( + _to_dummy_array, (args, kwargs) + ) + node_info = node_stats[object_id] + assert node_info is not None + path = node_info.path + if method_name != '__call__': + path = (*path, method_name) + call_info = CallInfo( + object_id=object_id, + type=type(obj), + path=path, + input_args=input_args_info, + input_kwargs=input_kwargs_info, + outputs=None, + ) + rows.append(call_info) + out = method(obj, *args, **kwargs) + call_info.outputs = jax.tree.map(_to_dummy_array, out) + return out + + return method_wrapper + + +def _call_obj( + obj, + input_args: tuple[tp.Any, ...], + input_kwargs: dict[str, tp.Any], + *, + call_method: str, + rows: list[CallInfo], + node_stats: NodeStats, + object_types: set[type], +): + original_methods: dict[type, dict[str, tp.Callable]] = {} + try: + for obj_type in object_types: + methods: dict[str, tp.Callable] = {} + for name, top_method in inspect.getmembers(obj_type, inspect.isfunction): + if not name.startswith('_') or name == '__call__': + methods[name] = top_method + method_wrapper = get_method_wrapper(rows, node_stats, top_method) + setattr(obj_type, name, method_wrapper) + + original_methods[obj_type] = methods + + top_method = getattr(type(obj), call_method) + top_method(obj, *input_args, **input_kwargs) + finally: + for obj_type, methods in original_methods.items(): + for name, top_method in methods.items(): + setattr(obj_type, name, top_method) + + +def filter_rng_streams(row: CallInfo): + return not issubclass(row.type, nnx.RngStream) + +def tabulate( + obj, + *input_args, + depth: int | None = None, + method: str = '__call__', + row_filter: tp.Callable[[CallInfo], bool] = filter_rng_streams, + table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + **input_kwargs, +) -> str: + """Creates a summary of the graph object represented as a table. + + The table summarizes the object's state and metadata. The table is + structured as follows: + + - The first column represents the path of the object in the graph. + - The second column represents the type of the object. + - The third column represents the input arguments passed to the object's + method. + - The fourth column represents the output of the object's method. + - The following columns provide information about the object's state, + grouped by Variable types. + + Example: + + >>> from flax import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.bn = nnx.BatchNorm(dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.2, rngs=rngs) + ... + ... def __call__(self, x): + ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) + ... + >>> class Foo(nnx.Module): + ... def __init__(self, rngs: nnx.Rngs): + ... self.block1 = Block(32, 128, rngs=rngs) + ... self.block2 = Block(128, 10, rngs=rngs) + ... + ... def __call__(self, x): + ... return self.block2(self.block1(x)) + ... + >>> foo = Foo(nnx.Rngs(0)) + >>> # print(nnx.tabulate(foo, jnp.ones((1, 32)))) + + Foo Summary + ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ + ┃ path ┃ type ┃ inputs ┃ outputs ┃ BatchStat ┃ Param ┃ RngState ┃ + ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ + │ │ Foo │ float32[1,32] │ float32[1,10] │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block1 │ Block │ float32[1,32] │ float32[1,128] │ 256 (1.0 KB) │ 4,480 (17.9 KB) │ 2 (12 B) │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block1/linear │ Linear │ float32[1,32] │ float32[1,128] │ │ bias: float32[128] │ │ + │ │ │ │ │ │ kernel: float32[32,128] │ │ + │ │ │ │ │ │ │ │ + │ │ │ │ │ │ 4,224 (16.9 KB) │ │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block1/bn │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128] │ │ + │ │ │ │ │ var: float32[128] │ scale: float32[128] │ │ + │ │ │ │ │ │ │ │ + │ │ │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block1/dropout │ Dropout │ float32[1,128] │ float32[1,128] │ │ │ 2 (12 B) │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block2 │ Block │ float32[1,128] │ float32[1,10] │ 20 (80 B) │ 1,310 (5.2 KB) │ │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block2/linear │ Linear │ float32[1,128] │ float32[1,10] │ │ bias: float32[10] │ │ + │ │ │ │ │ │ kernel: float32[128,10] │ │ + │ │ │ │ │ │ │ │ + │ │ │ │ │ │ 1,290 (5.2 KB) │ │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block2/bn │ BatchNorm │ float32[1,10] │ float32[1,10] │ mean: float32[10] │ bias: float32[10] │ │ + │ │ │ │ │ var: float32[10] │ scale: float32[10] │ │ + │ │ │ │ │ │ │ │ + │ │ │ │ │ 20 (80 B) │ 20 (80 B) │ │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ block2/dropout │ Dropout │ float32[1,10] │ float32[1,10] │ │ │ │ + ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ + │ │ │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ + └────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘ + + Total Parameters: 6,068 (24.3 KB) + + + Note that ``block2/dropout`` is not shown in the table because it shares the + same ``RngState`` with ``block1/dropout``. + + Args: + obj: A object to summarize. It can a pytree or a graph objects + such as nnx.Module or nnx.Optimizer. + *input_args: Positional arguments passed to the object's method. + **input_kwargs: Keyword arguments passed to the object's method. + depth: The depth of the table. + method: The method to call on the object. Default is ``'__call__'``. + row_filter: A callable that filters the rows to be displayed in the table. + By default, it filters out rows with type ``nnx.RngStream``. + table_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table`` constructor. + column_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table.add_column`` when adding columns to + the table. + console_kwargs: An optional dictionary with additional keyword arguments + that are passed to `rich.console.Console` when rendering the table. + Default arguments are ``'force_terminal': True``, and ``'force_jupyter'`` + is set to ``True`` if the code is running in a Jupyter notebook, otherwise + it is set to ``False``. + + Returns: + A string summarizing the object. + """ + _console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython} + _console_kwargs.update(console_kwargs) + + obj = graph.clone(obj) # create copy to avoid side effects + node_stats: NodeStats = {} + object_types: set[type] = set() + _collect_stats((), obj, node_stats, object_types) + variable_types = { + nnx.RngState + if issubclass(variable_state.type, nnx.RngState) + else variable_state.type + for _, variable_state in nnx.state(obj).flat_state() + } + variable_types = sorted(variable_types, key=lambda t: t.__name__) + + rows: list[CallInfo] = [] + eval_fn = functools.partial( + _call_obj, + call_method=method, + rows=rows, + node_stats=node_stats, + object_types=object_types, + ) + nnx.eval_shape(eval_fn, obj, input_args, input_kwargs) + + if depth is not None: + rows = [row for row in rows if len(row.path) <= depth and row_filter(row)] + else: + rows = [row for row in rows if row_filter(row)] + + rich_table = rich.table.Table( + show_header=True, + show_lines=True, + show_footer=True, + title=f'{type(obj).__name__} Summary', + **table_kwargs, + ) + + rich_table.add_column('path', **column_kwargs) + rich_table.add_column('type', **column_kwargs) + rich_table.add_column('inputs', **column_kwargs) + rich_table.add_column('outputs', **column_kwargs) + + for var_type in variable_types: + rich_table.add_column(var_type.__name__, **column_kwargs) + + for row in rows: + node_info = node_stats[row.object_id] + assert node_info is not None + col_reprs: list[str] = [] + path_str = '/'.join(map(str, row.path)) + col_reprs.append(path_str) + col_reprs.append(row.type.__name__) + inputs_repr = '' + if row.input_args: + input_args = row.input_args + if len(row.input_args) == 1 and not row.input_kwargs: + input_args = row.input_args[0] + inputs_repr += _as_yaml_str(input_args) + if inputs_repr and row.input_kwargs: + inputs_repr += '\n' + if row.input_kwargs: + inputs_repr += _as_yaml_str(row.input_kwargs) + col_reprs.append(inputs_repr) + col_reprs.append(_as_yaml_str(row.outputs)) + + for var_type in variable_types: + attributes = {} + for name, variable in node_info.variable_groups[var_type].items(): + value = variable.value + value_repr = _render_array(value) if _has_shape_dtype(value) else '' + metadata = variable.get_metadata() + + if metadata: + attributes[name] = { + 'value': value_repr, + **metadata, + } + elif value_repr: + attributes[name] = value_repr + + if attributes: + col_repr = _as_yaml_str(attributes) + '\n\n' + else: + col_repr = '' + + size_bytes = node_info.stats.get(var_type) + if size_bytes: + col_repr += f'[bold]{size_bytes}[/bold]' + col_reprs.append(col_repr) + + rich_table.add_row(*col_reprs) + + rich_table.columns[3].footer = rich.text.Text.from_markup( + 'Total', justify='right' + ) + node_info = node_stats[id(obj)] + assert node_info is not None + for i, var_type in enumerate(variable_types): + size_bytes = node_info.stats[var_type] + rich_table.columns[i + 4].footer = str(size_bytes) + + rich_table.caption_style = 'bold' + total_size = sum(node_info.stats.values(), SizeBytes(0, 0)) + rich_table.caption = f'\nTotal Parameters: {total_size}' + + return _get_rich_repr(rich_table, _console_kwargs) + + +def _get_rich_repr(obj, console_kwargs): + f = io.StringIO() + console = rich.console.Console(file=f, **console_kwargs) + console.print(obj) + return f.getvalue() + + +def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: + leaves = jax.tree.leaves(pytree) + size = sum(x.size for x in leaves if hasattr(x, 'size')) + num_bytes = sum( + x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') + ) + return size, num_bytes + + +def _size_and_bytes_repr(size: int, num_bytes: int) -> str: + if not size: + return '' + bytes_repr = _bytes_repr(num_bytes) + return f'{size:,} [dim]({bytes_repr})[/dim]' + + +def _bytes_repr(num_bytes): + count, units = ( + (f'{num_bytes / 1e9:,.1f}', 'GB') + if num_bytes > 1e9 + else (f'{num_bytes / 1e6:,.1f}', 'MB') + if num_bytes > 1e6 + else (f'{num_bytes / 1e3:,.1f}', 'KB') + if num_bytes > 1e3 + else (f'{num_bytes:,}', 'B') + ) + + return f'{count} {units}' + + +def _has_shape_dtype(value): + return hasattr(value, 'shape') and hasattr(value, 'dtype') + + +def _normalize_values(x): + if isinstance(x, type): + return f'type[{x.__name__}]' + elif isinstance(x, ArrayRepr | SimpleObjectRepr): + return str(x) + else: + return x + +def _maybe_pytree_to_dict(pytree: tp.Any): + path_leaves = jax.tree_util.tree_flatten_with_path(pytree)[0] + path_leaves = [ + (tuple(map(graph._key_path_to_key, path)), value) + for path, value in path_leaves + ] + if len(path_leaves) < 1: + return pytree + elif len(path_leaves) == 1 and path_leaves[0][0] == (): + return pytree + else: + return _unflatten_to_simple_structure(path_leaves) + + +def _unflatten_to_simple_structure(xs: list[tuple[tuple[tp.Any, ...], tp.Any]]): + result = [] if len(xs) > 0 and isinstance(xs[0][0][0], int) else {} + for path, value in xs: + cursor = result + for i, key in enumerate(path[:-1]): + if isinstance(cursor, list): + if key == len(cursor): + next_key = path[i + 1] + if isinstance(next_key, int): + cursor.append([]) + else: + cursor.append({}) + else: + if key not in cursor: + next_key = path[i + 1] + if isinstance(next_key, int): + cursor[key] = [] + else: + cursor[key] = {} + cursor = cursor[key] + cursor[path[-1]] = value + return result + +def _as_yaml_str(value) -> str: + if (hasattr(value, '__len__') and len(value) == 0) or value is None: + return '' + + value = jax.tree.map(_normalize_values, value) + value = _maybe_pytree_to_dict(value) + + file = io.StringIO() + yaml.safe_dump( + value, + file, + default_flow_style=False, + indent=2, + sort_keys=False, + explicit_end=False, + ) + return file.getvalue().replace('\n...', '').replace("'", '').strip() + + +def _render_array(x): + shape, dtype = jnp.shape(x), jnp.result_type(x) + shape_repr = ','.join(str(x) for x in shape) + return f'[dim]{dtype}[/dim][{shape_repr}]' + + +def _sort_variable_types(types: tp.Iterable[type]) -> list[type]: + def _variable_parents_count(t: type): + return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) + + type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types} + return sorted(types, key=lambda t: type_sort_key[t]) diff --git a/pyproject.toml b/pyproject.toml index f7a890fad..66cce2d47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,6 +185,8 @@ filterwarnings = [ "ignore:.*divide by zero encountered in.*:RuntimeWarning", # DeprecationWarning: numpy.core is deprecated "ignore:.*numpy.core is deprecated.*:DeprecationWarning", + # MatplotlibDeprecationWarning: + "ignore:.*The non_interactive_bk attribute was deprecated.*:Warning", ] [tool.coverage.report] @@ -206,7 +208,7 @@ exclude = [ "activation.py", "partitioning.py", "flax/core/variables.py", - "examples/", +# "examples/", ] line-length = 80 diff --git a/uv.lock b/uv.lock index 2857313f7..929701958 100644 --- a/uv.lock +++ b/uv.lock @@ -2258,7 +2258,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.11.0" +version = "0.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2276,9 +2276,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/de/b3/a9a8a6bc08ded7634a9d85ba440400172f0a11f9341897b8fd3389fad245/orbax_checkpoint-0.11.0.tar.gz", hash = "sha256:d4a0dcc81edd29191cf5a4feb9cf2a4edd31fc5da79d7be616a04f11f2a4d484", size = 253035 } +sdist = { url = "https://files.pythonhosted.org/packages/66/3a/60a08a16997b1e4e097d03b8b9135eec44c88cb8547822efe46d2e97eb09/orbax_checkpoint-0.11.1.tar.gz", hash = "sha256:e303f4f4441ba0367a14c692d6c73070bba954205de89df608f21823dd86eaee", size = 277832 } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/32/3779fa524a2272f408ab51d869fde9ff1c0ca731eedd01e40436bcf7ba2c/orbax_checkpoint-0.11.0-py3-none-any.whl", hash = "sha256:892a124fce71f3e7c71451a2b2090c0251db1097803a119a00baa377113bc9ba", size = 360423 }, + { url = "https://files.pythonhosted.org/packages/00/06/90f972344350d12ed77b6cf28f8e06215eff9d8ca4da95679e8d0369d697/orbax_checkpoint-0.11.1-py3-none-any.whl", hash = "sha256:9cfbe49331ffdef4aedfec337da683eb8fc65221dc43924cf6daea6c27d6dc06", size = 394885 }, ] [[package]]