Skip to content

Commit 0f31481

Browse files
committed
Merge branch 'main' into refactor-llm
2 parents 5fffd8c + 7de581c commit 0f31481

15 files changed

+457
-401
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
🔍 Explore our models on
1717
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤗%20Huggingface)](https://huggingface.co/xtuner)
1818
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤖%20ModelScope)](https://www.modelscope.cn/organization/xtuner)
19+
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🧰%20OpenXLab)](https://openxlab.org.cn/usercenter/xtuner)
1920

2021
English | [简体中文](README_zh-CN.md)
2122

requirements/runtime.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ tiktoken
1818
# limit pytorch version <= 2.1.2 as there may be some bugs in triton 2.2
1919
torch<=2.1.2
2020
torchvision<=0.16.2
21-
# Minimum 4.34.0 to support added_tokens_decoder of tokenizer
22-
# Exclude 4.34.1, 4.35.0, 4.35.1, 4.35.2 to avoid BC-break,
23-
# see https://github.com/huggingface/transformers/pull/27020, https://github.com/huggingface/transformers/pull/27073
24-
transformers>=4.34.0,!=4.34.1,!=4.35.0,!=4.35.1,!=4.35.2
21+
# Minimum 4.36.0 to support `Cache` data structure used by KV Cache
22+
transformers>=4.36.0
2523
transformers_stream_generator

xtuner/engine/hooks/evaluate_chat_hook.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def __init__(self,
2929
every_n_iters=None,
3030
max_new_tokens=600,
3131
stop_word=None,
32-
stop_words=[]):
32+
stop_words=[],
33+
generation_kwargs={}):
3334
self.evaluation_inputs = evaluation_inputs
3435
if isinstance(self.evaluation_inputs, str):
3536
self.evaluation_inputs = [self.evaluation_inputs]
@@ -69,8 +70,9 @@ def __init__(self,
6970
if image_processor is not None:
7071
self.image_processor = BUILDER.build(image_processor)
7172
self.stop_criteria = StoppingCriteriaList()
73+
7274
# default generation config
73-
self.gen_config = GenerationConfig(
75+
default_generation_kwargs = dict(
7476
max_new_tokens=max_new_tokens,
7577
do_sample=True,
7678
temperature=0.1,
@@ -79,8 +81,10 @@ def __init__(self,
7981
eos_token_id=self.tokenizer.eos_token_id,
8082
pad_token_id=self.tokenizer.pad_token_id
8183
if self.tokenizer.pad_token_id is not None else
82-
self.tokenizer.eos_token_id,
83-
)
84+
self.tokenizer.eos_token_id)
85+
default_generation_kwargs.update(generation_kwargs)
86+
self.gen_config = GenerationConfig(**default_generation_kwargs)
87+
8488
self.stop_criteria = StoppingCriteriaList()
8589
for word in stop_words:
8690
self.stop_criteria.append(

xtuner/engine/hooks/throughput_hook.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import logging
23
from typing import Optional, Union
34

45
import torch
6+
from mmengine import print_log
57
from mmengine.hooks import Hook
68
from mmengine.model.wrappers import is_model_wrapper
79
from torch.utils._pytree import tree_flatten
810

11+
from xtuner.parallel.sequence import get_sequence_parallel_world_size
12+
913
DATA_BATCH = Optional[Union[dict, tuple, list]]
1014

1115

@@ -20,12 +24,39 @@ def __init__(self,
2024
hidden_size=None,
2125
num_layers=None,
2226
vocab_size=None,
23-
mlp_ratio=None):
27+
mlp_ratio=None,
28+
is_casual=None):
2429
self.use_activation_checkpointing = use_activation_checkpointing
2530
self.hidden_size = hidden_size
2631
self.num_layers = num_layers
2732
self.vocab_size = vocab_size
2833
self.mlp_ratio = mlp_ratio
34+
self.is_casual = is_casual
35+
36+
@staticmethod
37+
def _guess_is_casual_attn(model):
38+
for module in model.modules():
39+
if hasattr(module, 'is_causal'):
40+
return module.is_causal
41+
print_log(
42+
'It\'s impossible to speculate whether casual attention was used, '
43+
'and FLOPs will be calculated as `casual = True`.', 'current')
44+
return True
45+
46+
@staticmethod
47+
def _get_batch_size_and_sequence_len(data_batch):
48+
data_list, _ = tree_flatten(data_batch)
49+
for data in data_list:
50+
if isinstance(data, torch.Tensor):
51+
return data.size(0), data.size(1)
52+
raise RuntimeError('No tensor found in the batch')
53+
54+
@staticmethod
55+
def _guess_use_activation_checkpointing(model):
56+
for module in model.modules():
57+
if hasattr(module, 'gradient_checkpointing'):
58+
return module.gradient_checkpointing
59+
return False
2960

3061
def before_run(self, runner) -> None:
3162
if is_model_wrapper(runner.model):
@@ -41,20 +72,18 @@ def before_run(self, runner) -> None:
4172
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
4273
model.config.hidden_size)
4374
self.mlp_ratio *= 1.5 # has gate_proj
44-
return
75+
self.is_casual = self.is_casual if self.is_casual is not None \
76+
else self._guess_is_casual_attn(model)
4577

46-
def _get_batch_size_and_sequence_len(self, data_batch):
47-
data_list, _ = tree_flatten(data_batch)
48-
for data in data_list:
49-
if isinstance(data, torch.Tensor):
50-
return data.size(0), data.size(1)
51-
raise RuntimeError('No tensor found in the batch')
78+
use_varlen_attn = getattr(model, 'use_varlen_attn', False)
79+
if use_varlen_attn:
80+
print_log(
81+
'Using variable-length Flash Attention causes an inflation'
82+
' in the FLOPs calculation.',
83+
'current',
84+
level=logging.WARNING)
5285

53-
def _guess_use_activation_checkpointing(self, model):
54-
for module in model.modules():
55-
if hasattr(module, 'gradient_checkpointing'):
56-
return module.gradient_checkpointing
57-
return False
86+
return
5887

5988
def after_train_iter(self,
6089
runner,
@@ -66,17 +95,50 @@ def after_train_iter(self,
6695

6796
batch_size, sequence_len = self._get_batch_size_and_sequence_len(
6897
data_batch)
98+
sequence_parallel_size = get_sequence_parallel_world_size()
6999

70100
message_hub = runner.message_hub
71101
iter_time = message_hub.get_scalar('train/time').current()
72102

73-
flops_per_iteration = (
74-
(3 + int(self.use_activation_checkpointing)) *
75-
((8 + self.mlp_ratio * 4) * batch_size * sequence_len *
76-
self.hidden_size**2 +
77-
4 * batch_size * sequence_len**2 * self.hidden_size)
78-
) * self.num_layers + \
79-
6 * batch_size * sequence_len * self.hidden_size * self.vocab_size
103+
# We consider a language model with 𝑙 transformer layers,
104+
# hidden size h, sequence length s, vocabulary size V, and
105+
# training batch size B.
106+
# A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs
107+
# (factor of 2 needed to account for multiplies and adds).
108+
109+
# Attention Layer:
110+
# qkv_proj + o_proj: 8B * s * h^2
111+
# attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
112+
113+
# MLP Layer:
114+
# up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
115+
# (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
116+
# (has gate_proj))
117+
118+
# The backward pass requires double the number of FLOPs since we
119+
# need to calculate the gradients with respect to both input and
120+
# weight tensors. In addition, we are using activation recomputation,
121+
# which requires an additional forward pass before the backward pass.
122+
123+
# While sequence parallel will affect the FLOPs calculation in attn.
124+
# Suppose the sequence length in one GPU is s and the sequence
125+
# parallel world size is `sp_size`, which means the total
126+
# sequence length in the attention calculation is
127+
# `s * sp_size` and the number of attention heads decrease to
128+
# `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
129+
# 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
130+
# 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
131+
132+
flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
133+
flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
134+
sequence_parallel_size / (int(self.is_casual) + 1)
135+
flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
136+
self.hidden_size**2
137+
flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
138+
flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
139+
flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
140+
self.vocab_size
141+
flops_per_iteration = flops_wo_head + flops_head
80142

81143
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
82144
tokens_per_sec_per_gpu = batch_size * sequence_len / (
Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Mapping, Optional, Sequence, Union
2+
from typing import Optional, Union
33

4-
import torch
54
import torch.distributed as dist
65
from mmengine import MessageHub
76
from mmengine.hooks import Hook
@@ -11,20 +10,6 @@
1110

1211
class VarlenAttnArgsToMessageHubHook(Hook):
1312

14-
args = ('cumulative_len', 'max_seqlen')
15-
16-
def cast_data(self, data):
17-
if isinstance(data, Mapping):
18-
return {key: self.cast_data(data[key]) for key in data}
19-
elif isinstance(data, (str, bytes)) or data is None:
20-
return data
21-
elif isinstance(data, Sequence):
22-
return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
23-
elif isinstance(data, torch.Tensor):
24-
return data.cuda()
25-
else:
26-
return data
27-
2813
def before_train_iter(self,
2914
runner,
3015
batch_idx: int,
@@ -35,10 +20,13 @@ def before_train_iter(self,
3520
assert 'data' in data_batch.keys()
3621
data = data_batch['data']
3722

38-
for arg in self.args:
39-
assert arg in data
40-
message_hub.update_info(f'{arg}_rank_{rank}',
41-
self.cast_data(data.pop(arg)))
23+
cumulative_len = data.pop('cumulative_len')
24+
assert len(cumulative_len) == 1
25+
cumulative_len = cumulative_len[0].cuda()
26+
message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)
27+
28+
max_seqlen = data.pop('max_seqlen')
29+
message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)
4230

4331
def after_train_iter(self,
4432
runner,
@@ -47,6 +35,5 @@ def after_train_iter(self,
4735
outputs: Optional[dict] = None) -> None:
4836
rank = dist.get_rank()
4937
message_hub = MessageHub.get_instance('varlen_attn_args')
50-
51-
for arg in self.args:
52-
message_hub.update_info(f'{arg}_rank_{rank}', None)
38+
message_hub.update_info(f'cumulative_len_rank_{rank}', None)
39+
message_hub.update_info(f'max_seqlen_rank_{rank}', None)

xtuner/model/llava.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import math
23
from collections import OrderedDict
34

5+
import torch
46
import torch.nn as nn
57
from mmengine.config import Config, ConfigDict
68
from mmengine.model import BaseModel
79
from peft import get_peft_model, prepare_model_for_kbit_training
10+
from transformers import AutoConfig
811

912
from xtuner.registry import BUILDER
1013
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
14+
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
1115
from .utils import (LoadWoInit, find_all_linear_names,
1216
get_peft_model_state_dict, guess_load_checkpoint,
1317
make_inputs_require_grad,
@@ -26,11 +30,15 @@ def __init__(self,
2630
projector_depth=2,
2731
llm_lora=None,
2832
visual_encoder_lora=None,
29-
use_activation_checkpointing=True):
33+
use_activation_checkpointing=True,
34+
max_position_embeddings=None):
3035
super().__init__()
3136
self.freeze_llm = freeze_llm
3237
self.freeze_visual_encoder = freeze_visual_encoder
3338
with LoadWoInit():
39+
if isinstance(llm, dict):
40+
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
41+
3442
self.llm = self._build_from_cfg_or_module(llm)
3543
self.visual_encoder = self._build_from_cfg_or_module(
3644
visual_encoder)
@@ -157,6 +165,62 @@ def state_dict(self, *args, **kwargs):
157165
for k, v in state_dict.items() if 'projector.' in k})
158166
return to_return
159167

168+
@staticmethod
169+
def _prepare_for_long_context_training(cfg, llm_cfg,
170+
max_position_embeddings):
171+
172+
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
173+
if orig_rope_scaling is None:
174+
orig_rope_scaling = {'factor': 1}
175+
176+
orig_rope_scaling_factor = orig_rope_scaling[
177+
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
178+
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
179+
if orig_ctx_len:
180+
orig_ctx_len *= orig_rope_scaling_factor
181+
if max_position_embeddings > orig_ctx_len:
182+
scaling_factor = float(
183+
math.ceil(max_position_embeddings / orig_ctx_len))
184+
llm_cfg.rope_scaling = {
185+
'type': 'linear',
186+
'factor': scaling_factor
187+
}
188+
189+
# hardcode for internlm2
190+
llm_cfg.attn_implementation = 'flash_attention_2'
191+
cfg.config = llm_cfg
192+
193+
return cfg, llm_cfg
194+
195+
@staticmethod
196+
def _prepare_for_flash_attn(cfg, llm_cfg):
197+
cls_name = type(llm_cfg).__name__
198+
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
199+
'MixtralConfig', 'Qwen2Config',
200+
'Starcoder2Config', 'Starcoder2Config')
201+
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
202+
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
203+
'Starcoder2Config', 'Starcoder2Config')
204+
205+
if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
206+
cfg.torch_dtype = torch.bfloat16 \
207+
if torch.cuda.is_bf16_supported() else torch.float16
208+
cfg.attn_implementation = 'flash_attention_2'
209+
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
210+
cfg.attn_implementation = 'sdpa'
211+
212+
return cfg, llm_cfg
213+
214+
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
215+
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
216+
llm_cfg = AutoConfig.from_pretrained(
217+
pretrained_model_name_or_path, trust_remote_code=True)
218+
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
219+
if max_position_embeddings is not None:
220+
cfg, llm_cfg = self._prepare_for_long_context_training(
221+
cfg, llm_cfg, max_position_embeddings)
222+
return cfg
223+
160224
def _build_from_cfg_or_module(self, cfg_or_mod):
161225
if isinstance(cfg_or_mod, nn.Module):
162226
return cfg_or_mod

xtuner/model/modules/dispatch/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .yi import yi_attn_forward
1414

1515
IS_LOW_VERSION_TRANSFORMERS = digit_version(
16-
transformers.__version__) < digit_version('4.36')
16+
transformers.__version__) < digit_version('4.38')
1717
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.0.0')
1818
SUPPORT_FLASH2 = False
1919

@@ -48,7 +48,7 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
4848
if use_varlen_attn:
4949
assert SUPPORT_FLASH2 and SUPPORT_TRITON, \
5050
'flash_attn and triton is required if you want to use varlen_attn.'
51-
elif not SUPPORT_FLASH:
51+
elif not SUPPORT_FLASH2:
5252
return
5353

5454
from .llama import (llama_attn_forward, llama_attn_forward_legacy,
@@ -57,8 +57,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
5757

5858
print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
5959
for module in model.modules():
60-
if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2',
61-
'LlamaSdpaAttention'):
60+
# Do not need to dispatch if
61+
# type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is
62+
# required when using sequence parallel
63+
if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'):
6264
if use_varlen_attn:
6365
print_log('dispatch llama varlen attn forward', 'current')
6466
if IS_LOW_VERSION_TRANSFORMERS:

0 commit comments

Comments
 (0)