@@ -148,18 +148,24 @@ def _log_first_step(self, resume_step, train_X):
148
148
print (colorful ('[Label]: ' , color = 'green' ), + train_X )
149
149
150
150
def _forward (self , * inputs , ** input_kwargs ):
151
+ '''调用模型的forward,方便下游继承的时候可以自定义使用哪个模型的forward
152
+ '''
153
+ return self ._argparse_forward (self .unwrap_model (), * inputs , ** input_kwargs )
154
+
155
+ @staticmethod
156
+ def _argparse_forward (model , * inputs , ** input_kwargs ):
151
157
'''调用模型的forward
152
158
如果传入了网络结构module,则调用module的forward;如果是继承方式,则调用自身的forward
153
159
'''
154
160
if (len (inputs )== 1 ) and isinstance (inputs [0 ], (tuple ,list )): # 防止([])嵌套
155
161
inputs = inputs [0 ]
156
162
157
163
if isinstance (inputs , torch .Tensor ): # tensor不展开
158
- return self . unwrap_model () .forward (inputs , ** input_kwargs )
164
+ return model .forward (inputs , ** input_kwargs )
159
165
elif isinstance (inputs , (tuple , list )):
160
- return self . unwrap_model () .forward (* inputs , ** input_kwargs )
166
+ return model .forward (* inputs , ** input_kwargs )
161
167
else :
162
- return self . unwrap_model () .forward (inputs , ** input_kwargs )
168
+ return model .forward (inputs , ** input_kwargs )
163
169
164
170
def train_step (self , train_X , train_y ):
165
171
''' Perform a training step on a batch of inputs. '''
0 commit comments