Skip to content

Commit 4c80135

Browse files
committed
v0.1.0.post2
1 parent 75db510 commit 4c80135

File tree

7 files changed

+25
-23
lines changed

7 files changed

+25
-23
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ backup
99
test
1010
.DS_Store
1111
*.pt
12-
*.log
12+
*.log
13+
ckpt

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ pip install git+https://github.com/Tongjilibo/torch4keras.git
6565
- 简单示例: [turorials_mnist](https://github.com/Tongjilibo/torch4keras/blob/master/examples/turorials_mnist.py)
6666
6767
## 4. 版本说明
68-
- **v0.1.0**: 允许调整进度条的显示参数, 进度条和日志同步(如果进度条平滑了则日志也平滑), 自动把tensor转到model.device上, 允许打印第一个batch来检查样本
68+
- **v0.1.0.post2**: 20230725 修复v0.1.0的bug,主要是进度条和log的标签平滑的问题
69+
- **v0.1.0**: 20230724 允许调整进度条的显示参数, 进度条和日志同步(如果进度条平滑了则日志也平滑), 自动把tensor转到model.device上, 允许打印第一个batch来检查样本
6970
- **v0.0.9**:20230716 增加auto_set_cuda_devices自动选择显卡,增加log_info,log_warn, log_error等小函数
7071
- **v0.0.8**:20230625 增加EmailCallback和WandbCallback, 增加AccelerateTrainer和DeepSpeedTrainer, grad_accumulation_steps内算一个batch,修改Trainer中部分成员函数
7172
- **v0.0.7.post3**: 20230517 修复保存scheduler
@@ -80,7 +81,8 @@ pip install git+https://github.com/Tongjilibo/torch4keras.git
8081
- **v0.0.1**:20221019 初始版本
8182
8283
## 5. 更新:
83-
- **20230721**: 允许调整进度条的显示参数, 进度条和日志同步(如果进度条平滑了则日志也平滑), 自动把tensor转到model.device上, 允许打印第一个batch来检查样本
84+
- **20230725**: 修复v0.1.0的bug,主要是进度条和log的标签平滑的问题
85+
- **20230724**: 允许调整进度条的显示参数, 进度条和日志同步(如果进度条平滑了则日志也平滑), 自动把tensor转到model.device上, 允许打印第一个batch来检查样本
8486
- **20230716**:增加auto_set_cuda_devices自动选择显卡,增加log_info,log_warn, log_error等小函数
8587
- **20230625**:增加EmailCallback和WandbCallback, 增加AccelerateTrainer和DeepSpeedTrainer, grad_accumulation_steps内算一个batch,修改Trainer中部分成员函数
8688
- **20230517**:Checkpoint Calback增加保存scheduler, save_weights可自行创建目录,Logger, Tensorboard模块加入lr, 修改predict和add_trainer

examples/tutorials_mnist.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torchvision
55
from torch4keras.model import BaseModel, Trainer
66
from torch4keras.snippets import seed_everything
7-
from torch4keras.callbacks import Checkpoint, Evaluator, EarlyStopping, Summary, Logger, EmailCallback, WandbCallback
7+
from torch4keras.callbacks import Checkpoint, Evaluator, EarlyStopping, Summary, Logger, EmailCallback, WandbCallback, Tensorboard
88
from transformers.optimization import get_linear_schedule_with_warmup
99
from torch.utils.data import TensorDataset, DataLoader
1010
from tqdm import tqdm
@@ -54,7 +54,7 @@
5454
model = Trainer(net.to(device))
5555
optimizer = optim.Adam(net.parameters())
5656
scheduler = get_linear_schedule_with_warmup(optimizer, steps_per_epoch, steps_per_epoch*epochs)
57-
model.compile(optimizer=optimizer, scheduler=scheduler, loss=nn.CrossEntropyLoss(), metrics=['acc'], bar='tqdm')
57+
model.compile(optimizer=optimizer, scheduler=scheduler, loss=nn.CrossEntropyLoss(), metrics=['acc'])
5858

5959
class MyEvaluator(Evaluator):
6060
# 重构评价函数
@@ -78,7 +78,9 @@ def evaluate(self):
7878
scheduler_path='./ckpt/{epoch}/scheduler_{epoch}_{test_acc:.5f}.pt',
7979
steps_params_path='./ckpt/{epoch}/steps_params_{epoch}_{test_acc:.5f}.pt')
8080
early_stop = EarlyStopping(monitor='test_acc', verbose=1)
81-
logger = Logger('./ckpt/log.log', interval=100)
82-
email = EmailCallback(receivers='[email protected]')
83-
wandb = WandbCallback(save_code=True)
84-
hist = model.fit(train_dataloader, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[Summary(), evaluator, logger, ckpt, early_stop])
81+
logger = Logger('./ckpt/log.log', interval=100) # log文件
82+
ts_board = Tensorboard('./ckpt/tensorboard', method='step', interval=100) # tensorboard
83+
email = EmailCallback(receivers='[email protected]') # 发送邮件
84+
wandb = WandbCallback(save_code=True) # wandb
85+
hist = model.fit(train_dataloader, steps_per_epoch=steps_per_epoch, epochs=epochs,
86+
callbacks=[Summary(), evaluator, logger, ts_board, ckpt, early_stop])

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name='torch4keras',
10-
version='v0.1.0',
10+
version='v0.1.0.post2',
1111
description='Use torch like keras',
1212
long_description=long_description,
1313
long_description_content_type="text/markdown",

torch4keras/callbacks.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def update(self, current, values=None):
6262
if k not in self.stateful_metrics:
6363
if k not in self._values:
6464
self._values[k] = [v * (current - self._seen_so_far), current - self._seen_so_far]
65-
elif (self.smooth_interval is not None) and (current % self.smooth_interval == 0):
65+
elif (self.smooth_interval is not None) and (current % self.smooth_interval == 1):
6666
# 如果定义了累积smooth_interval,则需要重新累计
6767
self._values[k] = [v, 1]
6868
else:
@@ -435,7 +435,7 @@ def smooth_values(self, current, values=None):
435435
if k not in self.stateful_metrics:
436436
if k not in self._values:
437437
self._values[k] = [v * (current - self._seen_so_far), current - self._seen_so_far]
438-
elif (self.smooth_interval is not None) and (current % self.smooth_interval == 0):
438+
elif (self.smooth_interval is not None) and (current % self.smooth_interval == 1):
439439
# 如果定义了累积smooth_interval,则需要重新累计
440440
self._values[k] = [v, 1]
441441
else:
@@ -887,15 +887,13 @@ class Tensorboard(Callback):
887887
:param method: str, 控制是按照epoch还是step来计算,默认为'epoch', 可选{'step', 'epoch'}
888888
:param interval: int, 保存tensorboard的间隔
889889
:param prefix: str, tensorboard分栏的前缀,默认为'train'
890-
:param on_epoch_end_scalar_epoch: bool, epoch结束后是横轴是按照epoch还是global_step来记录
891890
'''
892-
def __init__(self, log_dir, method='epoch', interval=10, prefix='train', on_epoch_end_scalar_epoch=True, **kwargs):
891+
def __init__(self, log_dir, method='epoch', interval=10, prefix='train', **kwargs):
893892
super(Tensorboard, self).__init__(**kwargs)
894893
assert method in {'step', 'epoch'}, 'Args `method` only support `step` or `epoch`'
895894
self.method = method
896895
self.interval = interval
897896
self.prefix = prefix+'/' if len(prefix.strip()) > 0 else '' # 控制默认的前缀,用于区分栏目
898-
self.on_epoch_end_scalar_epoch = on_epoch_end_scalar_epoch # 控制on_epoch_end记录的是epoch还是global_step
899897

900898
from tensorboardX import SummaryWriter
901899
os.makedirs(log_dir, exist_ok=True)
@@ -904,8 +902,7 @@ def __init__(self, log_dir, method='epoch', interval=10, prefix='train', on_epoc
904902
def on_epoch_end(self, global_step, epoch, logs=None):
905903
if self.method == 'epoch':
906904
# 默认记录的是epoch
907-
log_step = epoch+1 if self.on_epoch_end_scalar_epoch else global_step+1
908-
self.process(log_step, logs)
905+
self.process(epoch+1, logs)
909906

910907
def on_batch_end(self, global_step, local_step, logs=None):
911908
# 默认记录的是global_step
@@ -998,22 +995,22 @@ def __init__(self, receivers, subject='', method='epoch', interval=10, mail_host
998995

999996
def on_epoch_end(self, global_step, epoch, logs=None):
1000997
if self.method == 'epoch':
1001-
msg = json.dumps({k:f'{v:.5f}' for k,v in logs.items() if k!='size'}, indent=2, ensure_ascii=False)
998+
msg = json.dumps({k:f'{v:.5f}' for k,v in logs.items() if k not in SKIP_METRICS}, indent=2, ensure_ascii=False)
1002999
subject = f'[INFO] Epoch {epoch+1} performance'
10031000
if self.subject != '':
10041001
subject = self.subject + ' | ' + subject
10051002
self._email(subject, msg)
10061003

10071004
def on_batch_end(self, global_step, local_step, logs=None):
10081005
if (self.method == 'step') and ((global_step+1) % self.interval == 0):
1009-
msg = json.dumps({k:f'{v:.5f}' for k,v in logs.items() if k!='size'}, indent=2, ensure_ascii=False)
1006+
msg = json.dumps({k:f'{v:.5f}' for k,v in logs.items() if k not in SKIP_METRICS}, indent=2, ensure_ascii=False)
10101007
subject = f'[INFO] Step {global_step} performance'
10111008
if self.subject != '':
10121009
subject = self.subject + ' | ' + subject
10131010
self._email(subject, msg)
10141011

10151012
def on_train_end(self, logs=None):
1016-
msg = json.dumps({k:f'{v:.5f}' for k,v in logs.items() if k!='size'}, indent=2, ensure_ascii=False)
1013+
msg = json.dumps({k:f'{v:.5f}' for k,v in logs.items() if k not in SKIP_METRICS}, indent=2, ensure_ascii=False)
10171014
subject = f'[INFO] Finish training'
10181015
if self.subject != '':
10191016
subject = self.subject + ' | ' + subject

torch4keras/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def compile(self, loss, optimizer, scheduler=None, clip_grad_norm=None, mixed_pr
4545
:param metrics: str/List[str]/dict, 训练过程中需要打印的指标, loss相关指标默认会打印, 目前支持accuracy, 也支持自定义metric,形式为{key: func}
4646
:param grad_accumulation_steps: int, 梯度累积步数,默认为1
4747
:param bar: str, 使用进度条的种类,从kwargs中解析,默认为keras, 可选keras, tqdm, progressbar2
48-
:param progbar_config: 进度条的配置,如果使用指标平滑会更新到后续其他callbacks中(比如Logger),实现进度条显示和日志会保持一致
48+
:param progbar_config: 进度条的配置,默认是对整个epoch计算均值指标
4949
bar: str, 默认为keras
5050
stateful_metrics: List[str], 表示不使用指标平滑仅进行状态记录的metric,指标抖动会更加明显,默认为None表示使用指标平滑
5151
smooth_interval: int, 表示指标平滑时候的累计步数,默认为None表示对整个epoch进行平滑

torch4keras/snippets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def send_email(receivers, subject, msg="", mail_host=None, mail_user=None, mail_
246246
smtpObj.login(mail_user, mail_pwd) # 登录到服务器
247247
smtpObj.sendmail(sender, receivers, message.as_string()) # 发送
248248
smtpObj.quit() # 退出
249-
print('[INFO] Send email success')
249+
log_info('Send email success')
250250
except smtplib.SMTPException as e:
251251
log_error('Send email error : '+str(e))
252252
return str(e)
@@ -357,7 +357,7 @@ def print_trainable_parameters(module):
357357
all_param += num_params
358358
if param.requires_grad:
359359
trainable_params += num_params
360-
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
360+
log_info(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
361361

362362

363363
def get_parameter_device(parameter):

0 commit comments

Comments
 (0)