Skip to content

(WIP) DeepseekV3 (and Multi-Head Latent Attention) #2012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
249f722
v2
simoneangarano Feb 24, 2025
8350ac2
added link to results
simoneangarano Feb 24, 2025
2a599da
uodated README_MLA
simoneangarano Feb 25, 2025
9ab7ed8
Updated README_MLA.md
simoneangarano Feb 25, 2025
7ce13ff
Update README_MLA.md
simoneangarano Feb 25, 2025
f46a2b1
add more comments and visual representation
simoneangarano Mar 10, 2025
c3eef3f
Merge branch 'main' of https://github.com/simoneangarano/litgpt
simoneangarano Mar 10, 2025
98579a0
Merge branch 'main' into main
Borda Mar 11, 2025
a7fb896
Merge branch 'main' into main
Borda Mar 12, 2025
8b030ec
Merge branch 'main' into main
Borda Apr 3, 2025
48fb11d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2025
5dc3985
Merge branch 'main' into main
Borda Apr 3, 2025
07b0538
typo
Borda Apr 3, 2025
27d3d40
Merge branch 'main' into main
Borda Apr 7, 2025
18af658
MLA: modified to support specifying custom values for q_lora_rank, v_…
ysjprojects Apr 13, 2025
6cf4282
clean up
ysjprojects Apr 13, 2025
15727c6
clean up
ysjprojects Apr 13, 2025
47bd94e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2025
ebeb67f
Merge branch 'main' into pr-feature-mla
ysjprojects Apr 13, 2025
43187c2
major change ref
ysjprojects Apr 23, 2025
56e62ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2025
9f35b9c
Merge branch 'main' into pr-feature-mla
ysjprojects May 15, 2025
ba55cf1
feat: deepseekv3 architecture
May 16, 2025
7e3ea78
deepseekv3
May 16, 2025
dcc89de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions README_MLA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Multi-Head Latent Attention (MLA)

