Skip to content

Commit 3afee02

Browse files
committed
fix clip_grad_norm
1 parent 0e3da78 commit 3afee02

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ pip install git+https://github.com/Tongjilibo/torch4keras.git
6767
## 4. 版本历史
6868
|更新日期| 版本 | 版本说明 |
6969
|------| ----------------- |----------- |
70-
|20231207| 0.1.6 |监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题|
71-
|20230928| 0.1.5 |进度条中显示已经训练的时间|
70+
|20231208|v0.1.6.post2 |监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题; 修复clip_grad_norm的bug|
71+
|20230928|v0.1.5 |进度条中显示已经训练的时间|
7272
|20230912|v0.1.4.post2|History增加plot()方法, 增加add_module()方法,修复0.1.4的_argparse_forward的bug, 增加loss2metrics方法|
7373
7474
[更多版本](https://github.com/Tongjilibo/torch4keras/blob/master/docs/Update.md)
7575
7676
## 5. 更新历史:
77-
- **20231207**: 监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题
77+
- **20231208**: 监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题; 修复clip_grad_norm的bug
7878
- **20230928**: 进度条中显示已经训练的时间
7979
- **20230912**: History增加plot()方法, 增加add_module()方法,修复0.1.4的_argparse_forward的bug, 增加loss2metrics方法
8080

docs/Update.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
|更新日期| 版本 | 版本说明 |
44
|------| ----------------- |----------- |
5-
|20231207| 0.1.6 |监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题|
6-
|20230928| 0.1.5 |进度条中显示已经训练的时间|
5+
|20231207|v0.1.6 |监控fit过程,有报错则发送邮件提醒; 解决torch2.0的compile冲突问题|
6+
|20230928|v0.1.5 |进度条中显示已经训练的时间|
77
|20230912|v0.1.4.post2|History增加plot()方法, 增加add_module()方法,修复0.1.4的_argparse_forward的bug, 增加loss2metrics方法|
88
|20230909|v0.1.4|增加from_pretrained和save_pretrained方法,增加log_warn_once方法,compile()中可设置成员变量,默认move_to_model_device设置为True, 增加JsonConfig,增加_argparse_forward()方便下游继承改写Trainer|
99
|20230901|v0.1.3|compile()可不传参,interval不一致报warning, 去除部分self.vars, 调整move_to_model_device逻辑,DDP每个epoch重新设置随机数,save_weights()和load_weights()可以按照`pretrained`格式|

examples/tutorials_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
model = Trainer(net.to(device))
5555
optimizer = optim.Adam(net.parameters())
5656
scheduler = get_linear_schedule_with_warmup(optimizer, steps_per_epoch, steps_per_epoch*epochs)
57-
model.compile(optimizer=optimizer, scheduler=scheduler, loss=nn.CrossEntropyLoss(), metrics=['acc'])
57+
model.compile(optimizer=optimizer, scheduler=scheduler, loss=nn.CrossEntropyLoss(), metrics=['acc'], clip_grad_norm=1.0)
5858

5959
class MyEvaluator(Evaluator):
6060
# 重构评价函数

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name='torch4keras',
10-
version='v0.1.6',
10+
version='v0.1.6.post2',
1111
description='Use torch like keras',
1212
long_description=long_description,
1313
long_description_content_type="text/markdown",

torch4keras/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ def step(self):
217217
if self.mixed_precision:
218218
self.scaler.unscale_(self.optimizer)
219219
if self.clip_grad_norm is not None: # 梯度裁剪
220-
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_norm)
220+
torch.nn.utils.clip_grad_norm_(self.unwrap_model().parameters(), self.clip_grad_norm)
221221
self.scaler.step(self.optimizer)
222222
self.scaler.update()
223223
skip_scheduler = self.scaler.get_scale() != self.scale_before_step
224224
else:
225225
if self.clip_grad_norm is not None: # 梯度裁剪
226-
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_norm)
226+
torch.nn.utils.clip_grad_norm_(self.unwrap_model().parameters(), self.clip_grad_norm)
227227
self.optimizer.step()
228228

229229
self.optimizer.zero_grad() # 清梯度

0 commit comments

Comments
 (0)