Skip to content

Commit 441ae4c

Browse files
committed
merge
2 parents d3ae98f + 71f2d29 commit 441ae4c

File tree

7 files changed

+106
-31
lines changed

7 files changed

+106
-31
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pip install git+https://github.com/Tongjilibo/torch4keras.git
6767
## 4. 版本历史
6868
|更新日期| 版本 | 版本说明 |
6969
|------| ----------------- |----------- |
70+
|20240204|v0.1.9 | 增加Timeit, Timeit2, timeit等时间/速度监控|
7071
|20240116|v0.1.8 | 重新整理snippets, 重写save_pretrained|
7172
|20231219|v0.1.7 | 增加SimpleStreamFileLogger和LoggerHandler, 修改Logger的格式|
7273
|20231208|v0.1.6.post2 |监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题; 修复clip_grad_norm的bug|

docs/History.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 更新历史
22

3+
- **20240204**: 增加Timeit, Timeit2, timeid等时间/速度监控
34
- **20240116**: 重新整理snippets, 重写save_pretrained
45
- **20231219**: 增加SimpleStreamFileLogger和LoggerHandler, 修改Logger的格式
56
- **20231208**: 监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题

test/test_time.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch4keras.snippets import timeit, Timeit
1+
from torch4keras.snippets import timeit, Timeit, Timeit2
22
import time
33

44

@@ -13,19 +13,30 @@ def func(n=10):
1313
with Timeit() as ti:
1414
for i in range(10):
1515
time.sleep(0.1)
16-
ti.lap(prefix=i, restart=False) # 统计累计耗时
16+
ti.lap(name=i, reset=False) # 统计累计耗时
1717

1818
# 上下文管理器 - 统计每段速度
1919
with Timeit() as ti:
2020
for i in range(10):
2121
time.sleep(0.1)
22-
ti.lap(count=10, prefix=i, restart=True)
22+
ti.lap(count=10, name=i, reset=True)
2323
ti(10) # 统计速度
2424

2525

2626
# 上下文管理器 - 统计速度
2727
with Timeit() as ti:
2828
for i in range(10):
2929
time.sleep(0.1)
30-
ti.lap(prefix=i, restart=True)
31-
ti(10) # 统计速度
30+
ti.lap(name=i, reset=True)
31+
ti(10) # 统计速度
32+
33+
ti = Timeit2()
34+
for i in range(10):
35+
time.sleep(0.1)
36+
ti.lap(name=i)
37+
38+
for i in range(10):
39+
time.sleep(0.1)
40+
ti.lap(name=i)
41+
ti.end() # 打印时长
42+

torch4keras/snippets/log.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ def log_info(string:str, verbose:int=1):
6262
return res
6363

6464

65+
@functools.lru_cache(None)
66+
def log_info_once(string:str, verbose=1):
67+
''' 单次warning '''
68+
return log_info(string, verbose)
69+
70+
6571
def log_warn(string:str, verbose:int=1):
6672
'''[WARNING]: message, 黄色前缀'''
6773
res = colorful('[WARNING]', color='yellow') + ' ' + string.strip()
@@ -70,6 +76,12 @@ def log_warn(string:str, verbose:int=1):
7076
return res
7177

7278

79+
@functools.lru_cache(None)
80+
def log_warn_once(string:str, verbose=1):
81+
''' 单次warning '''
82+
return log_warn(string, verbose)
83+
84+
7385
def log_error(string:str, verbose:int=1):
7486
'''[ERROR]: message, 红色前缀'''
7587
res = colorful('[ERROR]', color='red') + ' ' + string.strip()
@@ -79,9 +91,9 @@ def log_error(string:str, verbose:int=1):
7991

8092

8193
@functools.lru_cache(None)
82-
def log_warn_once(string:str, verbose=1):
94+
def log_error_once(string:str, verbose=1):
8395
''' 单次warning '''
84-
return log_warn(string, verbose)
96+
return log_error(string, verbose)
8597

8698

8799
@functools.lru_cache(None)

torch4keras/snippets/misc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import random
77
from .log import log_info, log_warn, log_error
8+
import json
89

910

1011
def seed_everything(seed:int=None):
@@ -77,7 +78,6 @@ def allowDotting(self, state=True):
7778
class JsonConfig:
7879
'''读取配置文件并返回可.操作符的字典'''
7980
def __new__(self, json_path, encoding='utf-8'):
80-
import json
8181
return DottableDict(json.load(open(json_path, "r", encoding=encoding)))
8282

8383

torch4keras/snippets/monitor.py

+67-14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import functools
77
from .log import log_info, log_warn, log_error
8+
from pprint import pprint
89

910

1011
def format_time(eta, hhmmss=True):
@@ -61,56 +62,108 @@ class Timeit:
6162
with Timeit() as ti:
6263
for i in range(10):
6364
time.sleep(0.1)
64-
# ti.lap(prefix=i, restart=False) # 统计累计耗时
65-
# ti.lap(prefix=i, restart=True) # 统计间隔耗时
66-
# ti.lap(count=10, prefix=i, restart=True) # 统计每段速度
65+
# ti.lap(name=i, restart=False) # 统计累计耗时
66+
# ti.lap(name=i, restart=True) # 统计间隔耗时
67+
# ti.lap(count=10, name=i, restart=True) # 统计每段速度
6768
# ti(10) # 统计速度
6869
'''
69-
def __enter__(self):
70+
def __enter__(self, template='Average speed: {:.2f}/s'):
7071
self.count = None
7172
self.start_tm = time.time()
72-
self.template = 'Average speed: {:.2f}/s'
73+
self.template = template
7374
return self
7475

7576
def __call__(self, count):
7677
self.count = count
7778

78-
def restart(self):
79+
def reset(self):
7980
'''自定义开始记录的地方'''
8081
self.start_tm = time.time()
8182

82-
def lap(self, count:int=None, prefix:str=None, restart=False):
83+
def lap(self, name:str=None, count:int=None, reset=False):
8384
'''
85+
:params name: 打印时候自定义的前缀
8486
:params count: 需要计算平均生成速度中统计的次数
85-
:params prefix: 打印时候自定义的前缀
86-
:params restart: 是否重置start_tm, True只记录时间间隔,否则记录的是从一开始的累计时间
87+
:params reset: 是否重置start_tm, True只记录时间间隔,否则记录的是从一开始的累计时间
8788
'''
8889
if count is not None:
8990
self.count = count
90-
prefix = '' if prefix is None else str(prefix).strip() + ' - '
91+
name = '' if name is None else str(name).strip() + ' - '
9192

9293
end_tm = time.time()
9394
consume = end_tm - self.start_tm
9495
if self.count is None:
96+
# 只log时间
9597
consume = format_time(consume, hhmmss=False)
9698
start1 = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_tm))
9799
end1 = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end_tm))
98-
log_info(prefix + f'Cost {consume} [{start1} < {end1}]')
100+
log_info(name + f'Cost {consume} [{start1} < {end1}]')
99101
elif consume > 0:
100102
speed = self.count / consume
101-
log_info(prefix + self.template.format(speed))
103+
log_info(name + self.template.format(speed))
102104
else:
103105
pass
104106
# log_warn('Time duration = 0')
105107

106-
if restart:
107-
self.restart()
108+
if reset:
109+
self.reset()
108110

109111
def __exit__(self, exc_type, exc_val, exc_tb):
110112
self.lap()
111113
print()
112114

113115

116+
class Timeit2:
117+
'''记录耗时
118+
119+
Example
120+
----------------------
121+
ti = Timeit2()
122+
for i in range(10):
123+
time.sleep(0.1)
124+
ti.lap(name=i)
125+
ti.end() # 打印各个步骤时长
126+
'''
127+
def __init__(self):
128+
self.reset()
129+
130+
def __call__(self, *args, **kwargs):
131+
self.lap(*args, **kwargs)
132+
133+
def reset(self):
134+
'''自定义开始记录的地方'''
135+
self.cost = dict()
136+
self.count = dict()
137+
self.start_tm = time.time()
138+
139+
def restart(self):
140+
self.start_tm = time.time()
141+
142+
def lap(self, name:str):
143+
'''
144+
:params name: 打印时候自定义的前缀
145+
'''
146+
end_tm = time.time()
147+
consume = end_tm - self.start_tm
148+
name = str(name)
149+
self.cost[name] = self.cost.get(name, 0) + consume
150+
self.count[name] = self.count.get(name, 0) + 1
151+
self.start_tm = time.time()
152+
153+
def end(self, verbose=1):
154+
for k, v in self.count.items():
155+
if v > 1:
156+
self.cost['avg_' + k] = self.cost[k] / v
157+
158+
if verbose > 0:
159+
log_info('Cost detail')
160+
pprint(self.cost)
161+
print()
162+
163+
self.reset()
164+
return self.cost
165+
166+
114167
def send_email(mail_receivers:Union[str,list], mail_subject:str, mail_msg:str="", mail_host:str=None,
115168
mail_user:str=None, mail_pwd:str=None, mail_sender:str=None):
116169
''' 发送邮件(默认使用笔者自己注册的邮箱,若含敏感信息请使用自己注册的邮箱)

torch4keras/trainer.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ def fit(self, train_dataloader, steps_per_epoch=None, epochs=1, callbacks=None,
363363
# forward和backward
364364
if not self.unwrap_model().training:
365365
self.unwrap_model().train() # 设置为train模式
366-
367366
tr_loss, tr_loss_detail = 0, {}
368367
for _ in range(self.grad_accumulation_steps):
369368
train_X, train_y = self._prepare_nextbatch() # 获取下一个batch的训练数据
@@ -562,23 +561,21 @@ def save_to_checkpoint(self, save_dir:str=None, model_path:str=None, optimizer_p
562561
:param mapping: dict, 模型文件的mapping
563562
:param trainable_only
564563
'''
565-
model_path = model_path or os.path.join(save_dir, 'model.pt')
566-
optimizer_path = optimizer_path or os.path.join(save_dir, 'optimizer.pt')
567-
scheduler_path = scheduler_path or os.path.join(save_dir, 'scheduler.pt')
568-
steps_params_path = steps_params_path or os.path.join(save_dir, 'steps_params.pt')
564+
model_path = model_path or os.path.join(save_dir or './', 'model.pt')
565+
optimizer_path = optimizer_path or os.path.join(save_dir or './', 'optimizer.pt')
566+
scheduler_path = scheduler_path or os.path.join(save_dir or './', 'scheduler.pt')
567+
steps_params_path = steps_params_path or os.path.join(save_dir or './', 'steps_params.pt')
569568

570569
verbose_str = ''
571570
if model_path:
572571
self.save_weights(model_path, mapping=mapping, trainable_only=trainable_only)
573572
verbose_str += f'Model weights successfuly saved to {model_path}\n'
574573
if optimizer_path:
575-
save_dir = os.path.dirname(optimizer_path)
576-
os.makedirs(save_dir, exist_ok=True)
574+
os.makedirs(os.path.dirname(optimizer_path), exist_ok=True)
577575
torch.save(self.optimizer.state_dict(), optimizer_path)
578576
verbose_str += f'Optimizer successfuly saved to {optimizer_path}\n'
579577
if scheduler_path and (self.scheduler is not None):
580-
save_dir = os.path.dirname(scheduler_path)
581-
os.makedirs(save_dir, exist_ok=True)
578+
os.makedirs(os.path.dirname(scheduler_path), exist_ok=True)
582579
torch.save(self.scheduler.state_dict(), scheduler_path)
583580
verbose_str += f'Scheduler successfuly saved to {scheduler_path}\n'
584581
if steps_params_path:

0 commit comments

Comments
 (0)