-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
64 lines (54 loc) · 2.22 KB
/
utils.py
File metadata and controls
64 lines (54 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import jax
from datetime import datetime
# https://github.com/karpathy/nanochat/blob/bc51da8baca66c54606bdd75c861c82ced90dcb0/nanochat/common.py#L183C1-L190C13
class DummyWandb:
def __init__(self):
self.id = datetime.now().strftime("%Y%m%d_%H%M%S")
def log(self, *args, **kwargs):
pass
def log_artifact(self, *args, **kwargs):
pass
def log_model(self, *args, **kwargs):
pass
def finish(self):
pass
class MetricLogger:
def __init__(self, batch_size, prefix, buffer=True, wandb_run=None):
self.batch_size = batch_size
self.prefix = prefix
self.buffer = buffer
self.wandb_run = wandb_run
self.prev_metrics = None
def _human_format(self, num: float, billions: bool = False, divide_by_1024: bool = False) -> str:
# https://github.com/huggingface/nanotron/blob/7bc9923285a03069ebffe994379a311aceaea546/src/nanotron/logging/base.py#L268
if abs(num) < 1:
return "{:.3g}".format(num)
SIZES = ["", "K", "M", "B", "T", "P", "E"]
num = float("{:.3g}".format(num))
magnitude = 0
i = 0
while abs(num) >= 1000 and i < len(SIZES) - 1:
magnitude += 1
num /= 1000.0 if not divide_by_1024 else 1024.0
i += 1
return "{}{}".format("{:f}".format(num).rstrip("0").rstrip("."), SIZES[magnitude])
def _pretty_print(self, metrics, step):
print_string = f"step: {step}"
for k, v in metrics.items():
print_string += f" | {k}: {self._human_format(v)}"
print(print_string)
def log(self, metrics):
if self.buffer:
self.prev_metrics, log_metrics = metrics, self.prev_metrics
else:
log_metrics = metrics
if not log_metrics:
return
step = log_metrics.pop("step")
# move to cpu - to not block
log_metrics = jax.tree.map(lambda x: float(x), log_metrics)
log_metrics["samples_per_second"] = self.batch_size / log_metrics["step_time"]
self._pretty_print(log_metrics, step)
if self.wandb_run:
log_metrics = {f"{self.prefix}/{k}": v for k, v in log_metrics.items()}
self.wandb_run.log(log_metrics, step=step)