diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index aa6f4f2023..dfe9c9c8ba 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -150,6 +150,15 @@ from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, ) +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone as DeepSeekV31Backbone, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm import ( + DeepSeekV31CausalLM as DeepSeekV31CausalLM, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm_preprocessor import ( + DeepSeekV31CausalLMPreprocessor as DeepSeekV31CausalLMPreprocessor, +) from keras_hub.src.models.deit.deit_backbone import DeiTBackbone as DeiTBackbone from keras_hub.src.models.deit.deit_image_classifier import ( DeiTImageClassifier as DeiTImageClassifier, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 97a68ab009..8e39752a51 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -22,6 +22,9 @@ from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer as DebertaV3Tokenizer, ) +from keras_hub.src.models.deepseek_v31.deepseek_v31_tokenizer import ( + DeepSeekV31Tokenizer as DeepSeekV31Tokenizer, +) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer as DistilBertTokenizer, ) diff --git a/keras_hub/src/models/deepseek_v31/__init__.py b/keras_hub/src/models/deepseek_v31/__init__.py new file mode 100644 index 0000000000..89f185ea5b --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/__init__.py @@ -0,0 +1,28 @@ +"""DeepSeek V3.1 model exports.""" + +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm import ( + DeepSeekV31CausalLM, # noqa: F401 +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm_preprocessor import ( # noqa: E501 + DeepSeekV31CausalLMPreprocessor, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_presets import ( + backbone_presets, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_presets import ( + preprocessor_presets, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_presets import ( + tokenizer_presets, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_tokenizer import ( + DeepSeekV31Tokenizer, +) +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DeepSeekV31Backbone) +register_presets(tokenizer_presets, DeepSeekV31Tokenizer) +register_presets(preprocessor_presets, DeepSeekV31CausalLMPreprocessor) diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_attention.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_attention.py new file mode 100644 index 0000000000..c74af3c95c --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_attention.py @@ -0,0 +1,446 @@ +"""DeepSeek V31 Multi-head Latent Attention layer.""" + +import math + +import keras +from keras import ops + + +class DeepSeekV31Attention(keras.layers.Layer): + """Multi-head Latent Attention (MLA) for DeepSeek V31. + + Implements the MLA architecture from Section 2.1.1 of the DeepSeek-V3 + paper. MLA reduces KV cache size by compressing keys and values through + a shared low-rank latent vector `c_kv` of dimension `kv_lora_rank`, + rather than caching full per-head K and V tensors. + + The attention computation splits queries and keys into two components: + + - **Content (nope) component**: `q_nope`, `k_nope` — carries semantic + information, does not receive positional encoding. + - **RoPE component**: `q_rope`, `k_rope` — receives Rotary Position + Embeddings (RoPE) for positional awareness. + + During inference only `(c_kv, k_rope)` need to be stored in the KV cache, + not the full materialized K and V tensors. The content keys and values are + recovered via the absorption matrices `w_uk` and `w_uv`: + + score = q_nope @ w_uk.T @ c_kv.T + q_rope @ k_rope.T + + YaRN (Yet another RoPE extensioN, Section 4.3) is applied to the RoPE + frequencies to support contexts longer than the training length. Different + frequency bands are scaled differently: high-frequency dimensions + (short wavelength) receive little scaling while low-frequency dimensions + (long wavelength) are scaled more aggressively. + + Args: + hidden_dim: int. Dimensionality of model hidden states. + num_query_heads: int. Number of query attention heads. + num_key_value_heads: int. Number of key/value heads. For MLA this + equals `num_query_heads` since KV are recovered from a shared + latent. + q_lora_rank: int. Rank of the query down-projection. + kv_lora_rank: int. Rank of the shared KV latent `c_kv`. This is the + per-layer KV cache size per token. + qk_nope_head_dim: int. Per-head dimension for the content (non-RoPE) + query and key components. + qk_rope_head_dim: int. Per-head dimension for the RoPE query and key + components. + v_head_dim: int. Per-head dimension for values. + rope_max_wavelength: int. Base wavelength for RoPE inverse frequencies. + Defaults to `10000`. + rope_scaling_factor: float. YaRN context extension scale factor. + Values greater than 1.0 extend the effective context length. + Defaults to `1.0`. + yarn_beta_fast: int. YaRN ramp upper threshold. Dimensions with + wavelength above `yarn_beta_fast * original_max_position` are + treated as low-frequency and receive full scaling. Defaults to + `32`. + yarn_beta_slow: int. YaRN ramp lower threshold. Dimensions with + wavelength below `yarn_beta_slow * original_max_position` are + treated as high-frequency and receive no scaling. Defaults to + `1`. + yarn_mscale: float. YaRN magnitude scaling factor applied to attention + cosine/sine embeddings. Defaults to `1.0`. + yarn_mscale_all_dim: float. If non-zero, overrides `yarn_mscale` for + all dimensions. Defaults to `0.0`. + yarn_original_max_position_embeddings: int. The context length used + during pre-training, used as the reference for YaRN ramp + thresholds. Defaults to `4096`. + attention_dropout: float. Dropout probability applied to attention + weights. Defaults to `0.0`. + kernel_initializer: string or initializer. Initializer for Dense and + raw weight matrices. Defaults to `"glorot_uniform"`. + + Example: + + ```python + attn = keras_hub.layers.DeepSeekV31Attention( + hidden_dim=512, + num_query_heads=8, + num_key_value_heads=8, + q_lora_rank=256, + kv_lora_rank=128, + qk_nope_head_dim=32, + qk_rope_head_dim=16, + v_head_dim=32, + ) + hidden = keras.random.normal((2, 16, 512)) + output = attn(hidden) # (2, 16, 512) + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + def __init__( + self, + hidden_dim, + num_query_heads, + num_key_value_heads, + q_lora_rank, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + yarn_beta_fast=32, + yarn_beta_slow=1, + yarn_mscale=1.0, + yarn_mscale_all_dim=0.0, + yarn_original_max_position_embeddings=4096, + attention_dropout=0.0, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.yarn_beta_fast = yarn_beta_fast + self.yarn_beta_slow = yarn_beta_slow + self.yarn_mscale = yarn_mscale + self.yarn_mscale_all_dim = yarn_mscale_all_dim + self.yarn_original_max_position_embeddings = ( + yarn_original_max_position_embeddings + ) + self.attention_dropout = attention_dropout + self.kernel_initializer = kernel_initializer + + # Query low-rank compression (eq. 6-8). + self.q_down_proj = keras.layers.Dense( + q_lora_rank, + use_bias=False, + kernel_initializer=kernel_initializer, + name="q_down_proj", + dtype=self.dtype_policy, + ) + self.q_up_nope_proj = keras.layers.Dense( + num_query_heads * qk_nope_head_dim, + use_bias=False, + kernel_initializer=kernel_initializer, + name="q_up_nope_proj", + dtype=self.dtype_policy, + ) + self.q_up_rope_proj = keras.layers.Dense( + num_query_heads * qk_rope_head_dim, + use_bias=False, + kernel_initializer=kernel_initializer, + name="q_up_rope_proj", + dtype=self.dtype_policy, + ) + + # KV low-rank compression (eq. 1-5). + self.kv_down_proj = keras.layers.Dense( + kv_lora_rank, + use_bias=False, + kernel_initializer=kernel_initializer, + name="kv_down_proj", + dtype=self.dtype_policy, + ) + self.k_rope_proj = keras.layers.Dense( + qk_rope_head_dim, + use_bias=False, + kernel_initializer=kernel_initializer, + name="k_rope_proj", + dtype=self.dtype_policy, + ) + + self.output_proj = keras.layers.Dense( + hidden_dim, + use_bias=False, + kernel_initializer=kernel_initializer, + name="output_proj", + dtype=self.dtype_policy, + ) + self.dropout = keras.layers.Dropout( + attention_dropout, + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + # W_UK and W_UV are stored as raw weight matrices for the MLA + # absorption trick. Rather than materializing full per-head K and V + # tensors, we score queries against c_kv directly using these absorbed + # matrices: score = q_nope @ W_UK.T @ c_kv.T + self.w_uk = self.add_weight( + shape=( + self.num_query_heads * self.qk_nope_head_dim, + self.kv_lora_rank, + ), + initializer=self.kernel_initializer, + name="w_uk", + ) + self.w_uv = self.add_weight( + shape=( + self.num_query_heads * self.v_head_dim, + self.kv_lora_rank, + ), + initializer=self.kernel_initializer, + name="w_uv", + ) + + self.q_down_proj.build(input_shape) + q_down_shape = list(input_shape[:-1]) + [self.q_lora_rank] + self.q_up_nope_proj.build(q_down_shape) + self.q_up_rope_proj.build(q_down_shape) + self.kv_down_proj.build(input_shape) + self.k_rope_proj.build(input_shape) + self.output_proj.build( + list(input_shape[:-1]) + + [self.num_key_value_heads * self.v_head_dim] + ) + super().build(input_shape) + + def _yarn_inv_freq(self, dtype): + """Return YaRN-scaled RoPE inverse frequencies and magnitude scale.""" + dim = self.qk_rope_head_dim + freqs = 1.0 / ( + self.rope_max_wavelength + ** (ops.arange(0, dim, 2, dtype="float32") / dim) + ) + + if self.rope_scaling_factor <= 1.0: + return ops.cast(freqs, dtype), self.yarn_mscale + + # Wavelength = 2π / freq. High-freq → small wavelength, low-freq → + # large wavelength. YaRN applies more scaling to low-freq dimensions. + + wavelengths = 2.0 * math.pi / freqs + old_ctx = float(self.yarn_original_max_position_embeddings) + beta_slow = float(self.yarn_beta_slow) + beta_fast = float(self.yarn_beta_fast) + + # ramp=0 → high-freq (no extra scaling), ramp=1 → low-freq (full scale) + ramp = ops.clip( + (wavelengths / old_ctx - beta_slow) / (beta_fast - beta_slow), + 0.0, + 1.0, + ) + scale = (1.0 - ramp) + ramp * self.rope_scaling_factor + scaled_freqs = freqs / scale + + mscale = ( + self.yarn_mscale_all_dim + if self.yarn_mscale_all_dim != 0.0 + else self.yarn_mscale + ) + return ops.cast(scaled_freqs, dtype), mscale + + def _apply_rope(self, x, start_index, dtype, inv_freq, mscale): + """Apply Rotary Position Embeddings to x of shape (B, H, S, D).""" + seq_len = ops.shape(x)[-2] + start = 0 if start_index is None else ops.cast(start_index, "float32") + positions = ops.arange(seq_len, dtype="float32") + start + freqs = ops.concatenate([ops.outer(positions, inv_freq)] * 2, axis=-1) + freqs = ops.expand_dims(ops.expand_dims(freqs, 0), 0) # (1,1,S,D) + cos = ops.cos(freqs) * ops.cast(mscale, "float32") + sin = ops.sin(freqs) * ops.cast(mscale, "float32") + x_fp32 = ops.cast(x, "float32") + half = self.qk_rope_head_dim // 2 + rotated = ops.concatenate( + [ + x_fp32[..., :half] * cos[..., :half] + - x_fp32[..., half:] * sin[..., half:], + x_fp32[..., :half] * sin[..., half:] + + x_fp32[..., half:] * cos[..., half:], + ], + axis=-1, + ) + return ops.cast(rotated, dtype) + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=0, + training=False, + ): + batch_size = ops.shape(hidden_states)[0] + seq_len = ops.shape(hidden_states)[1] + dtype = hidden_states.dtype + + # Query projections. + c_q = self.q_down_proj(hidden_states) + q_nope = ops.reshape( + self.q_up_nope_proj(c_q), + [batch_size, seq_len, self.num_query_heads, self.qk_nope_head_dim], + ) + q_nope = ops.transpose(q_nope, [0, 2, 1, 3]) # (B, H, S, D_nope) + + q_rope = ops.reshape( + self.q_up_rope_proj(c_q), + [batch_size, seq_len, self.num_query_heads, self.qk_rope_head_dim], + ) + q_rope = ops.transpose(q_rope, [0, 2, 1, 3]) # (B, H, S, D_rope) + + # KV projections. + c_kv = self.kv_down_proj(hidden_states) # (B, S, kv_lora_rank) + k_rope = self.k_rope_proj(hidden_states) # (B, S, D_rope) + k_rope = ops.expand_dims(k_rope, axis=1) # (B, 1, S, D_rope) + + # Apply RoPE to positional components. + inv_freq, mscale = self._yarn_inv_freq("float32") # compute once + q_rope = self._apply_rope( + q_rope, cache_update_index, dtype, inv_freq, mscale + ) + k_rope = self._apply_rope( + k_rope, cache_update_index, dtype, inv_freq, mscale + ) + + # KV cache: read full history and write current step. + if cache is not None: + c_kv_cache, k_rope_cache = cache + c_kv = ops.slice_update( + c_kv_cache, [0, cache_update_index, 0], c_kv + ) + k_rope_sq = ops.squeeze(k_rope, axis=1) + k_rope = ops.expand_dims( + ops.slice_update( + k_rope_cache, [0, cache_update_index, 0], k_rope_sq + ), + axis=1, + ) + new_cache = (c_kv, ops.squeeze(k_rope, axis=1)) + hist_len = ops.shape(c_kv_cache)[1] + else: + new_cache = None + hist_len = seq_len + + # Content attention scores via MLA absorption trick (eq. 10). + # q_nope @ w_uk.T projects queries into latent KV space, then scores + # against c_kv without materialising per-head K tensors. + w_uk = ops.reshape( + self.w_uk, + [self.num_query_heads, self.qk_nope_head_dim, self.kv_lora_rank], + ) + q_latent = ops.einsum("bhsd,hdk->bhsk", q_nope, w_uk) # (B,H,S,lora) + c_kv_t = ops.expand_dims( + ops.transpose(c_kv, [0, 2, 1]), axis=1 + ) # (B,1,lora,T) + score_content = ops.matmul(q_latent, c_kv_t) # (B,H,S,T) + + # RoPE attention scores. + k_rope_t = ops.transpose(k_rope, [0, 1, 3, 2]) # (B,1,D_rope,T) + score_rope = ops.matmul(q_rope, k_rope_t) # (B,H,S,T) + + scale = ops.cast( + 1.0 + / ops.sqrt( + ops.cast( + self.qk_nope_head_dim + self.qk_rope_head_dim, "float32" + ) + ), + dtype, + ) + scores = ( + ops.cast(score_content, dtype) + ops.cast(score_rope, dtype) + ) * scale + + # XLA-compatible causal mask using static shapes. + idx = ops.cast( + 0 if cache_update_index is None else cache_update_index, "int32" + ) + i_idx = ops.arange(seq_len, dtype="int32")[:, None] + idx + j_idx = ops.arange(hist_len, dtype="int32")[None, :] + causal = ops.reshape( + ops.cast(i_idx >= j_idx, dtype), [1, 1, seq_len, hist_len] + ) + + if attention_mask is not None: + pad = ops.cast(attention_mask, dtype) + if len(pad.shape) == 2: + pad = pad[:, None, None, :] + elif len(pad.shape) == 3: + pad = pad[:, None, :, :] + mask = ops.cast( + ops.logical_and( + ops.cast(causal, "bool"), ops.cast(pad, "bool") + ), + dtype, + ) + else: + mask = causal + + large_neg = ops.cast( + -3e4 if scores.dtype == "float16" else -1e9, scores.dtype + ) + scores = scores + (1.0 - mask) * large_neg + attn_weights = ops.softmax(scores, axis=-1) + attn_weights = self.dropout(attn_weights, training=training) + + # Value computation via W_UV absorption. + c_kv_exp = ops.expand_dims(c_kv, axis=1) # (B,1,T,lora) + ctx = ops.matmul(attn_weights, c_kv_exp) # (B,H,S,lora) + w_uv = ops.reshape( + self.w_uv, + [self.num_query_heads, self.kv_lora_rank, self.v_head_dim], + ) + attn_out = ops.einsum("bhsk,hvk->bhsv", ctx, w_uv) # (B,H,S,v_dim) + + attn_out = ops.reshape( + ops.transpose(attn_out, [0, 2, 1, 3]), + [batch_size, seq_len, self.num_key_value_heads * self.v_head_dim], + ) + output = self.output_proj(attn_out) + + if cache is not None: + return output, new_cache + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "q_lora_rank": self.q_lora_rank, + "kv_lora_rank": self.kv_lora_rank, + "qk_nope_head_dim": self.qk_nope_head_dim, + "qk_rope_head_dim": self.qk_rope_head_dim, + "v_head_dim": self.v_head_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "yarn_beta_fast": self.yarn_beta_fast, + "yarn_beta_slow": self.yarn_beta_slow, + "yarn_mscale": self.yarn_mscale, + "yarn_mscale_all_dim": self.yarn_mscale_all_dim, + "yarn_original_max_position_embeddings": ( + self.yarn_original_max_position_embeddings + ), + "attention_dropout": self.attention_dropout, + "kernel_initializer": self.kernel_initializer, + } + ) + return config diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_backbone.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_backbone.py new file mode 100644 index 0000000000..4d93cc49a1 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_backbone.py @@ -0,0 +1,285 @@ +"""DeepSeek V31 backbone model.""" + +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.deepseek_v31.deepseek_v31_decoder_block import ( + DeepSeekV31DecoderBlock, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_decoder_block import ( + DeepSeekV31RMSNorm, +) + + +def _deepseek_v31_kernel_initializer(stddev=0.006): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.DeepSeekV31Backbone") +class DeepSeekV31Backbone(Backbone): + """DeepSeek V31 core transformer backbone. + + Implements the full DeepSeek-V3 architecture as described in + arXiv:2412.19437. The model uses Multi-head Latent Attention (MLA) for + efficient KV caching, and a Mixture-of-Experts (MoE) feed-forward network + with sigmoid-based routing in all but the first few layers. + + The first `first_k_dense_replace` layers use a dense SwiGLU feed-forward + network; remaining layers use `DeepSeekV31MoE` with `num_routed_experts` + total experts and `num_experts_per_tok` activated per token. + + This backbone outputs the final sequence of hidden states with shape + `(batch_size, sequence_length, hidden_dim)`. + + Args: + vocabulary_size: int. Size of the token vocabulary. + num_layers: int. Number of transformer decoder layers. Defaults to + `61`. + hidden_dim: int. Dimensionality of hidden states. Defaults to `7168`. + num_query_heads: int. Number of query attention heads. Defaults to + `128`. + num_key_value_heads: int. Number of key/value heads (equal to query + heads for MLA). Defaults to `128`. + intermediate_dim: int. Inner dimensionality of FFN layers. Defaults to + `18432`. + q_lora_rank: int. Query down-projection rank. Defaults to `1536`. + kv_lora_rank: int. KV latent rank. Controls the per-token KV cache + size. Defaults to `512`. + qk_nope_head_dim: int. Per-head content (non-RoPE) dimension. Defaults + to `128`. + qk_rope_head_dim: int. Per-head RoPE dimension. Defaults to `64`. + v_head_dim: int. Per-head value dimension. Defaults to `128`. + num_routed_experts: int. Total routed MoE experts per layer. Defaults + to `256`. + num_shared_experts: int. Always-active shared experts per MoE layer. + Defaults to `1`. + num_experts_per_tok: int. Number of routed experts activated per token. + Defaults to `8`. + first_k_dense_replace: int. Number of initial layers that use a dense + FFN instead of MoE. Defaults to `3`. + rope_max_wavelength: int. RoPE base wavelength. Defaults to `10000`. + rope_scaling_factor: float. YaRN context extension scale. Values + greater than 1 extend the effective context length. Defaults to + `1.0`. + yarn_original_max_position_embeddings: int. The context length used + during pre-training, used as the YaRN ramp reference. Defaults to + `4096`. + layer_norm_epsilon: float. Epsilon for RMSNorm layers. Defaults to + `1e-6`. + dropout: float. Dropout rate for attention and residual connections. + Defaults to `0.0`. + + Example: + + ```python + backbone = keras_hub.models.DeepSeekV31Backbone( + vocabulary_size=32000, + num_layers=4, + hidden_dim=512, + num_query_heads=8, + num_key_value_heads=8, + intermediate_dim=1024, + q_lora_rank=256, + kv_lora_rank=128, + qk_nope_head_dim=32, + qk_rope_head_dim=16, + v_head_dim=32, + num_routed_experts=8, + num_experts_per_tok=2, + first_k_dense_replace=1, + ) + token_ids = keras.random.randint((2, 16), 0, 32000) + padding_mask = keras.ones((2, 16), dtype="bool") + output = backbone({"token_ids": token_ids, "padding_mask": padding_mask}) + # (2, 16, 512) + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + def __init__( + self, + vocabulary_size, + num_layers=61, + hidden_dim=7168, + num_query_heads=128, + num_key_value_heads=128, + intermediate_dim=18432, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + num_routed_experts=256, + num_shared_experts=1, + num_experts_per_tok=8, + first_k_dense_replace=3, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + yarn_original_max_position_embeddings=4096, + layer_norm_epsilon=1e-6, + dropout=0.0, + **kwargs, + ): + dtype = kwargs.get("dtype") + # ===== Build sub-layers ===== + token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=False, + embeddings_initializer=_deepseek_v31_kernel_initializer( + stddev=0.01 + ), + name="token_embedding", + dtype=dtype, + ) + + transformer_layers = [] + for i in range(num_layers): + transformer_layers.append( + DeepSeekV31DecoderBlock( + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + intermediate_dim=intermediate_dim, + q_lora_rank=q_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + num_routed_experts=num_routed_experts, + num_shared_experts=num_shared_experts, + num_experts_per_tok=num_experts_per_tok, + use_moe=(i >= first_k_dense_replace), + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + yarn_original_max_position_embeddings=yarn_original_max_position_embeddings, # noqa: E501 + layer_norm_epsilon=layer_norm_epsilon, + dropout=dropout, + kernel_initializer=_deepseek_v31_kernel_initializer( + stddev=0.02 + ), + name=f"transformer_layer_{i}", + dtype=dtype, + ) + ) + + layer_norm = DeepSeekV31RMSNorm( + epsilon=layer_norm_epsilon, + name="layer_norm", + dtype=dtype, + ) + + # ===== Functional model ===== + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="bool", name="padding_mask" + ) + + x = token_embedding(token_id_input) + for layer in transformer_layers: + x = layer(x, attention_mask=padding_mask_input) + sequence_output = layer_norm(x) + + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # ===== Store attributes (must be after super().__init__) ===== + self.token_embedding = token_embedding + self.transformer_layers = transformer_layers + self.layer_norm = layer_norm + + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_dim = intermediate_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.num_routed_experts = num_routed_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.yarn_original_max_position_embeddings = ( + yarn_original_max_position_embeddings + ) + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def _build_cache(self, token_ids): + """Build an empty MLA KV cache for all transformer layers. + + Each layer's cache is a tuple `(c_kv, k_rope)` where: + - `c_kv` has shape `(batch, max_len, kv_lora_rank)` + - `k_rope` has shape `(batch, max_len, qk_rope_head_dim)` + + This is more memory-efficient than standard MHA caching, which would + store full K and V tensors of shape `(batch, heads, max_len, head_dim)`. + """ + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + cache = [] + for _ in range(self.num_layers): + cache.append( + ( + ops.zeros( + [batch_size, max_length, self.kv_lora_rank], + dtype=self.compute_dtype, + ), + ops.zeros( + [batch_size, max_length, self.qk_rope_head_dim], + dtype=self.compute_dtype, + ), + ) + ) + return cache + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "hidden_dim": self.hidden_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "intermediate_dim": self.intermediate_dim, + "q_lora_rank": self.q_lora_rank, + "kv_lora_rank": self.kv_lora_rank, + "qk_nope_head_dim": self.qk_nope_head_dim, + "qk_rope_head_dim": self.qk_rope_head_dim, + "v_head_dim": self.v_head_dim, + "num_routed_experts": self.num_routed_experts, + "num_shared_experts": self.num_shared_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "first_k_dense_replace": self.first_k_dense_replace, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "yarn_original_max_position_embeddings": ( + self.yarn_original_max_position_embeddings + ), + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_backbone_test.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_backbone_test.py new file mode 100644 index 0000000000..b35da860a5 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_backbone_test.py @@ -0,0 +1,68 @@ +import pytest +from keras import ops + +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone, +) +from keras_hub.src.tests.test_case import TestCase + + +class DeepSeekV31BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 1000, + "num_layers": 2, + "hidden_dim": 64, + "num_query_heads": 4, + "num_key_value_heads": 4, + "intermediate_dim": 128, + "q_lora_rank": 16, + "kv_lora_rank": 16, + "qk_nope_head_dim": 16, + "qk_rope_head_dim": 8, + "v_head_dim": 16, + "num_routed_experts": 4, + "num_shared_experts": 1, + "num_experts_per_tok": 2, + "first_k_dense_replace": 1, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="bool"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=DeepSeekV31Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 64), + ) + + def test_num_parameters(self): + model = DeepSeekV31Backbone(**self.init_kwargs) + self.assertGreater(model.count_params(), 0) + + def test_backbone_with_cache(self): + model = DeepSeekV31Backbone(**self.init_kwargs) + token_ids = ops.ones((2, 5), dtype="int32") + cache = model._build_cache(token_ids) + + self.assertIsInstance(cache, list) + self.assertEqual(len(cache), self.init_kwargs["num_layers"]) + for c_kv, k_rope in cache: + self.assertEqual( + c_kv.shape, (2, 5, self.init_kwargs["kv_lora_rank"]) + ) + self.assertEqual( + k_rope.shape, (2, 5, self.init_kwargs["qk_rope_head_dim"]) + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in DeepSeekV31Backbone.presets: + self.run_preset_test( + cls=DeepSeekV31Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm.py new file mode 100644 index 0000000000..867b34e6c2 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm.py @@ -0,0 +1,182 @@ +"""DeepSeek V31 Causal Language Model.""" + +import keras # noqa: F401 +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm_preprocessor import ( # noqa: E501 + DeepSeekV31CausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.DeepSeekV31CausalLM") +class DeepSeekV31CausalLM(CausalLM): + """DeepSeek V31 Causal Language Model. + + Pairs `DeepSeekV31Backbone` with a language model head for next-token + prediction. The LM head reuses the token embedding weights via + `ReversibleEmbedding` (weight tying is off by default for DeepSeek V31, + but the same layer is used for both embedding lookup and logit projection). + + Autoregressive generation uses an MLA-compatible KV cache. Rather than + storing full per-head K and V tensors at each step, only the compressed + latents `c_kv` (shape `[batch, max_len, kv_lora_rank]`) and decoupled RoPE + keys `k_rope` (shape `[batch, max_len, qk_rope_head_dim]`) are cached per + layer. This significantly reduces memory usage compared to standard MHA + caching. + + Args: + backbone: `DeepSeekV31Backbone` instance. The core transformer model. + preprocessor: Optional `DeepSeekV31CausalLMPreprocessor`. Used for + tokenizing inputs before passing to the model. + + Example: + + ```python + backbone = keras_hub.models.DeepSeekV31Backbone(...) + lm = keras_hub.models.DeepSeekV31CausalLM(backbone=backbone) + output = lm.generate("Once upon a time") + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + backbone_cls = DeepSeekV31Backbone + preprocessor_cls = DeepSeekV31CausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + self.backbone = backbone + self.preprocessor = preprocessor + + def _build_cache(self, token_ids): + """Build an empty MLA KV cache for a given token_ids tensor. + + Returns a list of `(c_kv, k_rope)` tuples, one per transformer layer, + each pre-allocated to the full sequence length. The cache is written + to incrementally during autoregressive generation. + """ + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + cache = [] + for _ in range(self.backbone.num_layers): + cache.append( + ( + ops.zeros( + [batch_size, max_length, self.backbone.kv_lora_rank], + dtype=self.compute_dtype, + ), + ops.zeros( + [ + batch_size, + max_length, + self.backbone.qk_rope_head_dim, + ], + dtype=self.compute_dtype, + ), + ) + ) + return cache + + def call_with_cache(self, token_ids, cache, cache_update_index): + """Forward pass with explicit KV cache read/write. + + Threads the cache through each transformer layer by calling sub-layers + directly, bypassing the Keras functional model graph (which does not + support dynamic cache arguments). This is the standard KerasHub pattern + for cached autoregressive generation. + + Args: + token_ids: int32 tensor of shape `(batch, 1)` for single-step + decoding or `(batch, seq_len)` for prefill. + cache: list of `(c_kv, k_rope)` tuples from `_build_cache`. + cache_update_index: int. Position index into the cache to write the + current token's KV entries. + + Returns: + A `(logits, hidden_states, new_cache)` tuple where `logits` has + shape `(batch, seq_len, vocabulary_size)`, `hidden_states` has + shape `(batch, seq_len, hidden_dim)`, and `new_cache` is the + updated list of `(c_kv, k_rope)` tuples. + """ + x = self.backbone.token_embedding(token_ids) + new_cache = [] + for i, layer in enumerate(self.backbone.transformer_layers): + x, updated_cache = layer( + x, + cache=cache[i], + cache_update_index=cache_update_index, + ) + new_cache.append(updated_cache) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, x, new_cache + + def generate_step(self, inputs, stop_token_ids=None): + """XLA-compilable single-batch generation step. + + Args: + inputs: dict with keys `"token_ids"` (int32 tensor) and + `"padding_mask"` (bool tensor). + stop_token_ids: tuple of int token ids. Generation stops when all + sequences have produced at least one stop token. + + Returns: + Updated `inputs` dict with the same keys. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + cache = self._build_cache(token_ids) + + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + index = ops.min(row_lengths) + + def next(prompt, cache, index): + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, cache, cache_update_index + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=ops.zeros( + [ops.shape(token_ids)[0], self.backbone.hidden_dim], + dtype=self.compute_dtype, + ), + model=self, + ) + + if stop_token_ids is not None: + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + padding_mask = ops.ones_like(token_ids, dtype="bool") + + return {"token_ids": token_ids, "padding_mask": padding_mask} diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor.py new file mode 100644 index 0000000000..6bcb766256 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor.py @@ -0,0 +1,62 @@ +"""DeepSeek V3.1 Causal LM Preprocessor.""" + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_tokenizer import ( + DeepSeekV31Tokenizer, +) + + +@keras_hub_export("keras_hub.models.DeepSeekV31CausalLMPreprocessor") +class DeepSeekV31CausalLMPreprocessor(CausalLMPreprocessor): + """DeepSeek V3.1 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.DeepSeekV31CausalLM`. By default, it will take in + batches of strings, and return outputs in a `(x, y, sample_weight)` + format, where the `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this + preprocessor is attached to a `keras_hub.models.DeepSeekV31CausalLM` + instance, these methods will be called implicitly in `generate()`. + + Args: + tokenizer: A `keras_hub.models.DeepSeekV31Tokenizer` instance. + sequence_length: int. The length of the packed inputs. + add_start_token: bool. Whether to prepend the start token. + add_end_token: bool. Whether to append the end token. + + Example: + ```python + preprocessor = keras_hub.models.DeepSeekV31CausalLMPreprocessor.from_preset( + "deepseek_v31_base" + ) + + # Preprocess a batch of strings + sentences = ["Hello, world!", "How are you?"] + x, y, sample_weight = preprocessor(sentences) + ``` + """ + + backbone_cls = DeepSeekV31Backbone + tokenizer_cls = DeepSeekV31Tokenizer + + def __init__( + self, + tokenizer, + sequence_length=1024, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + add_start_token=add_start_token, + add_end_token=add_end_token, + **kwargs, + ) diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor_test.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..9b806a1bae --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor_test.py @@ -0,0 +1,84 @@ +import os # noqa: F401 + +import pytest + +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm_preprocessor import ( # noqa: E501 + DeepSeekV31CausalLMPreprocessor, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_tokenizer import ( + DeepSeekV31Tokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class DeepSeekV31CausalLMPreprocessorTest(TestCase): + def setUp(self): + # "Ġ" maps to 6, and " " maps to 7 to maintain a valid 1:1 mapping. + self.vocab = { + "<|begin▁of▁sentence|>": 151646, + "<|end▁of▁sentence|>": 151643, + "a": 2, + "b": 3, + "c": 4, + "d": 5, + "Ġ": 6, + " ": 7, + } + self.merges = [] + self.tokenizer = DeepSeekV31Tokenizer( + vocabulary=self.vocab, merges=self.merges + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["a b"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=DeepSeekV31CausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[151646, 2, 6, 3, 151643, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[2, 6, 3, 151643, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["a b"] * 4 + + preprocessor = DeepSeekV31CausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[2, 6, 3, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + preprocessor = DeepSeekV31CausalLMPreprocessor(**self.init_kwargs) + preprocessed = preprocessor.generate_preprocess(["a b"]) + self.assertIn("token_ids", preprocessed) + self.assertIn("padding_mask", preprocessed) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=DeepSeekV31CausalLMPreprocessor, + preset="deepseek_v31_base", + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in DeepSeekV31CausalLMPreprocessor.presets: + self.run_preset_test( + cls=DeepSeekV31CausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_test.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_test.py new file mode 100644 index 0000000000..d8359ed1c1 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_test.py @@ -0,0 +1,89 @@ +import os # noqa: F401 + +import pytest +from keras import ops # noqa: F401 + +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm import ( + DeepSeekV31CausalLM, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_causal_lm_preprocessor import ( # noqa: E501 + DeepSeekV31CausalLMPreprocessor, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_tokenizer import ( + DeepSeekV31Tokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class DeepSeekV31CausalLMTest(TestCase): + def setUp(self): + # Explicit special tokens for default ID injection + # Use same vocab as preprocessor test but with larger overall vocabulary + self.vocab = { + "<|begin▁of▁sentence|>": 151646, + "<|end▁of▁sentence|>": 151643, + "a": 2, + "b": 3, + "c": 4, + "d": 5, + "Ġ": 6, + " ": 7, + } + self.merges = [] + self.tokenizer = DeepSeekV31Tokenizer( + vocabulary=self.vocab, merges=self.merges + ) + + self.preprocessor = DeepSeekV31CausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + + # Use large vocabulary size to match expected output shape + self.backbone = DeepSeekV31Backbone( + vocabulary_size=151650, + num_layers=2, + hidden_dim=32, + num_query_heads=4, + num_key_value_heads=4, + intermediate_dim=64, + q_lora_rank=16, + kv_lora_rank=16, + qk_nope_head_dim=16, + qk_rope_head_dim=8, + v_head_dim=16, + num_routed_experts=4, + num_experts_per_tok=2, + first_k_dense_replace=1, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["a b", "b a"],) + + def test_causal_lm_basics(self): + self.run_task_test( + cls=DeepSeekV31CausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 151650), + ) + + def test_generate(self): + causal_lm = DeepSeekV31CausalLM(**self.init_kwargs) + prompt = "a b" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in DeepSeekV31CausalLM.presets: + self.run_preset_test( + cls=DeepSeekV31CausalLM, + preset=preset, + input_data=self.train_data, + ) diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_decoder_block.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_decoder_block.py new file mode 100644 index 0000000000..6de28e9cf9 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_decoder_block.py @@ -0,0 +1,416 @@ +"""DeepSeek V31 transformer decoder block, RMSNorm, and dense FFN.""" + +import keras +from keras import ops + +from keras_hub.src.models.deepseek_v31.deepseek_v31_attention import ( + DeepSeekV31Attention, +) +from keras_hub.src.models.deepseek_v31.deepseek_v31_moe import DeepSeekV31MoE + + +class DeepSeekV31RMSNorm(keras.layers.Layer): + """Root Mean Square Layer Normalization for DeepSeek V31. + + Applies RMS normalization using float32 precision internally to avoid + numerical instability with fp16/bf16 training, then casts back to the + layer's compute dtype. This matches the reference DeepSeek implementation. + + Unlike `LayerNormalization`, RMSNorm does not subtract the mean, which + reduces computation while preserving re-scaling performance. + + Args: + epsilon: float. Small value added to the RMS denominator for numerical + stability. Defaults to `1e-6`. + + Example: + + ```python + norm = keras_hub.layers.DeepSeekV31RMSNorm(epsilon=1e-6) + x = keras.random.normal((2, 16, 512)) + output = norm(x) # (2, 16, 512) + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + - [Zhang & Sennrich, 2019](https://arxiv.org/abs/1910.07467) + """ + + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.scale = self.add_weight( + name="scale", + shape=input_shape[-1:], + initializer="ones", + ) + super().build(input_shape) + + def call(self, x): + x_fp32 = ops.cast(x, "float32") + rms = ops.rsqrt( + ops.mean(ops.square(x_fp32), axis=-1, keepdims=True) + self.epsilon + ) + return ops.cast(x_fp32 * rms, self.compute_dtype) * ops.cast( + self.scale, self.compute_dtype + ) + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config + + +class DeepSeekV31DenseFeedForward(keras.layers.Layer): + """Dense SwiGLU feed-forward network for DeepSeek V31. + + Used for the first `first_k_dense_replace` transformer layers before the + MoE layers begin. Implements the gated activation function: + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + + Args: + hidden_dim: int. Input and output dimensionality. + intermediate_dim: int. Inner dimensionality of the gated projection. + kernel_initializer: string or initializer. Initializer for Dense kernel + weights. Defaults to `"glorot_uniform"`. + + Example: + + ```python + ffn = keras_hub.layers.DeepSeekV31DenseFeedForward( + hidden_dim=512, + intermediate_dim=1024, + ) + x = keras.random.normal((2, 16, 512)) + output = ffn(x) # (2, 16, 512) + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + def __init__( + self, + hidden_dim, + intermediate_dim, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.kernel_initializer = kernel_initializer + + self.gate_proj = keras.layers.Dense( + intermediate_dim, + activation="silu", + use_bias=False, + kernel_initializer=kernel_initializer, + name="gate_proj", + dtype=self.dtype_policy, + ) + self.up_proj = keras.layers.Dense( + intermediate_dim, + use_bias=False, + kernel_initializer=kernel_initializer, + name="up_proj", + dtype=self.dtype_policy, + ) + self.down_proj = keras.layers.Dense( + hidden_dim, + use_bias=False, + kernel_initializer=kernel_initializer, + name="down_proj", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + inner_shape = list(input_shape[:-1]) + [self.intermediate_dim] + self.down_proj.build(inner_shape) + super().build(input_shape) + + def call(self, x): + return self.down_proj(self.gate_proj(x) * self.up_proj(x)) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "kernel_initializer": self.kernel_initializer, + } + ) + return config + + +class DeepSeekV31DecoderBlock(keras.layers.Layer): + """Transformer decoder block for DeepSeek V31. + + Implements the pre-norm residual block structure from Figure 2 of the paper: + + x = x + Attention(RMSNorm(x)) + x = x + FFN(RMSNorm(x)) + + where FFN is either a dense SwiGLU network (for the first + `first_k_dense_replace` layers) or a `DeepSeekV31MoE` layer. + + The KV cache format matches the MLA architecture: each layer stores a tuple + `(c_kv, k_rope)` of compressed latents rather than full K/V tensors. + + Args: + hidden_dim: int. Dimensionality of hidden states. + num_query_heads: int. Number of query attention heads. + num_key_value_heads: int. Number of key/value heads. + intermediate_dim: int. Inner dimensionality of the FFN. + q_lora_rank: int. Query down-projection rank. + kv_lora_rank: int. KV latent rank (controls KV cache size per token). + qk_nope_head_dim: int. Per-head content (non-RoPE) dimension. + qk_rope_head_dim: int. Per-head RoPE dimension. + v_head_dim: int. Per-head value dimension. + num_routed_experts: int. Total routed experts in the MoE layer. + Defaults to `256`. + num_shared_experts: int. Always-active shared experts. Defaults to `1`. + num_experts_per_tok: int. Top-K experts activated per token. Defaults + to `8`. + use_moe: bool. If `True`, uses `DeepSeekV31MoE` as the FFN; otherwise + uses `DeepSeekV31DenseFeedForward`. Defaults to `True`. + rope_max_wavelength: int. RoPE base wavelength. Defaults to `10000`. + rope_scaling_factor: float. YaRN context extension factor. Defaults to + `1.0`. + yarn_original_max_position_embeddings: int. Pre-training context length + for YaRN ramp computation. Defaults to `4096`. + layer_norm_epsilon: float. RMSNorm epsilon. Defaults to `1e-6`. + dropout: float. Dropout rate for attention weights and residuals. + Defaults to `0.0`. + kernel_initializer: string or initializer. Initializer for all + sub-layer kernel weights. Defaults to `"glorot_uniform"`. + + Example: + + ```python + block = keras_hub.layers.DeepSeekV31DecoderBlock( + hidden_dim=512, + num_query_heads=8, + num_key_value_heads=8, + intermediate_dim=1024, + q_lora_rank=256, + kv_lora_rank=128, + qk_nope_head_dim=32, + qk_rope_head_dim=16, + v_head_dim=32, + num_routed_experts=8, + num_experts_per_tok=2, + use_moe=True, + ) + x = keras.random.normal((2, 16, 512)) + output = block(x) # (2, 16, 512) + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + def __init__( + self, + hidden_dim, + num_query_heads, + num_key_value_heads, + intermediate_dim, + q_lora_rank, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + num_routed_experts=256, + num_shared_experts=1, + num_experts_per_tok=8, + use_moe=True, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + yarn_original_max_position_embeddings=4096, + layer_norm_epsilon=1e-6, + dropout=0.0, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + + # Store every __init__ argument as + # an attribute (style guide requirement). + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_dim = intermediate_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.num_routed_experts = num_routed_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.use_moe = use_moe + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.yarn_original_max_position_embeddings = ( + yarn_original_max_position_embeddings + ) + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.kernel_initializer = kernel_initializer + + self.pre_attention_norm = DeepSeekV31RMSNorm( + epsilon=layer_norm_epsilon, + name="pre_attention_norm", + dtype=self.dtype_policy, + ) + self.attention = DeepSeekV31Attention( + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + q_lora_rank=q_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + yarn_original_max_position_embeddings=yarn_original_max_position_embeddings, # noqa: E501 + attention_dropout=dropout, + kernel_initializer=kernel_initializer, + name="attention", + dtype=self.dtype_policy, + ) + self.pre_ffn_norm = DeepSeekV31RMSNorm( + epsilon=layer_norm_epsilon, + name="pre_ffn_norm", + dtype=self.dtype_policy, + ) + + if use_moe: + self.ffn = DeepSeekV31MoE( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_routed_experts=num_routed_experts, + num_shared_experts=num_shared_experts, + num_experts_per_tok=num_experts_per_tok, + kernel_initializer=kernel_initializer, + name="ffn", + dtype=self.dtype_policy, + ) + else: + self.ffn = DeepSeekV31DenseFeedForward( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + kernel_initializer=kernel_initializer, + name="ffn", + dtype=self.dtype_policy, + ) + + self.residual_dropout = keras.layers.Dropout( + dropout, + dtype=self.dtype_policy, + ) + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=False, + ): + # Pre-attention norm + MLA. + attn_out = self.attention( + self.pre_attention_norm(hidden_states), + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + training=training, + ) + + if isinstance(attn_out, tuple): + attn_out, new_cache = attn_out + else: + new_cache = None + + hidden_states = hidden_states + self.residual_dropout( + attn_out, training=training + ) + + # Pre-FFN norm + FFN (MoE or Dense). + ffn_out = self.ffn(self.pre_ffn_norm(hidden_states), training=training) + hidden_states = hidden_states + self.residual_dropout( + ffn_out, training=training + ) + + if new_cache is not None: + return hidden_states, new_cache + return hidden_states + + def compute_output_spec( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + # When cache is provided the layer returns (hidden_states, new_cache). + # This override is required so Keras symbolic tracing handles the tuple + # output correctly rather than assuming a single tensor output. + try: + input_shape = hidden_states.shape + except Exception: + input_shape = None + + hidden_out = keras.KerasTensor(input_shape, dtype=self.compute_dtype) + + if cache is not None: + c_kv, k_rope = cache + new_cache_spec = ( + keras.KerasTensor( + getattr(c_kv, "shape", None), dtype=self.compute_dtype + ), + keras.KerasTensor( + getattr(k_rope, "shape", None), dtype=self.compute_dtype + ), + ) + return hidden_out, new_cache_spec + + return hidden_out + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "intermediate_dim": self.intermediate_dim, + "q_lora_rank": self.q_lora_rank, + "kv_lora_rank": self.kv_lora_rank, + "qk_nope_head_dim": self.qk_nope_head_dim, + "qk_rope_head_dim": self.qk_rope_head_dim, + "v_head_dim": self.v_head_dim, + "num_routed_experts": self.num_routed_experts, + "num_shared_experts": self.num_shared_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "use_moe": self.use_moe, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "yarn_original_max_position_embeddings": ( + self.yarn_original_max_position_embeddings + ), + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "kernel_initializer": self.kernel_initializer, + } + ) + return config diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_moe.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_moe.py new file mode 100644 index 0000000000..6fa719f5e9 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_moe.py @@ -0,0 +1,249 @@ +"""DeepSeek V31 Mixture-of-Experts layer.""" + +import keras +from keras import ops + + +class DeepSeekV31MoE(keras.layers.Layer): + """Mixture-of-Experts (MoE) layer for DeepSeek V31. + + Implements DeepSeekMoE routing as described in Section 2.1.2 of the paper. + Each token is routed to `num_experts_per_tok` out of `num_routed_experts` + routed experts, plus `num_shared_experts` always-active shared experts. + + Routing uses sigmoid-based affinity scores (not softmax), and normalization + is applied only to the selected top-K scores: + + s_i = sigmoid(u^T e_i) + g_i = s_i / sum_{j in TopK} s_j for i in TopK + + Note on load balancing: The auxiliary-loss-free bias terms described in + Section 2.1.2 are a training-time mechanism and are not implemented here. + Inference uses standard top-K routing without bias correction. + + This implementation vectorizes the routed expert computation using batched + tensor operations (`ops.einsum`), which avoids graph bloat and ensures + compatibility with XLA compilation (`jit_compile=True`). + + Args: + hidden_dim: int. Dimensionality of input and output hidden states. + intermediate_dim: int. Inner dimensionality of each expert's SwiGLU FFN. + num_routed_experts: int. Total number of routed experts. Defaults to + `256`. + num_shared_experts: int. Number of always-active shared experts. + Defaults to `1`. + num_experts_per_tok: int. Number of routed experts activated per token + (top-K). Defaults to `8`. + kernel_initializer: string or initializer. Initializer for all Dense + kernel weights. Defaults to `"glorot_uniform"`. + + Example: + + ```python + moe = keras_hub.layers.DeepSeekV31MoE( + hidden_dim=512, + intermediate_dim=1024, + num_routed_experts=8, + num_shared_experts=1, + num_experts_per_tok=2, + ) + hidden = keras.random.normal((2, 16, 512)) + output = moe(hidden) # (2, 16, 512) + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + def __init__( + self, + hidden_dim, + intermediate_dim, + num_routed_experts=256, + num_shared_experts=1, + num_experts_per_tok=8, + kernel_initializer="glorot_uniform", + epsilon=1e-9, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_routed_experts = num_routed_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.kernel_initializer = kernel_initializer + self.epsilon = epsilon + + def build(self, input_shape): + # Router: maps hidden states to per-expert affinity logits. + self.router = keras.layers.Dense( + self.num_routed_experts, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="router", + ) + self.router.build(input_shape) + + # Shared experts (SwiGLU) — always active, not gated. + if self.num_shared_experts > 0: + shared_dim = self.intermediate_dim * self.num_shared_experts + self.shared_gate_proj = keras.layers.Dense( + shared_dim, + activation="silu", + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="shared_gate_proj", + ) + self.shared_up_proj = keras.layers.Dense( + shared_dim, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="shared_up_proj", + ) + self.shared_down_proj = keras.layers.Dense( + self.hidden_dim, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="shared_down_proj", + ) + self.shared_gate_proj.build(input_shape) + self.shared_up_proj.build(input_shape) + shared_inner_shape = list(input_shape[:-1]) + [shared_dim] + self.shared_down_proj.build(shared_inner_shape) + + # Routed experts (SwiGLU) + # Stacked into batched tensors for vectorized XLA-compatible execution. + self.expert_gate_kernel = self.add_weight( + shape=( + self.num_routed_experts, + self.hidden_dim, + self.intermediate_dim, + ), + initializer=self.kernel_initializer, + name="expert_gate_kernel", + ) + self.expert_up_kernel = self.add_weight( + shape=( + self.num_routed_experts, + self.hidden_dim, + self.intermediate_dim, + ), + initializer=self.kernel_initializer, + name="expert_up_kernel", + ) + self.expert_down_kernel = self.add_weight( + shape=( + self.num_routed_experts, + self.intermediate_dim, + self.hidden_dim, + ), + initializer=self.kernel_initializer, + name="expert_down_kernel", + ) + + super().build(input_shape) + + def call(self, hidden_states, training=False): + compute_dtype = hidden_states.dtype + + # Run router in float32 for numerical stability. + router_logits = ops.cast( + self.router(ops.cast(hidden_states, "float32")), "float32" + ) + + # Sigmoid affinity scores (Section 2.1.2, eq. 15). + # DeepSeek-V3 uses sigmoid instead of softmax used in V2. + affinity_scores = ops.sigmoid(router_logits) + + # Top-K selection. + top_k_scores, top_k_indices = ops.top_k( + affinity_scores, k=self.num_experts_per_tok + ) + + # Normalize only the selected K scores (eq. 13). + top_k_weights = top_k_scores / ( + ops.sum(top_k_scores, axis=-1, keepdims=True) + self.epsilon + ) + top_k_weights = ops.cast(top_k_weights, compute_dtype) + + output = ops.zeros_like(hidden_states) + + # Shared expert contribution (always active). + if self.num_shared_experts > 0: + shared_out = self.shared_down_proj( + self.shared_gate_proj(hidden_states) + * self.shared_up_proj(hidden_states) + ) + output = output + ops.cast(shared_out, compute_dtype) + + # =================================================================== + # Vectorized Routed Expert Contributions + # =================================================================== + + # 1. Create a dense routing mask of shape (..., num_routed_experts) + # one_hot shape: (..., K, E) + mask = ops.one_hot(top_k_indices, self.num_routed_experts) + + # Multiply by weights: expand top_k_weights to (..., K, 1) to broadcast + weights_expanded = ops.expand_dims(top_k_weights, axis=-1) + mask_weighted = mask * ops.cast(weights_expanded, mask.dtype) + + # Sum over K to get the final per-expert routing weights: shape (..., E) + router_mask = ops.sum(mask_weighted, axis=-2) + router_mask = ops.cast(router_mask, compute_dtype) + + # 2. Compute Gate and Up projections for all experts simultaneously + # hidden_states: (..., H) + # expert_kernels: (E, H, I) + # einsum naturally broadcasts over + # the missing E dimension to compute (..., E, I) + gate_out = ops.einsum( + "...h,ehi->...ei", hidden_states, self.expert_gate_kernel + ) + up_out = ops.einsum( + "...h,ehi->...ei", hidden_states, self.expert_up_kernel + ) + + # 3. Apply SwiGLU activation and the routing mask + expert_act = ops.silu(gate_out) * up_out + + # Expand router_mask to (..., E, 1) for broadcasting over I + router_mask_expanded = ops.expand_dims(router_mask, axis=-1) + + # Zero-out inactive experts and + # scale active ones by their affinity scores + expert_act_weighted = expert_act * router_mask_expanded + + # 4. Compute Down projection and sum over experts simultaneously + # expert_act_weighted: (..., E, I) + # expert_down_kernel: (E, I, H) + # This einsum performs the matmul and + # sums over the E dimension in one step + # Output shape: (..., H) + routed_out = ops.einsum( + "...ei,eih->...h", expert_act_weighted, self.expert_down_kernel + ) + + output = output + routed_out + + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_routed_experts": self.num_routed_experts, + "num_shared_experts": self.num_shared_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "kernel_initializer": self.kernel_initializer, + "epsilon": self.epsilon, + } + ) + return config diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_presets.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_presets.py new file mode 100644 index 0000000000..a10b5482b0 --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_presets.py @@ -0,0 +1,69 @@ +"""DeepSeek V3 model preset configurations.""" + +# Metadata for loading pretrained model weights and configurations. +backbone_presets = { + "deepseek_v31_base": { + "metadata": { + "description": ( + "671 billion parameter, 61-layer, base DeepSeek V3 model. " + "MoE architecture with 256 routed experts (8 per token). " + "37B activated parameters." + ), + "params": 671000000000, + "path": "deepseek_v31", + "model_type": "MoE", + "tokenizer": "DeepSeekV31Tokenizer", + }, + "kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/base/1", + }, + "deepseek_v31": { + "metadata": { + "description": ( + "671 billion parameter, 61-layer, instruction-tuned " + "DeepSeek V3 model. MoE architecture with 256 routed " + "experts (8 per token). 37B activated parameters." + ), + "params": 671000000000, + "path": "deepseek_v31", + "model_type": "MoE", + "tokenizer": "DeepSeekV31Tokenizer", + }, + "kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/instruct/1", + }, +} + +# Tokenizer presets +tokenizer_presets = { + "deepseek_v31_base": { + "metadata": { + "description": "DeepSeek V3 tokenizer.", + "path": "deepseek_v31", + }, + "kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/tokenizer/1", + }, + "deepseek_v31": { + "metadata": { + "description": "DeepSeek V3 tokenizer.", + "path": "deepseek_v31", + }, + "kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/tokenizer/1", + }, +} + +# Preprocessor presets +preprocessor_presets = { + "deepseek_v31_base": { + "metadata": { + "description": "DeepSeek V3 preprocessor.", + "path": "deepseek_v31", + }, + "kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/preprocessor/1", + }, + "deepseek_v31": { + "metadata": { + "description": "DeepSeek V3 preprocessor.", + "path": "deepseek_v31", + }, + "kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/preprocessor/1", + }, +} diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer.py new file mode 100644 index 0000000000..671452817d --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer.py @@ -0,0 +1,106 @@ +"""DeepSeek V31 tokenizer.""" + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.deepseek_v31.deepseek_v31_backbone import ( + DeepSeekV31Backbone, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export("keras_hub.tokenizers.DeepSeekV31Tokenizer") +class DeepSeekV31Tokenizer(BytePairTokenizer): + """Tokenizer for DeepSeek V31 models. + + This tokenizer uses Byte-Pair Encoding (BPE) with the DeepSeek V31 + vocabulary (~128K tokens). It adds special tokens for sequence boundary + marking: + + - `bos_token` (`<|begin▁of▁sentence|>`, id 151646): prepended to every + sequence during generation. + - `eos_token` (`<|end▁of▁sentence|>`, id 151643): used as the generation + stop token. + - `pad_token_id` (0): used for padding batched inputs. + + Args: + vocabulary: dict or str. Token-to-id mapping as a Python dict, or a + path to a JSON vocabulary file. Mutually exclusive with `proto`. + merges: list or str. BPE merge rules as a list of strings `"a b"`, or + a path to a merges file. Mutually exclusive with `proto`. + proto: str. Path to a SentencePiece `.model` file. When provided, + `vocabulary` and `merges` will be extracted automatically. + + Example: + + ```python + tokenizer = keras_hub.tokenizers.DeepSeekV31Tokenizer.from_preset( + "deepseek_v31_base" + ) + tokenizer.tokenize("Hello, world!") + # [13225, 11, 1879, 0] + tokenizer.detokenize([[13225, 11, 1879, 0]]) + # ["Hello, world!"] + ``` + + Reference: + - [DeepSeek-AI et al., 2024](https://arxiv.org/abs/2412.19437) + """ + + backbone_cls = DeepSeekV31Backbone + + def __init__( + self, + vocabulary=None, + merges=None, + proto=None, + **kwargs, + ): + # Handle SentencePiece proto: extract vocab and merges before calling + # super().__init__, since BytePairTokenizer needs them at construction. + if proto is not None: + from keras.src.saving import serialization_lib + + if isinstance(proto, str) and serialization_lib.in_safe_mode(): + raise ValueError( + "Requested the loading of a SentencePiece proto file " + "outside of the model archive. This carries a " + "potential risk of loading arbitrary and sensitive files" + " and thus it is disallowed " + "by default. If you trust the source of the artifact, you " + "can override this error by passing `safe_mode=False` to " + "the loading function, or calling " + "`keras.config.enable_unsafe_deserialization()`." + ) + try: + import sentencepiece as spm + + sp = spm.SentencePieceProcessor() + sp.Load(proto) + vocabulary = { + sp.IdToPiece(i): i for i in range(sp.GetPieceSize()) + } + merges = [] + except ImportError: + raise ImportError( + "Loading a SentencePiece proto requires the `sentencepiece`" + " package. Install it with `pip install sentencepiece`." + ) + + # BytePairTokenizer requires at least one merge rule to initialise its + # internal StaticHashTable. Inject a harmless placeholder when the + # merge list is empty (e.g. when loading from a SentencePiece proto). + if ( + isinstance(merges, list) + and len(merges) == 0 + and vocabulary is not None + ): + merges = ["a b"] + + super().__init__(vocabulary=vocabulary, merges=merges, **kwargs) + + self._add_special_token("<|begin▁of▁sentence|>", "bos_token") + self._add_special_token("<|end▁of▁sentence|>", "eos_token") + + self.start_token = "<|begin▁of▁sentence|>" + self.start_token_id = 151646 + self.end_token_id = 151643 + self.pad_token_id = 0 diff --git a/keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer_test.py b/keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer_test.py new file mode 100644 index 0000000000..dd89b494dc --- /dev/null +++ b/keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer_test.py @@ -0,0 +1,58 @@ +import os # noqa: F401 + +import pytest + +from keras_hub.src.models.deepseek_v31.deepseek_v31_tokenizer import ( + DeepSeekV31Tokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class DeepSeekV31TokenizerTest(TestCase): + def setUp(self): + self.vocab = { + "<|begin▁of▁sentence|>": 151646, + "<|end▁of▁sentence|>": 151643, + } + for i, c in enumerate( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ Ġ" + ): + self.vocab[c] = i + 2 + + # Register the fully formed BPE chunks that get merged + # to prevent dropping them during detokenize + self.vocab["th"] = 100 + self.vocab["ea"] = 101 + + self.merges = ["t h", "e a"] + + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=DeepSeekV31Tokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_special_tokens(self): + tokenizer = DeepSeekV31Tokenizer(**self.init_kwargs) + self.assertEqual(tokenizer.start_token_id, 151646) + self.assertEqual(tokenizer.end_token_id, 151643) + + def test_tokenizer_vocab_size(self): + tokenizer = DeepSeekV31Tokenizer(**self.init_kwargs) + self.assertGreater(tokenizer.vocabulary_size(), 0) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in DeepSeekV31Tokenizer.presets: + self.run_preset_test( + cls=DeepSeekV31Tokenizer, + preset=preset, + input_data=["the quick brown fox"], + )