## Overview
This document outlines the modifications made to the codebase in the `litgpt` repository to add support for Multi-Head Latent Attention (MLA) block from [DeepSeekV2](https://arxiv.org/abs/2405.04434).

## Changes Made
1. **Configuration**: Added `latent_attention: Optional[bool] = False` parameter to the configuration file to enable the MLA block.
2. **MLA module**: Implemented the MLA module as a separate component in the `litgpt` codebase.
3. **KVCacheCompressed**: Added support for the `KVCacheCompressed` class to store the key-value pairs for the MLA block.
4. **Model**: Modified the GPT model to include the **MLA block** as an alternative component based on the configuration parameter `latent_attention`.
5. **Training**: Updated the training script to support the MLA block and added support for training with the new configuration file `config_hub/pretrain/cfg.yaml`.

## Installation
Follow the updated installation instructions in the `README.md` file.

## Usage
1. **Configuration**: Set the `latent_attention` parameter to `True` in the configuration file to enable the MLA block.
2. **Training**: Run the training script with the updated configuration file.
```bash
litgpt pretrain --config config_hub/pretrain/cfg.yaml
```
3. **Inference**: Use the trained model for inference as follows:
```bash
litgpt generate out/pretrain/mla/final/
```

## Results
Results are available at [this link](https://docs.google.com/spreadsheets/d/1-VnTDoK5JuNPGMjory_z1hQkI7y-RgiTpTsUpa3bVEg/edit?usp=sharing).

The results highlight that MQA and GQA considerably reduce memory usage and increase the speed of training. However, this comes at the cost of a significant decrease in performance compared to the baseline model.

The MLA block demonstrates a better trade-off between memory usage, speed, and performance. It shows a slight drop in performance compared to the baseline model, while also reducing memory usage. This also comes with a slight increase in training and inference speed. Smaller projection dimensions have been tested for the MLA block, showing a consistent reduction of memory usage but with a significant drop in performance.

Overall, results are not as significant as expected due to the small scale of the model (limited by the GPU memory) and the short training time (~10k steps). Further experiments on larger models, bigger datasets, and longer training are expected to highlight the benefits of the MLA block. Also, further experiments with layer normalization and other hyperparameters are expected to improve the performance of the MLA block.

## Notes
- Pythia was used as model for the experiments because it comes with many versions at different scales.
- `pythia-160m` (160M parameters) was the largest model that could be trained on a single GPU with 16GB memory.
- For the same reason, the `tinystories` dataset was used for the experiments and the models were trained for only 100M tokens (~10k steps).
- Experiments on larger models, bigger datasets, and longer training are expected to further highlight the benefits of the MLA block.
- All the tested implementations use FlashAttention (as implemented in torch) by default.
- The resulting implementation of MLA depends on the `litgpt` codebase (especially the `CausalSelfAttention` class).
- The implementation of the MLA block is based on the DeepSeekV2 paper and includes support for KV caching (`KVCacheCompressed`) and decoupled RoPE (`apply_rope_mla`).
- A further improvement would be to optimize the implementation for speed and memory usage (for example, by merging matrices at inference like in LoRA).
> Fortunately, due to the associative law of matrix multiplication, we can absorb $𝑊^{𝑈𝐾}$ into $𝑊^{𝑈𝑄}$ , and $𝑊^{𝑈𝑉}$ into $𝑊^{𝑂}$. Therefore, we do not need to compute keys and values out for each query. Through this optimization, we avoid the computational overhead for recomputing $k^C_t$ and $v^𝐶_𝑡$ during inference.

Unfortunately, this was not implemented due to time constraints.

## Visual Representation
The visual representation of the MLA block with my implementation notes is as follows:

![MLA Block](./mla.png)
131 changes: 131 additions & 0 deletions config_hub/pretrain/cfg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: pythia-160m

# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:
name: pythia-160m
hf_config:
org: EleutherAI
name: pythia-160m
block_size: 2048
n_layer: 12
n_embd: 768
n_head: 12
padding_multiple: 128
norm_class_name: LayerNorm
norm_qk: false

# Whether to use latent attention (MLA). (type: bool, default: false)
latent_attention: true
# Whether to use MQA (head_size = 1), MLA (1 < head_size < n_head), or MHA (head_size = n_head).
# Not compatible with latent_attention.
n_query_groups: 12

# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/mla

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:

# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false

# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: TinyStories

# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000

# Number of iterations between logging calls (type: int, default: 1)
log_interval: 100

# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
global_batch_size: 128

# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4

# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 100

# Number of epochs to train on (type: Optional[int], default: null)
epochs:

# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 100000000

# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:

# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length:

# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:

# (type: Optional[float], default: 1.0)
max_norm: 1.0

# (type: float, default: 4e-05)
min_lr: 6e-5

# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000

# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:

# Number of iterations (type: int, default: 100)
max_iters: 100

# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false

# Whether to evaluate on the validation set at the end the training
final_validation: true

# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW

init_args:
# (type: float, default: 0.001)
lr: 6e-4

# (type: float, default: 0.01)
weight_decay: 0.1

# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95

# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto

# How many nodes to use. (type: int, default: 1)
num_nodes: 1

# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/EleutherAI/pythia-160m

# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
logger_name: tensorboard

# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
13 changes: 12 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def find_multiple(n: int, k: int) -> int:
return n + k - (n % k)


@dataclass
class MLAConfig:
q_proj_dim: Optional[int] = None
kv_proj_dim: Optional[int] = None
qk_rope_dim: Optional[int] = None
qk_nope_dim: Optional[int] = None
v_dim: Optional[int] = None


@dataclass
class Config:
name: str = ""
Expand All @@ -45,6 +54,8 @@ class Config:
# Transformer block (self-attention)
n_head: int = 32
head_size: Optional[int] = None
latent_attention: Optional[bool] = False
mla: MLAConfig = field(default_factory=MLAConfig)
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
Expand Down Expand Up @@ -82,7 +93,7 @@ class Config:
# Transformer block (MLP)
intermediate_size: Optional[int] = None
bias: bool = True
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE", "DeepseekV3MoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
n_expert: int = 0
n_expert_per_token: int = 0
Expand Down
Loading
Loading