Skip to content

RF-DETR 1.6.4 segmentation training on single GPU still fails on first optimizer step with fused AdamW dtype mismatch after clean install #959

@Aasim-ComputerVision

Description

@Aasim-ComputerVision

Search before asking

  • I have searched the RF-DETR issues and found no similar bug report.

Bug

I am trying to train an instance segmentation model with RFDETRSegMedium on a single A100 GPU using the official built-in training path on a clean environment.

After working through several environment and dependency issues, I got training to progress through:

package install
model init
dataset build
sanity validation
COCO metric backend setup
The run now consistently fails at the first optimizer step with:

RuntimeError: params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout
The traceback ends in fused AdamW:

File ".../torch/optim/adamw.py", line 679, in _fused_adamw
torch.fused_adamw(
RuntimeError: params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout
This is happening in a clean venv on a single GPU, so this is no longer just a DDP issue.

Environment

RF-DETR: 1.6.4
PyTorch: 2.4.1+cu124
Torchvision: 0.19.1+cu124
Torchaudio: 2.4.1+cu124
PyTorch Lightning: 2.6.1
GPU: NVIDIA A100-SXM4-80GB
Driver Version: 570.195.03
CUDA reported by nvidia-smi: 12.8
Python: 3.11
OS/container: Linux on RunPod

Minimal Reproducible Example

from rfdetr import RFDETRSegMedium

model = RFDETRSegMedium(
    num_classes=2,
    resolution=672,
    gradient_checkpointing=False,
)

model.train(
    dataset_dir="/workspace/SDMTP/dataset_root_",
    output_dir="/workspace/rf-detr/runs_seg_builtin_164",

    epochs=50,
    batch_size=4,
    grad_accum_steps=4,
    lr=1e-4,
    lr_encoder=1.5e-4,
    weight_decay=1e-4,

    amp=True,
    use_ema=True,

    mask_point_sample_ratio=16,
    mask_ce_loss_coef=7.0,
    mask_dice_loss_coef=7.0,

    eval_interval=1,
    eval_max_dets=500,
    log_per_class_metrics=False,
    tensorboard=True,

    early_stopping=True,
    early_stopping_patience=10,
    early_stopping_min_delta=0.001,
    early_stopping_use_ema=True,

    checkpoint_interval=5,
)

Additional

The run gets through model loading, dataset creation, and sanity validation, then fails at the first training optimizer step.

Relevant logs:

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
Precision bf16-mixed is not supported by the model summary. Estimated model size in MB will not be accurate. Using 32 bits instead.
...
`use_return_dict` is deprecated! Use `return_dict` instead!
...
File "/workspace/venv-rfdetr-clean/lib/python3.11/site-packages/torch/optim/adamw.py", line 679, in _fused_adamw
    torch._fused_adamw_(
RuntimeError: params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout

Built-in single-GPU segmentation training should start normally and complete optimizer steps without hitting fused AdamW dtype/layout mismatch.

Important observations

  1. This is not just a DDP problem
    Earlier I was debugging DDP issues, but this issue reproduces on:

single GPU
clean environment
built-in training path
no distributed training
2. Built-in training appears to resolve to bf16-mixed
The log shows:

Precision bf16-mixed is not supported by the model summary
This suggests the built-in training path is using BF16 mixed precision on my setup. The optimizer crash appears to happen in that path.

  1. Clean environment was required to get this far
    Before reaching this optimizer-step failure, I had to resolve multiple environment/dependency issues. Once those were fixed, the optimizer bug became reproducible and consistent.

Hiccups encountered while getting to the actual bug
I am listing these because they significantly complicated diagnosis.

A. System-package blinker conflict

Using system Python and broad pip installs caused:

Cannot uninstall blinker 1.4
It is a distutils installed project...
This was solved only by switching to a clean virtual environment.

B. Torch/CUDA environment got silently polluted

At one point, unconstrained installs upgraded torch to a much newer CUDA build, causing errors like:

The NVIDIA driver on your system is too old (found version 12080)
This was misleading, because nvidia-smi showed a healthy modern driver. The actual problem was the Python environment being polluted with incompatible torch/CUDA wheels.

This was solved by creating a clean venv and pinning:

torch==2.4.1
torchvision==0.19.1
torchaudio==2.4.1
from the cu124 index.

C. Missing training dependencies were not obvious from the base install

After installing rfdetr, the built-in train path failed until I manually installed training/logging dependencies such as:

pytorch_lightning
torchmetrics
tensorboard
albumentations
pycocotools

D. Missing faster-coco-eval

Even after that, training failed during sanity validation until I manually installed:

faster-coco-eval
The failure was:

ModuleNotFoundError: Backend faster_coco_eval in metric MeanAveragePrecision metric requires that faster-coco-eval is installed.
Only after resolving this did the run reach the actual optimizer-step failure.

What seems to be happening
From the outside, this looks like:

built-in segmentation training chooses BF16 mixed precision
optimizer uses fused AdamW
fused AdamW then fails because params / grads / optimizer state are not aligned in dtype/device/layout
Since this now reproduces in a clean single-GPU setup, it looks like a built-in training bug or compatibility issue in the segmentation path, not just a user environment issue.

Are you willing to submit a PR?

  • Yes, I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions