Skip to content

Commit 520ce99

Browse files
authored
[Fix] Fix throughput hook (#527)
fix throughput hook
1 parent 1dd5cbd commit 520ce99

File tree

1 file changed

+82
-20
lines changed

1 file changed

+82
-20
lines changed

xtuner/engine/hooks/throughput_hook.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import logging
23
from typing import Optional, Union
34

45
import torch
6+
from mmengine import print_log
57
from mmengine.hooks import Hook
68
from mmengine.model.wrappers import is_model_wrapper
79
from torch.utils._pytree import tree_flatten
810

11+
from xtuner.parallel.sequence import get_sequence_parallel_world_size
12+
913
DATA_BATCH = Optional[Union[dict, tuple, list]]
1014

1115

@@ -20,12 +24,39 @@ def __init__(self,
2024
hidden_size=None,
2125
num_layers=None,
2226
vocab_size=None,
23-
mlp_ratio=None):
27+
mlp_ratio=None,
28+
is_casual=None):
2429
self.use_activation_checkpointing = use_activation_checkpointing
2530
self.hidden_size = hidden_size
2631
self.num_layers = num_layers
2732
self.vocab_size = vocab_size
2833
self.mlp_ratio = mlp_ratio
34+
self.is_casual = is_casual
35+
36+
@staticmethod
37+
def _guess_is_casual_attn(model):
38+
for module in model.modules():
39+
if hasattr(module, 'is_causal'):
40+
return module.is_causal
41+
print_log(
42+
'It\'s impossible to speculate whether casual attention was used, '
43+
'and FLOPs will be calculated as `casual = True`.', 'current')
44+
return True
45+
46+
@staticmethod
47+
def _get_batch_size_and_sequence_len(data_batch):
48+
data_list, _ = tree_flatten(data_batch)
49+
for data in data_list:
50+
if isinstance(data, torch.Tensor):
51+
return data.size(0), data.size(1)
52+
raise RuntimeError('No tensor found in the batch')
53+
54+
@staticmethod
55+
def _guess_use_activation_checkpointing(model):
56+
for module in model.modules():
57+
if hasattr(module, 'gradient_checkpointing'):
58+
return module.gradient_checkpointing
59+
return False
2960

3061
def before_run(self, runner) -> None:
3162
if is_model_wrapper(runner.model):
@@ -41,20 +72,18 @@ def before_run(self, runner) -> None:
4172
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
4273
model.config.hidden_size)
4374
self.mlp_ratio *= 1.5 # has gate_proj
44-
return
75+
self.is_casual = self.is_casual if self.is_casual is not None \
76+
else self._guess_is_casual_attn(model)
4577

46-
def _get_batch_size_and_sequence_len(self, data_batch):
47-
data_list, _ = tree_flatten(data_batch)
48-
for data in data_list:
49-
if isinstance(data, torch.Tensor):
50-
return data.size(0), data.size(1)
51-
raise RuntimeError('No tensor found in the batch')
78+
use_varlen_attn = getattr(model, 'use_varlen_attn', False)
79+
if use_varlen_attn:
80+
print_log(
81+
'Using variable-length Flash Attention causes an inflation'
82+
' in the FLOPs calculation.',
83+
'current',
84+
level=logging.WARNING)
5285

53-
def _guess_use_activation_checkpointing(self, model):
54-
for module in model.modules():
55-
if hasattr(module, 'gradient_checkpointing'):
56-
return module.gradient_checkpointing
57-
return False
86+
return
5887

5988
def after_train_iter(self,
6089
runner,
@@ -66,17 +95,50 @@ def after_train_iter(self,
6695

6796
batch_size, sequence_len = self._get_batch_size_and_sequence_len(
6897
data_batch)
98+
sequence_parallel_size = get_sequence_parallel_world_size()
6999

70100
message_hub = runner.message_hub
71101
iter_time = message_hub.get_scalar('train/time').current()
72102

73-
flops_per_iteration = (
74-
(3 + int(self.use_activation_checkpointing)) *
75-
((8 + self.mlp_ratio * 4) * batch_size * sequence_len *
76-
self.hidden_size**2 +
77-
4 * batch_size * sequence_len**2 * self.hidden_size)
78-
) * self.num_layers + \
79-
6 * batch_size * sequence_len * self.hidden_size * self.vocab_size
103+
# We consider a language model with 𝑙 transformer layers,
104+
# hidden size h, sequence length s, vocabulary size V, and
105+
# training batch size B.
106+
# A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs
107+
# (factor of 2 needed to account for multiplies and adds).
108+
109+
# Attention Layer:
110+
# qkv_proj + o_proj: 8B * s * h^2
111+
# attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
112+
113+
# MLP Layer:
114+
# up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
115+
# (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
116+
# (has gate_proj))
117+
118+
# The backward pass requires double the number of FLOPs since we
119+
# need to calculate the gradients with respect to both input and
120+
# weight tensors. In addition, we are using activation recomputation,
121+
# which requires an additional forward pass before the backward pass.
122+
123+
# While sequence parallel will affect the FLOPs calculation in attn.
124+
# Suppose the sequence length in one GPU is s and the sequence
125+
# parallel world size is `sp_size`, which means the total
126+
# sequence length in the attention calculation is
127+
# `s * sp_size` and the number of attention heads decrease to
128+
# `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
129+
# 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
130+
# 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
131+
132+
flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
133+
flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
134+
sequence_parallel_size / (int(self.is_casual) + 1)
135+
flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
136+
self.hidden_size**2
137+
flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
138+
flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
139+
flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
140+
self.vocab_size
141+
flops_per_iteration = flops_wo_head + flops_head
80142

81143
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
82144
tokens_per_sec_per_gpu = batch_size * sequence_len / (

0 commit comments

Comments
 (0)