forked from mlcommons/training
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmlperf_logging_utils.py
More file actions
424 lines (377 loc) · 14.1 KB
/
mlperf_logging_utils.py
File metadata and controls
424 lines (377 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import torch
import torch.distributed as dist
from mlperf_logging import mllog
from mlperf_logging.mllog import constants
from pytorch_lightning import Callback
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
is_torch_xla_available,
)
if is_torch_xla_available():
import torch_xla.runtime as xr
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if is_torch_xla_available():
return xr.global_ordinal()
else:
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def barrier():
if not is_dist_avail_and_initialized():
return
torch.distributed.barrier()
class ClmLogger:
def __init__(self, config, filename=None, default_stack_offset=2):
self.mllogger = mllog.get_mllogger()
mllog.config(
default_stack_offset=default_stack_offset,
filename=(
filename
or os.getenv("COMPLIANCE_FILE")
or os.path.join(config.run_dir, "mlperf_compliance.log")
),
)
self.target_eval_loss = config.target_eval_loss
def event(self, key, value=None, metadata=None, sync=False, log_rank=None):
if get_rank() == 0:
self.mllogger.event(key=key, value=value, metadata=metadata)
def start(self, key, value=None, metadata=None, sync=False, log_rank=None):
if get_rank() == 0:
self.mllogger.start(key=key, value=value, metadata=metadata)
def end(self, key, value=None, metadata=None, sync=False, log_rank=None):
if get_rank() == 0:
self.mllogger.end(key=key, value=value, metadata=metadata)
class MLPerfCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def __init__(self, config):
super().__init__()
self.mllogger = ClmLogger(config)
self.submission_info = {
"submission_benchmark": "mixture-of-expert", # TODO change task name
"submission_division": "closed",
"submission_org": "Google",
"submission_platform": "reference",
"submission_status": "reference",
}
self.mllogger.event(
key=constants.CACHE_CLEAR,
value="True",
)
self.mllogger.start(key=constants.INIT_START, value="")
self.config = config
self.global_batch_tokens = config.global_train_batch_size * config.max_length
def on_train_begin(self, args, state, control, **kwargs):
if torch.cuda.is_available():
num_devices = int(os.getenv("WORLD_SIZE", 1))
elif is_torch_xla_available():
num_devices = xr.global_runtime_device_count()
else:
raise ValueError("The pipeline should be either cuda or xla backend.")
self.global_batch_size = int(
args.per_device_train_batch_size
* args.gradient_accumulation_steps
* num_devices
)
self.mllogger.event(
key=constants.SUBMISSION_BENCHMARK,
value=self.submission_info["submission_benchmark"],
)
self.mllogger.event(
key=constants.SUBMISSION_DIVISION,
value=self.submission_info["submission_division"],
)
self.mllogger.event(
key=constants.SUBMISSION_ORG, value=self.submission_info["submission_org"]
)
self.mllogger.event(
key=constants.SUBMISSION_PLATFORM,
value=self.submission_info["submission_platform"],
)
self.mllogger.event(
key=constants.SUBMISSION_STATUS,
value=self.submission_info["submission_status"],
)
self.mllogger.event(
key=constants.GLOBAL_BATCH_SIZE,
value=self.config.global_train_batch_size,
)
self.mllogger.event(
key=constants.EVAL_SAMPLES,
value=12694503,
)
self.mllogger.event(key=constants.SEED, value=args.seed)
self.mllogger.event(
key=constants.OPT_LR_WARMUP_FACTOR, value=args.sched.warmup_ratio
)
self.mllogger.event(key=constants.OPT_LR_TRAINING_STEPS, value=args.max_steps)
self.mllogger.event(
key=constants.OPT_ADAMW_WEIGHT_DECAY, value=args.weight_decay
)
self.mllogger.event(
key=constants.OPT_GRADIENT_CLIP_NORM, value=args.max_grad_norm
)
self.mllogger.event(key=constants.OPT_BASE_LR, value=args.lr)
self.mllogger.event(
key=constants.GRADIENT_ACCUMULATION_STEPS,
value=args.gradient_accumulation_steps,
)
# device warmup should be done here
self.mllogger.end(key=constants.INIT_STOP, value="")
# run on all ranks to allow sync
barrier()
self.mllogger.start(constants.RUN_START, value="")
self.mllogger.start(
constants.BLOCK_START,
value="",
metadata={
"samples_count": 0,
},
)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Event called at the end of a training step.
"""
if state.global_step % state.eval_steps == 0 and state.global_step > 0:
self.mllogger.event(
"train_loss",
value=state.log_history[-1]["train/loss"] if state.log_history else -1,
metadata={
"samples_count": (
state.global_step * self.global_batch_tokens
if state.log_history
else -1
)
},
)
control.should_log = True
if state.global_step % state.eval_steps == 0:
self.mllogger.end(
constants.BLOCK_STOP,
value="",
metadata={
"samples_count": state.global_step * self.global_batch_tokens,
},
)
self.mllogger.event(
constants.EVAL_ACCURACY,
value=state.log_history[-1]["eval/loss"],
metadata={
"samples_count": state.global_step * self.global_batch_tokens,
},
)
latest_eval_loss = float("nan")
if state.log_history and "eval/loss" in state.log_history[-1]:
latest_eval_loss = state.log_history[-1]["eval/loss"]
if latest_eval_loss <= self.mllogger.target_eval_loss:
control.should_training_stop = True
# run on all ranks to allow sync
barrier()
self.mllogger.end(
constants.RUN_STOP,
value=latest_eval_loss,
metadata={
"samples_count": state.global_step * self.global_batch_tokens,
"status": "success",
},
)
if state.global_step >= state.max_steps:
control.should_training_stop = True
self.mllogger.end(
constants.RUN_STOP,
value=latest_eval_loss,
metadata={
"samples_count": state.global_step * self.global_batch_tokens,
"status": "fail",
},
)
if not control.should_training_stop:
self.mllogger.start(
constants.BLOCK_START,
value="",
metadata={
"samples_count": state.global_step * self.global_batch_tokens
},
)
return control
class MLPerfLightningCallback(Callback):
def __init__(self, logger, global_batch_size: int, sequence_length: int):
super().__init__()
self.gbs = global_batch_size
self.seq = sequence_length
self.mllogger = logger
self.force_success = False
def __deepcopy__(self, memo):
return MLPerfLightningCallback(self.mllogger, self.gbs, self.seq)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
return super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
@rank_zero_only
def on_validation_start(self, trainer, pl_module):
self.mllogger.end(
constants.BLOCK_STOP,
metadata={"samples_count": trainer.global_step * self.gbs * self.seq},
sync=False,
)
self.mllogger.start(
key=constants.EVAL_START,
metadata={"samples_count": trainer.global_step * self.gbs * self.seq},
sync=False,
)
return super().on_validation_start(trainer, pl_module)
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
if not trainer.should_stop:
self.mllogger.start(
constants.BLOCK_START,
metadata={"samples_count": trainer.global_step * self.gbs * self.seq},
sync=False,
)
return super().on_validation_end(trainer, pl_module)
@rank_zero_only
def on_train_start(self, trainer, pl_module):
self.mllogger.start(
constants.BLOCK_START, metadata={"samples_count": 0}, sync=False
)
@rank_zero_only
def on_train_end(self, trainer, pl_module):
if hasattr(trainer, "run_stop_logged") and not trainer.run_stop_logged:
self.mllogger.end(
constants.RUN_STOP,
metadata={
"samples_count": trainer.global_step * self.gbs * self.seq,
"status": "aborted" if not self.force_success else "success",
},
)
return super().on_train_end(trainer, pl_module)
class MetricsLogger(Logger):
def __init__(
self,
logger,
nodes: int,
global_batch_size: int,
learning_rate: float,
sequence_length: int,
):
super().__init__()
self.nodes = nodes
self.gbs = global_batch_size
self.seq = sequence_length
self.lr = learning_rate
self.mllogger = logger
self.experiment = None
def __deepcopy__(self, memo):
output = MetricsLogger(self.mllogger, self.nodes, self.gbs, self.lr, self.seq)
if hasattr(self, "trainer"):
output.trainer = self.trainer
return output
def set_trainer(self, trainer):
self.trainer = trainer
trainer.run_stop_logged = False
@rank_zero_only
def log_metrics(self, metrics, step):
if "reduced_train_loss" in metrics:
self.mllogger.event(
"train_loss_update",
value=metrics["reduced_train_loss"],
metadata={
"samples_count": self.trainer.global_step * self.gbs * self.seq,
},
)
if "val_loss" in metrics:
val_loss = metrics["val_loss"]
self.mllogger.event(
constants.EVAL_ACCURACY,
value=val_loss,
metadata={
"samples_count": self.trainer.global_step * self.gbs * self.seq,
},
)
self.mllogger.end(
key=constants.EVAL_STOP,
metadata={
"samples_count": self.trainer.global_step * self.gbs * self.seq
},
sync=False,
)
@rank_zero_only
def log_hyperparams(self, params, *args, **kwargs):
self.mllogger.event(key=constants.CACHE_CLEAR, value=True)
self.mllogger.start(key=constants.INIT_START)
# self.mllogger.mlperf_submission_log(
# benchmark="mixtral_8x22b",
# num_nodes=self.nodes,
# )
# self.mllogger.event(
# key=constants.SEED,
# value=self.cfg.model.seed,
# sync=False,
# unique=True,
# )
self.mllogger.event(
key=constants.GLOBAL_BATCH_SIZE,
value=self.gbs,
sync=False,
)
# self.mllogger.event(
# key=constants.TRAIN_SAMPLES,
# value=0,
# )
# self.mllogger.event(
# key=constants.EVAL_SAMPLES,
# value=0,
# )
# self.mllogger.event(
# key=constants.OPT_LR_WARMUP_FACTOR,
# value=self.cfg.model.optim.sched.warmup_ratio,
# )
# self.mllogger.event(
# key=constants.OPT_ADAMW_WEIGHT_DECAY,
# value=self.cfg.model.optim.weight_decay,
# )
# self.mllogger.event(
# key=constants.OPT_GRADIENT_CLIP_NORM,
# value=self.cfg.trainer.gradient_clip_val,
# )
# ga = int(os.getenv("MINIBS", "1")) // self.cfg.model.micro_batch_size
# self.mllogger.event(key=constants.GRADIENT_ACCUMULATION_STEPS, value=ga)
# self.mllogger.event(
# key=constants.OPT_LR_TRAINING_STEPS, value=self.cfg.trainer.max_steps
# )
self.mllogger.event(key=constants.OPT_BASE_LR, value=self.lr)
@property
def name(self):
return "mlperf-metrics"
@property
def version(self):
return 1