Skip to content

Segmentation DDP Failure multiple trials #698

@Aasim-ComputerVision

Description

@Aasim-ComputerVision

Search before asking

  • I have searched the RF-DETR issues and found no similar bug report.

Bug

Training won't start with DDP enabled for rf-detr segmentation
DDP Trials Attempted

Below is every configuration I tried.

Trial 1 - Basic DDP (torch.distributed.launch)

python -m torch.distributed.launch --nproc_per_node=2 --use_env train_seg_1024.py

Result:
RuntimeError: Expected to mark a variable ready only once.
Parameter segmentation_head.bias has been marked as ready twice.

Crash occurs during:
train_one_epoch()
scaler.scale(losses).backward()

Trial 2 - Using torchrun (recommended method)

torchrun --nproc_per_node=2 train_seg_1024_ddp.py

Same error:
segmentation_head.bias marked ready twice

Trial 3 - Adjusted effective batch size

Tried preserving effective batch size:

    GPUs = 2
    batch_size = 4
    grad_accum = 2
    Effective = 16
    
    GPUs = 2
    batch_size = 2
    grad_accum = 4
    Effective = 16

Same crash.
This is not related to gradient accumulation.

Trial 4 - Disabled evaluation completely

Changed:
run_test=False
fp16_eval=False
early_stopping=False

Still crashes.
So this is not evaluation-related.

Trial 5 -AMP On / Off

Tried both:
amp=True
amp=False

No change.
Still fails at backward pass.

Trial 6 - Reduced num_workers

Tried:
num_workers=6
num_workers=4
num_workers=2

No difference.
This is not a DataLoader issue.

Trial 7 -Explicit local_rank GPU binding

Added:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

Still crashes.
So this is not a rank-to-GPU mapping issue.

Trial 8 - Disabled wandb and tensorboard

wandb=False
tensorboard=False

No change.

Trial 9 - Removed fp16_eval

Set:
fp16_eval=False

Still crashes.

Exact Failure
Occurs during training backward under DDP:
RuntimeError: Expected to mark a variable ready only once.
Parameter at index 280 with name segmentation_head.bias has been marked as ready twice.

Observations

Error occurs before evaluation starts.
Always references segmentation_head.bias.
Only happens in segmentation model.
Detection-only model does not show this behavior.
Happens during backward pass inside DDP.

Interpretation

This PyTorch DDP error occurs when:
A parameter participates in multiple backward graphs in one iteration.
Re-entrant backward occurs.
A module parameter is reused outside forward.
DDP expects static graph but model graph is dynamic.
Since only segmentation_head.bias triggers this, it strongly suggests:
RF-DETR Segmentation head is not DDP-safe in version 1.4.3.

Likely Root Cause
Segmentation adds:
Mask head
Auxiliary mask losses
Multi-branch backward paths
This likely causes multiple autograd hooks firing on the same parameter in a single iteration.

PyTorch 2.x supports:

DistributedDataParallel(..., static_graph=True)
or:
model._set_static_graph()

If the graph is static.
Currently this is not enabled internally.

Question
Is RF-DETR Segmentation officially supported under DDP in 1.4.3?
Should static_graph=True be enabled internally for segmentation?
Is there a known workaround for segmentation DDP training?

Environment

RF-DETR Version: 1.4.3
Model: RFDETRSegMedium (Segmentation)
GPUs: 2 × RTX 5090
CUDA: 12.7
OS: Ubuntu 22.04
Launch method: torchrun --nproc_per_node=2
Dataset Details
Train images: 15,325
Val/Test images: 6,640
Training resolution: 432

Minimal Reproducible Example

from rfdetr import RFDETRSegMedium

model = RFDETRSegMedium()

model.train(
    dataset_dir="/workspace/SDMTP/dataset_root",
    output_dir="/workspace/rf-detr/runs_seg_M1_uni_crops_dataset_Medium_ddp/",
    epochs=30,
    batch_size=2,
    grad_accum_steps=4,
    amp=True,
    run_test=False,
    fp16_eval=False,
    eval_max_dets=100,
    num_workers=6,
)

torchrun --nproc_per_node=2 train_seg_1024_ddp.py

Additional

Error occurs before train epoch starts starts
Always references segmentation_head.bias
Happens in segmentation model
Happens during backward pass inside DDP

exact error log i got in one of the runs

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward
ap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable
 ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over ite
rations.
[rank0]: Parameter at index 280 with name segmentation_head.bias has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this partic
ular parameter during this iteration.
[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/rf-detr/train_seg_ddp.py", line 44, in <module>
[rank1]:     model.train(
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/rfdetr/detr.py", line 97, in train
[rank1]:     self.train_from_config(config, **kwargs)
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/rfdetr/detr.py", line 238, in train_from_config
[rank1]:     self.model.train(
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/rfdetr/main.py", line 360, in train
[rank1]:     train_stats = train_one_epoch(
rations.
[rank1]: Parameter at index 280 with name segmentation_head.bias has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this partic
ular parameter during this iteration.
[rank0]:[W217 07:36:50.984191104 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For mor
e info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0217 07:36:50.975000 2242 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 2308 closing signal SIGTERM
E0217 07:36:51.341000 2242 torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 0 (pid: 2307) of binary: /usr/local/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 7, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 901, in main
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 143, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 277, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
train_seg_ddp.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-02-17_07:36:50
  host      : b5381081e05b
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 2307)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Are you willing to submit a PR?

  • Yes, I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions