Skip to content

about train #10

@wzz-z

Description

@wzz-z

作者你好,很感谢您的工作,我已经复现了您的项目并取得良好的效果。
但有一点我十分困惑,当我使用过去我习惯的train框架去训练QKFormer时(如下)

    for epoch in range(1,config['epochs']+1):
        epoch_start_time = time.time()  # timer for entire epoch
        model.train()
        total_loss_epoch = 0
        epoch_iter = 0
        for iteration, data in enumerate(train_loader):
            img = data[0].cuda()  
            label = data[1].float().cuda()
            with suppress():
                pre_label = model(img).squeeze(1).float() 
                loss_dict = cal_loss(pre_label, label)
                loss_total = loss_dict['total']
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step() 

会出现以下报错:

Traceback (most recent call last):
  File "/opt/data/private/wzz/egg/train.py", line 91, in <module>
    loss_total.backward()
  File "/root/anaconda3/envs/edformer/lib/python3.10/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/root/anaconda3/envs/edformer/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

我想这或许是因为snn激活时的某些操作导致反向传播图二次调用,但我仔细地对比了我们二者的训练部分代码,仍未发现有什么关键的差异。
同时我也排除了是网络方面的问题,当我在model.py中运行以下代码时,输出都是正常的

if __name__ == '__main__':
    input = torch.randn(2, 3, 256, 256).cuda()
    model = spiking_transformer(
        img_size_h=256, img_size_w=256, patch_size=16, in_channels=3, num_classes=2,
        embed_dims=256, num_heads=4, mlp_ratios=2, qkv_bias=False, qk_scale=None,
        drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
        depths=3, sr_ratios=1, T=4, pretrained_cfg=None,
    ).cuda()
    output = model(input)
    print(output)
    from torchinfo import summary
    summary(model, input_size=(2, 3, 256, 256))

悉心向您请教

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions