1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
2
3
from typing import Optional , Union
3
4
4
5
import torch
6
+ from mmengine import print_log
5
7
from mmengine .hooks import Hook
6
8
from mmengine .model .wrappers import is_model_wrapper
7
9
from torch .utils ._pytree import tree_flatten
8
10
11
+ from xtuner .parallel .sequence import get_sequence_parallel_world_size
12
+
9
13
DATA_BATCH = Optional [Union [dict , tuple , list ]]
10
14
11
15
@@ -20,12 +24,39 @@ def __init__(self,
20
24
hidden_size = None ,
21
25
num_layers = None ,
22
26
vocab_size = None ,
23
- mlp_ratio = None ):
27
+ mlp_ratio = None ,
28
+ is_casual = None ):
24
29
self .use_activation_checkpointing = use_activation_checkpointing
25
30
self .hidden_size = hidden_size
26
31
self .num_layers = num_layers
27
32
self .vocab_size = vocab_size
28
33
self .mlp_ratio = mlp_ratio
34
+ self .is_casual = is_casual
35
+
36
+ @staticmethod
37
+ def _guess_is_casual_attn (model ):
38
+ for module in model .modules ():
39
+ if hasattr (module , 'is_causal' ):
40
+ return module .is_causal
41
+ print_log (
42
+ 'It\' s impossible to speculate whether casual attention was used, '
43
+ 'and FLOPs will be calculated as `casual = True`.' , 'current' )
44
+ return True
45
+
46
+ @staticmethod
47
+ def _get_batch_size_and_sequence_len (data_batch ):
48
+ data_list , _ = tree_flatten (data_batch )
49
+ for data in data_list :
50
+ if isinstance (data , torch .Tensor ):
51
+ return data .size (0 ), data .size (1 )
52
+ raise RuntimeError ('No tensor found in the batch' )
53
+
54
+ @staticmethod
55
+ def _guess_use_activation_checkpointing (model ):
56
+ for module in model .modules ():
57
+ if hasattr (module , 'gradient_checkpointing' ):
58
+ return module .gradient_checkpointing
59
+ return False
29
60
30
61
def before_run (self , runner ) -> None :
31
62
if is_model_wrapper (runner .model ):
@@ -41,20 +72,18 @@ def before_run(self, runner) -> None:
41
72
self .mlp_ratio = self .mlp_ratio or (model .config .intermediate_size /
42
73
model .config .hidden_size )
43
74
self .mlp_ratio *= 1.5 # has gate_proj
44
- return
75
+ self .is_casual = self .is_casual if self .is_casual is not None \
76
+ else self ._guess_is_casual_attn (model )
45
77
46
- def _get_batch_size_and_sequence_len (self , data_batch ):
47
- data_list , _ = tree_flatten (data_batch )
48
- for data in data_list :
49
- if isinstance (data , torch .Tensor ):
50
- return data .size (0 ), data .size (1 )
51
- raise RuntimeError ('No tensor found in the batch' )
78
+ use_varlen_attn = getattr (model , 'use_varlen_attn' , False )
79
+ if use_varlen_attn :
80
+ print_log (
81
+ 'Using variable-length Flash Attention causes an inflation'
82
+ ' in the FLOPs calculation.' ,
83
+ 'current' ,
84
+ level = logging .WARNING )
52
85
53
- def _guess_use_activation_checkpointing (self , model ):
54
- for module in model .modules ():
55
- if hasattr (module , 'gradient_checkpointing' ):
56
- return module .gradient_checkpointing
57
- return False
86
+ return
58
87
59
88
def after_train_iter (self ,
60
89
runner ,
@@ -66,17 +95,50 @@ def after_train_iter(self,
66
95
67
96
batch_size , sequence_len = self ._get_batch_size_and_sequence_len (
68
97
data_batch )
98
+ sequence_parallel_size = get_sequence_parallel_world_size ()
69
99
70
100
message_hub = runner .message_hub
71
101
iter_time = message_hub .get_scalar ('train/time' ).current ()
72
102
73
- flops_per_iteration = (
74
- (3 + int (self .use_activation_checkpointing )) *
75
- ((8 + self .mlp_ratio * 4 ) * batch_size * sequence_len *
76
- self .hidden_size ** 2 +
77
- 4 * batch_size * sequence_len ** 2 * self .hidden_size )
78
- ) * self .num_layers + \
79
- 6 * batch_size * sequence_len * self .hidden_size * self .vocab_size
103
+ # We consider a language model with 𝑙 transformer layers,
104
+ # hidden size h, sequence length s, vocabulary size V, and
105
+ # training batch size B.
106
+ # A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs
107
+ # (factor of 2 needed to account for multiplies and adds).
108
+
109
+ # Attention Layer:
110
+ # qkv_proj + o_proj: 8B * s * h^2
111
+ # attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
112
+
113
+ # MLP Layer:
114
+ # up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
115
+ # (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
116
+ # (has gate_proj))
117
+
118
+ # The backward pass requires double the number of FLOPs since we
119
+ # need to calculate the gradients with respect to both input and
120
+ # weight tensors. In addition, we are using activation recomputation,
121
+ # which requires an additional forward pass before the backward pass.
122
+
123
+ # While sequence parallel will affect the FLOPs calculation in attn.
124
+ # Suppose the sequence length in one GPU is s and the sequence
125
+ # parallel world size is `sp_size`, which means the total
126
+ # sequence length in the attention calculation is
127
+ # `s * sp_size` and the number of attention heads decrease to
128
+ # `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
129
+ # 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
130
+ # 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
131
+
132
+ flops_qkvo_proj = 8 * batch_size * sequence_len * self .hidden_size ** 2
133
+ flops_attn = 4 * batch_size * sequence_len ** 2 * self .hidden_size * \
134
+ sequence_parallel_size / (int (self .is_casual ) + 1 )
135
+ flops_mlp = 4 * self .mlp_ratio * batch_size * sequence_len * \
136
+ self .hidden_size ** 2
137
+ flops_wo_head = (3 + int (self .use_activation_checkpointing )) * (
138
+ flops_qkvo_proj + flops_attn + flops_mlp ) * self .num_layers
139
+ flops_head = 3 * 2 * batch_size * sequence_len * self .hidden_size * \
140
+ self .vocab_size
141
+ flops_per_iteration = flops_wo_head + flops_head
80
142
81
143
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12 )
82
144
tokens_per_sec_per_gpu = batch_size * sequence_len / (
0 commit comments