-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
作者你好,很感谢您的工作,我已经复现了您的项目并取得良好的效果。
但有一点我十分困惑,当我使用过去我习惯的train框架去训练QKFormer时(如下)
for epoch in range(1,config['epochs']+1):
epoch_start_time = time.time() # timer for entire epoch
model.train()
total_loss_epoch = 0
epoch_iter = 0
for iteration, data in enumerate(train_loader):
img = data[0].cuda()
label = data[1].float().cuda()
with suppress():
pre_label = model(img).squeeze(1).float()
loss_dict = cal_loss(pre_label, label)
loss_total = loss_dict['total']
optimizer.zero_grad()
loss_total.backward()
optimizer.step()
会出现以下报错:
Traceback (most recent call last):
File "/opt/data/private/wzz/egg/train.py", line 91, in <module>
loss_total.backward()
File "/root/anaconda3/envs/edformer/lib/python3.10/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/root/anaconda3/envs/edformer/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
我想这或许是因为snn激活时的某些操作导致反向传播图二次调用,但我仔细地对比了我们二者的训练部分代码,仍未发现有什么关键的差异。
同时我也排除了是网络方面的问题,当我在model.py中运行以下代码时,输出都是正常的
if __name__ == '__main__':
input = torch.randn(2, 3, 256, 256).cuda()
model = spiking_transformer(
img_size_h=256, img_size_w=256, patch_size=16, in_channels=3, num_classes=2,
embed_dims=256, num_heads=4, mlp_ratios=2, qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=3, sr_ratios=1, T=4, pretrained_cfg=None,
).cuda()
output = model(input)
print(output)
from torchinfo import summary
summary(model, input_size=(2, 3, 256, 256))
悉心向您请教
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels