Skip to content

Issue in matcher - ValueError: matrix contains invalid numeric entries #784

@svengoluza

Description

@svengoluza

Hi, I get the following error while training a model:

Epoch: [10/100]:  51%|█████████████████████████████████████████████████████████                                                      | 169/329 [09:41<09:10, 
 3.44s/it, lr=0.000100, class_loss=4.04, box_loss=0.09, loss=35.33, max_mem=24481 MB]                                                                        
Traceback (most recent call last):                                                                                                                           
  File "<frozen runpy>", line 198, in _run_module_as_main                                                                                                    
  File "<frozen runpy>", line 88, in _run_code                                                                                                               
  File "/home/ubuntu/sgoluza/instance_segmentation/src/training/trainer.py", line 158, in <module>                                                           
    main()                                                                                                                                                   
  File "/home/ubuntu/sgoluza/instance_segmentation/src/training/trainer.py", line 143, in main                                                               
    checkpoint_path = train(config=config, model_size=model_size, resume_training_weights=resume_training_weights)                                           
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                           
  File "/home/ubuntu/sgoluza/instance_segmentation/src/training/trainer.py", line 85, in train                                                               
    model.train(**train_kwargs)                                                                                                                              
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/rfdetr/detr.py", line 105, in train                                    
    self.train_from_config(config, **kwargs)                                                                                                                 
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/rfdetr/detr.py", line 281, in train_from_config                        
    self.model.train(                                                                                                                                        
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/rfdetr/main.py", line 416, in train                                    
    train_stats = train_one_epoch(                                                                                                                           
                  ^^^^^^^^^^^^^^^^                                                                                                                           
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/rfdetr/engine.py", line 187, in train_one_epoch                        
    loss_dict = criterion(outputs, new_targets)                                                                                                              
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                              
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl          
    return self._call_impl(*args, **kwargs)                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl                  
    return forward_call(*args, **kwargs)                                                                                                                     
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/rfdetr/models/lwdetr.py", line 706, in forward
    indices = self.matcher(outputs_without_aux, targets, group_detr=group_detr)                                                                              
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                              
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                   
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)                                              
           ^^^^^^^^^^^^^^^^^^^^^                                              
  File "/home/ubuntu/sgoluza/instance_segmentation/.venv/lib/python3.12/site-packages/rfdetr/models/matcher.py", line 186, in forward
    indices_g = [linear_sum_assignment(c[i]) for i, c in enumerate(C_g.split(sizes, -1))]                                                                    
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                  
ValueError: matrix contains invalid numeric entries

I’m using RFDETRSegXLarge with rfdetr 1.5.2, CUDA 13.1, and PyTorch 2.8.0. You can find the full training configuration below. The same error consistently occurs around epoch 10.

Two days ago, I ran the same training configuration without augmentations on rf-detr 1.4.2, and everything worked smoothly. The only change I made was upgrading to version 1.5.2 so I could use augmentations.

I suspect the issue may be related to augmentations. I’ve already tried turning AMP on and off, and I checked my dataset multiple times - annotations are clean and there are no background-only images.

I also inspected the source code and found a potential bug in rfdetr/models/matcher.py lines 176-178:

max_cost = C.max() if C.numel() > 0 else 0
C[C.isinf() | C.isnan()] = max_cost * 2

.max() propagates NaN - if any element of C is NaN, C.max() returns NaN, so the replacement C[mask] = NaN * 2 is a no-op. The cleanup was intended to handle non-finite values but silently fails in exactly the cases it's supposed to catch. I'm currently exploring further to resolve this but wanted to ask if somebody has the same problems.


AUGMENTATION_CONFIGS = {
    "HorizontalFlip": {"p": 0.5},
    "CLAHE": {"clip_limit": 2.0, "tile_grid_size": (8, 8), "p": 0.3},
    "GaussianBlur": {"blur_limit": (3, 7), "p": 0.2},
    "GaussNoise": {"std_range": (0.01, 0.05), "p": 0.2},
    "RandomBrightnessContrast": {
        "brightness_limit": 0.15,
        "contrast_limit": 0.15,
        "p": 0.3,
    },
    "Sharpen": {"alpha": (0.2, 0.5), "lightness": (0.5, 1.0), "p": 0.2},
}

@dataclass
class TrainingConfig:
    """
    Parameters passed directly to rfdetr model.train().
    """
    # data related paths
    dataset_dir: Path
    output_dir: Path
    # architecture related hyperparameters
    resolution: int = 624
    group_detr: int = 13
    # training related hyperparameters
    epochs: int = 100
    batch_size: int = 4
    grad_accum_steps: int = 4
    lr: float = 1e-4
    amp: bool = False
    gradient_checkpointing: bool = False
    num_workers: int = 2
    multi_scale: bool = True
    aug_config: dict = field(default_factory=lambda: AUGMENTATION_CONFIGS)
    # logging hyperparameters
    print_freq: int = 100
    tensorboard: bool = True
    run_test: bool = False
    progress_bar: bool = True
    # validation hyperparameters
    early_stopping: bool = True
    early_stopping_patience: int = 20
    use_ema: bool = True
    early_stopping_use_ema: bool = False
    num_select: int = 100
    eval_max_dets: int = 500
    fp16_eval: bool = False
    # reproducibility
    seed: int = 42

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions