Skip to content

Commit 75e77fd

Browse files
TongjiliboTongjilibo
Tongjilibo
authored and
Tongjilibo
committed
add _argparse_forward()
1 parent 573bf43 commit 75e77fd

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ pip install git+https://github.com/Tongjilibo/torch4keras.git
8585
- **v0.0.1**:20221019 初始版本
8686
8787
## 5. 更新:
88-
- **20230907**: 增加from_pretrained和save_pretrained方法,增加log_warn_once方法,compile()中可设置成员变量,默认move_to_model_device设置为True, 增加JsonConfig
88+
- **20230907**: 增加from_pretrained和save_pretrained方法,增加log_warn_once方法,compile()中可设置成员变量,默认move_to_model_device设置为True, 增加JsonConfig,增加_argparse_forward()方便下游继承改写Trainer
8989
- **20230901**: compile()可不传参,interval不一致报warning, 去除部分self.vars, 调整move_to_model_device逻辑,DDP每个epoch重新设置随机数,save_weights()和load_weights()可以按照`pretrained`格式
9090
- **20230821**: 代码结构调整,增加trainer.py文件,方便下游集成
9191
- **20230812**: 修复DeepSpeedTrainer,修复DDP

torch4keras/trainer.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,24 @@ def _log_first_step(self, resume_step, train_X):
148148
print(colorful('[Label]: ', color='green'), + train_X)
149149

150150
def _forward(self, *inputs, **input_kwargs):
151+
'''调用模型的forward,方便下游继承的时候可以自定义使用哪个模型的forward
152+
'''
153+
return self._argparse_forward(self.unwrap_model(), *inputs, **input_kwargs)
154+
155+
@staticmethod
156+
def _argparse_forward(model, *inputs, **input_kwargs):
151157
'''调用模型的forward
152158
如果传入了网络结构module,则调用module的forward;如果是继承方式,则调用自身的forward
153159
'''
154160
if (len(inputs)==1) and isinstance(inputs[0], (tuple,list)): # 防止([])嵌套
155161
inputs = inputs[0]
156162

157163
if isinstance(inputs, torch.Tensor): # tensor不展开
158-
return self.unwrap_model().forward(inputs, **input_kwargs)
164+
return model.forward(inputs, **input_kwargs)
159165
elif isinstance(inputs, (tuple, list)):
160-
return self.unwrap_model().forward(*inputs, **input_kwargs)
166+
return model.forward(*inputs, **input_kwargs)
161167
else:
162-
return self.unwrap_model().forward(inputs, **input_kwargs)
168+
return model.forward(inputs, **input_kwargs)
163169

164170
def train_step(self, train_X, train_y):
165171
''' Perform a training step on a batch of inputs. '''

0 commit comments

Comments
 (0)