Skip to content

[Bug] distill for yolox_s failed #184

Closed
@tanghy2016

Description

@tanghy2016

Describe the bug

The original Backbone for yolox_s is CSPDarknet. I use ResNet18 instead. The detailed configuration is as follows:

algorithm = dict(
    type='GeneralDistill',
    architecture=dict(
        type='MMDetArchitecture',
        model=dict(
            type='mmdet.YOLOX',
            input_size=(640, 640),
            random_size_range=(15, 25),
            random_size_interval=10,
            backbone=dict(
                type='ResNet',
                depth=18,
                num_stages=4,
                out_indices=(1, 2, 3),
                norm_cfg=dict(type='BN', requires_grad=True),
                norm_eval=True,
                style='pytorch',
                init_cfg=dict(
                    type='Pretrained', checkpoint='torchvision://resnet18')),
            neck=dict(
                type='YOLOXPAFPN',
                in_channels=[128, 256, 512],
                out_channels=128,
                num_csp_blocks=1),
            bbox_head=dict(
                type='YOLOXHead',
                num_classes=1,
                in_channels=128,
                feat_channels=128),
            train_cfg=dict(
                assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
            test_cfg=dict(
                score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))),
    with_student_loss=True,
    with_teacher_loss=False,
    distiller=dict(
        type='SingleTeacherDistiller',
        teacher=dict(
            type='mmdet.YOLOX',
            init_cfg=dict(
                type='Pretrained',
                checkpoint='/root/minio_model/1518897383465840642/epoch_19.pth'
            ),
            input_size=(640, 640),
            random_size_range=(15, 25),
            random_size_interval=10,
            backbone=dict(
                type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
            neck=dict(
                type='YOLOXPAFPN',
                in_channels=[128, 256, 512],
                out_channels=128,
                num_csp_blocks=1),
            bbox_head=dict(
                type='YOLOXHead',
                num_classes=1,
                in_channels=128,
                feat_channels=128),
            train_cfg=dict(
                assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
            test_cfg=dict(
                score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))),
        teacher_trainable=False,
        components=[
            dict(
                student_module='neck.out_convs.0.conv',
                teacher_module='neck.out_convs.0.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ]),
            dict(
                student_module='neck.out_convs.1.conv',
                teacher_module='neck.out_convs.1.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ]),
            dict(
                student_module='neck.out_convs.2.conv',
                teacher_module='neck.out_convs.2.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ])
        ]))

When it runs to the epoch 2, the following error occurs:

2022-06-20 07:27:45,022 - mmdet - INFO - Epoch [2][2/8]	lr: 6.250e-04, eta: 0:05:27, time: 0.702, data_time: 0.108, memory: 6444, student.loss_cls: 0.4989, student.loss_bbox: 4.8508, student.loss_obj: 19.1068, distiller.loss_cwd_logits.0: 0.2968, loss: 24.7533
Traceback (most recent call last):
  File "MMCV/train.py", line 52, in <module>
    main()
  File "MMCV/train.py", line 46, in main
    train_distill(args)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_distill.py", line 39, in train_distill
    train_mmdet(args, cfg)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_mmdet.py", line 43, in train_mmdet
    train_mmdet_model(
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/apis/mmdet/train.py", line 206, in train_mmdet_model
    runner.run(data_loader, cfg.workflow)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 29, in run_iter
    outputs = self.model.train_step(data_batch, self.optimizer,
  File "/opt/conda/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/models/algorithms/general_distill.py", line 49, in train_step
    distill_losses = self.distiller.compute_distill_loss(data)
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/models/distillers/single_teacher.py", line 240, in compute_distill_loss
    losses[loss_name] = loss_module(s_out, t_out)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/models/losses/cwd.py", line 41, in forward
    assert preds_S.shape[-2:] == preds_T.shape[-2:]
AssertionError

Correct situation, no matter preds_S or preds_T, its shape should be (batch_size, 3, 80, 80), (batch_size, 3, 40, 40) or (batch_size, 3, 20, 20). But when the above error occurs, in my debugging, various sizes have appeared.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions