-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
Copy pathperformance_evaluator.py
151 lines (125 loc) · 5.46 KB
/
performance_evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from time import time
from typing import Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from colossalai.accelerator import get_accelerator
from colossalai.cluster import DistCoordinator
def divide(x: float, y: float) -> float:
if y == 0:
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
class DummyProfiler:
def __init__(self):
self.step_number = 0
def step(self):
self.step_number += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
if enable_flag:
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
else:
return DummyProfiler()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None
def reset(self) -> None:
self.duration = 0.0
class PerformanceEvaluator:
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(
self,
model_numel: int,
num_layers: int,
hidden_size: int,
vocab_size: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
) -> None:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.num_layers = num_layers
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.coordinator = DistCoordinator()
self.dp_world_size = dp_world_size or self.coordinator.world_size
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop_megatron = 0
self.flop: int = 0
self.num_tokens: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
# get_accelerator().synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
# get_accelerator().synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
self.num_tokens += batch_size * seq_len
checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint)
self.flop_megatron += (
24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)
) * (
1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
)
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
avg_throughput_samples = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
avg_throughput_tokens = self.num_tokens * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.coordinator.world_size // self.dp_world_size
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
self.coordinator.print_on_master(
f"num_samples: {self.num_samples}, num_tokens: {self.num_tokens}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput_samples: {avg_throughput_samples}, avg_throughput_tokens: {avg_throughput_tokens}"
)
self.coordinator.print_on_master(
f"Throughput_Samples: {avg_throughput_samples:.2f} samples/sec, Throughput_Tokens: {avg_throughput_tokens:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
)