-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
作者您好!
我在训练STANet时出现了在计算损失函数阶段tar和pred的shape不匹配问题(我没有对核心代码做过修改),报错信息如下:
然后我打断点进行排查:
在 src/impl/trainers/cd_trainer.py 的 train_epoch 方法中定位到 pred = self._process_model_out(out) 这行代码(见下图),此时 out 的 shape 为 [batch, 64, 256, 256]
紧接着代码跳到 src/impl/trainers/cd_trainer_metric.py 的 _process_model_out 方法(见下图),此时执行 out.squeeze(1) ,但是 out 的 shape 为 [batch, 64, 256, 256] ,所以 out.squeeze(1) 不会对 out 做任何操作,out 的 shape 仍为 [batch, 64, 256, 256]
然后代码跳到 src/impl/trainers/cd_trainer_metric.py 中计算损失函数(见下图),此时 tar 的 shape 为 [batch, 256 , 256] ,但 pred 的 shape 为 [batch, 64, 256, 256] 所以直接进行乘法运算就出现了错误
想请教下作者这个问题该如何解决,是否需要对代码做些改动?
Metadata
Metadata
Assignees
Labels
No labels