Skip to content

DDP Validation Metric Logging got Misplaced Silently When Processes Different Metric Keys on Different Devices #21409

@worldlife123

Description

@worldlife123

Bug description

Problem

In PyTorch Lightning's DDP (Distributed Data Parallel) mode, when different processes log different metric keys during validation (based on their local data), the synchronization mechanism fails, causing:

  1. Metrics to be assigned to wrong keys
  2. Incorrect aggregated values
  3. No clear error or warning about the mismatch

Possible Root Cause

Lightning's sync_dist=True mechanism assumes homogeneous metric keys across all processes at each step. When this assumption is violated (e.g., Process 0 logs ["loss", "metric_a"] while Process 1 logs ["loss", "metric_b"]), the synchronization logic becomes confused, leading to data misalignment.

Minimal Example

class BuggyModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        # Common metric - all processes log this
        self.log("val_loss", loss, sync_dist=True)
        
        # Different processes log different metrics
        if y[0] % 2 == 0:  # Even samples
            self.log("val_metric_even", torch.tensor(0.5), sync_dist=True)
        else:  # Odd samples
            self.log("val_metric_odd", torch.tensor(0.7), sync_dist=True)
        
        return {"val_loss": loss, "y": y}

Expected Behavior

One of the following:

  1. Option A (Preferred): Lightning properly handles heterogeneous metric keys by:

    • Synchronizing each key independently across processes
    • Only aggregating metrics where all processes contributed values
    • Providing clear warnings about partial metric coverage
  2. Option B: Clear error/warning when metric key mismatch is detected, guiding users to:

    • Use sync_dist=False and handle synchronization manually
    • Ensure consistent logging across processes
    • Use the all_gather API for heterogeneous metrics
  3. Option C: Add a flag like allow_heterogeneous_keys=True that enables proper synchronization of different keys.

Actual Behavior

  • Silent misalignment of metric values
  • Incorrect metrics reported to logger (TensorBoard, WandB, etc.)
  • No error or warning, making debugging extremely difficult

Impact

This affects many real-world scenarios:

  1. Multi-task learning: Different tasks may have different evaluation metrics
  2. Imbalanced datasets: Rare classes might trigger special metrics only in some batches
  3. Conditional evaluation: Some metrics only make sense for certain data subsets

Current User Workarounds (All Unsatisfactory)

  1. Log dummy values: Forces all processes to log all keys, wasting computation
  2. Disable sync_dist: Lose automatic synchronization benefits
  3. Manual all_gather: Requires significant boilerplate code
  4. Log only at epoch end: Lose per-step logging granularity

Suggested Fix

Maybe implement proper key-aware synchronization like this?

# Pseudo-code for proper synchronization
def sync_metrics(metrics_dict, group, rank):
    # Gather all keys from all processes
    all_keys = [set(metrics_dict.keys()) for _ in range(world_size)]
    # Synchronize each key independently
    for key in union_of_all_keys:
        if key in metrics_dict:
            sync_tensor(metrics_dict[key], group)
        else:
            # Handle missing key (skip or use NaN)
            pass

** Below I provide a complete reproduction script of this bug, where a model logs 'val_metric_a', 'val_metric_b', 'val_metric_c' depends on the data class, which should be around 1.0, 2.0, 3.0, respectively. However, when viewing in Tensorboard, the values are not that case, as shown in this image (version_0 use sync_dist, and version 1 not): **

Image

What version are you seeing the problem on?

v2.4

Reproduced in studio

No response

How to reproduce the bug

"""
Reproduction script for PyTorch Lightning DDP validation logging bug
when processes log different metric keys.

Issue: When different processes log different metric keys during validation,
the synchronized logs become misplaced/corrupted.

Expected: Each metric should be properly synchronized across processes.
Actual: Metrics get misaligned, causing wrong values or missing metrics.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

class SyntheticDataset(Dataset):
    """Synthetic dataset that generates different data for different ranks"""
    def __init__(self, size=1000):
        self.size = size
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        # Create input with 2 channels
        x = torch.randn(3, 32, 32)
        
        # Create different labels for different data points
        # This will cause different processes to see different types of samples
        if idx % 3 == 0:
            y = torch.tensor(0)  # Type A
        elif idx % 3 == 1:
            y = torch.tensor(1)  # Type B
        else:
            y = torch.tensor(2)  # Type C
            
        return x, y


class BuggyModel(pl.LightningModule):
    """Model that reproduces the DDP logging bug"""
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 3)
        
        # Track what we're logging for debugging
        self.logged_keys = []
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # CRITICAL BUG REPRODUCTION:
        # Different processes log different metrics based on the sample type
        # This simulates real-world scenarios where different data requires
        # different evaluation metrics
        
        # Get unique labels in this batch
        unique_labels = torch.unique(y)
        
        # Store what we log for debugging
        step_logged_keys = []
        
        # Always log the common loss
        self.log('val_loss', loss, sync_dist=True, prog_bar=True)
        step_logged_keys.append('val_loss')
        
        # Log type-specific metrics (THIS IS WHERE THE BUG HAPPENS)
        if 0 in unique_labels:  # Type A samples
            # Process with Type A samples computes metric A
            metric_a = torch.rand(1).item() * 0.5 + 1  # Simulated metric
            self.log('val_metric_a', metric_a, sync_dist=True)
            step_logged_keys.append('val_metric_a')
            
        if 1 in unique_labels:  # Type B samples  
            # Process with Type B samples computes metric B
            metric_b = torch.rand(1).item() * 0.5 + 2  # Simulated metric
            self.log('val_metric_b', metric_b, sync_dist=True)
            step_logged_keys.append('val_metric_b')
            
        if 2 in unique_labels:  # Type C samples
            # Process with Type C samples computes metric C
            metric_c = torch.rand(1).item() * 0.5 + 3  # Simulated metric
            self.log('val_metric_c', metric_c, sync_dist=True)
            step_logged_keys.append('val_metric_c')
        
        # Record what was logged in this step
        self.logged_keys.append(step_logged_keys)
        
        # Print what each rank is logging (for debugging)
        rank = self.trainer.global_rank if self.trainer else 0
        print(f"Rank {rank}, Batch {batch_idx}: Logged keys = {step_logged_keys}")
        
        return {
            'val_loss': loss,
            'labels': y,
            'logged_keys': step_logged_keys
        }
    
    def on_validation_epoch_end(self):
        """Analyze the logging issue at the end of validation"""
        if hasattr(self, 'trainer') and self.trainer:
            rank = self.trainer.global_rank
            print(f"\n=== Rank {rank} Summary ===")
            print(f"Total validation steps: {len(self.logged_keys)}")
            
            # Check for inconsistent logging patterns
            all_keys = set()
            for step_keys in self.logged_keys:
                all_keys.update(step_keys)
            
            print(f"All keys logged by this rank: {sorted(all_keys)}")
            
            # Count occurrences of each key
            key_counts = {}
            for step_keys in self.logged_keys:
                for key in step_keys:
                    key_counts[key] = key_counts.get(key, 0) + 1
            
            print(f"Key frequencies: {key_counts}")
        
        # Reset for next epoch
        self.logged_keys = []
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def reproduce_bug():
    """Main function to reproduce the bug"""
    
    print("=" * 80)
    print("PyTorch Lightning DDP Validation Logging Bug Reproduction")
    print("=" * 80)
    
    # Create dataset and dataloader
    dataset = SyntheticDataset(size=32)
    
    # Use a sampler that ensures different ranks get different data distributions
    # This maximizes the chance of different metrics being logged
    train_loader = DataLoader(
        dataset, 
        batch_size=1,
        shuffle=True,
        num_workers=0  # Set to 0 for easier debugging
    )
    
    val_loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0
    )
    
    # Create model
    model = BuggyModel()
    
    # Setup trainer with DDP
    trainer = Trainer(
        max_epochs=2,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=2 if torch.cuda.is_available() else 2,  # Use 2 processes
        strategy='ddp' if torch.cuda.is_available() else 'ddp_spawn',
        num_nodes=1,
        enable_progress_bar=True,
        log_every_n_steps=1,
        enable_checkpointing=False,
        enable_model_summary=False,
    )
    
    print("\nTraining with DDP (2 processes)...")
    print("Expected: All metrics (val_metric_a, val_metric_b, val_metric_c) should be logged correctly.")
    print("Bug: Metrics will be misplaced because processes log different keys.\n")
    
    # Train and validate
    trainer.fit(model, train_loader, val_loader)
    
    print("\n" + "=" * 80)
    print("Bug Reproduction Complete!")
    print("=" * 80)
    print("\nAnalysis:")
    print("1. Each process logs different metrics based on the data it receives.")
    print("2. Lightning tries to synchronize these logs across processes.")
    print("3. Because the metric keys differ, the synchronization gets confused.")
    print("4. Result: Some metrics show wrong values or appear in wrong steps.")
    print("\nCheck the logs above to see the issue:")
    print("- Look for 'val_metric_a', 'val_metric_b', 'val_metric_c' in TensorBoard/logger")
    print("- Notice they might show incorrect values or appear/disappear unexpectedly")


def simple_repro():
    """Even simpler reproduction - can run without GPU"""
    
    print("\n" + "=" * 80)
    print("Simplified Bug Reproduction (Single Process Simulation)")
    print("=" * 80)
    
    # Simulate what happens in DDP
    print("\nSimulating DDP with 2 processes:")
    
    # Process 0 logs in step 0
    print("\nStep 0:")
    print("  Process 0 logs: ['val_loss', 'val_metric_a', 'val_metric_b']")
    print("  Process 1 logs: ['val_loss', 'val_metric_c']")
    print("  Lightning tries to sync: Expects same keys from all processes!")
    
    # Process 0 logs in step 1
    print("\nStep 1:")
    print("  Process 0 logs: ['val_loss', 'val_metric_a']")
    print("  Process 1 logs: ['val_loss', 'val_metric_b', 'val_metric_c']")
    print("\nProblem: Key mismatch causes:")
    print("  1. Metric values get assigned to wrong keys")
    print("  2. Some metrics appear/disappear")
    print("  3. Aggregated values are incorrect")
    
    print("\nExpected behavior:")
    print("  Lightning should handle heterogeneous metric keys across processes")
    print("  or at least provide a clear error/warning")


if __name__ == "__main__":
    # Run simple explanation first
    # simple_repro()
    
    # Uncomment to run the full DDP reproduction (requires 2+ GPUs or CPU with DDP support)
    reproduce_bug()

Error messages and logs

================================================================================
PyTorch Lightning DDP Validation Logging Bug Reproduction
================================================================================
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores

Training with DDP (2 processes)...
Expected: All metrics (val_metric_a, val_metric_b, val_metric_c) should be logged correctly.
Bug: Metrics will be misplaced because processes log different keys.

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
================================================================================
PyTorch Lightning DDP Validation Logging Bug Reproduction
================================================================================

Training with DDP (2 processes)...
Expected: All metrics (val_metric_a, val_metric_b, val_metric_c) should be logged correctly.
Bug: Metrics will be misplaced because processes log different keys.

Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
/home/xzy/miniconda3/envs/cbench_vid/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'val_dataloader' does 
not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to 
improve performance.
Rank 0, Batch 0: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 1: Logged keys = ['val_loss', 'val_metric_c']

=== Rank 0 Summary ===
Total validation steps: 2
All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_c']
Key frequencies: {'val_loss': 2, 'val_metric_a': 1, 'val_metric_c': 1}
Rank 1, Batch 0: Logged keys = ['val_loss', 'val_metric_b']
Rank 1, Batch 1: Logged keys = ['val_loss', 'val_metric_a']

=== Rank 1 Summary ===
Total validation steps: 2
All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b']
Key frequencies: {'val_loss': 2, 'val_metric_b': 1, 'val_metric_a': 1}
/home/xzy/miniconda3/envs/cbench_vid/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'train_dataloader' does
not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to 
improve performance.
Rank 0, Batch 0: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 1: Logged keys = ['val_loss', 'val_metric_c']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2/16  0:00:00 • 0:00:01 192.15it/s                               Rank 1, Batch 0: Logged keys = ['val
Rank 0, Batch 2: Logged keys = ['val_loss', 'val_metric_b']
Rank 0, Batch 3: Logged keys = ['val_loss', 'val_metric_a']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/16  0:00:00 • 0:00:01 197.95it/s                               Rank 1, Batch 1: Logged keys = ['val
Rank 0, Batch 4: Logged keys = ['val_loss', 'val_metric_c']
Rank 0, Batch 5: Logged keys = ['val_loss', 'val_metric_b']
Rank 0, Batch 6: Logged keys = ['val_loss', 'val_metric_a']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━ 7/16  0:00:00 • 0:00:01 204.50it/s                               Rank 1, Batch 2: Logged keys = ['val
Rank 0, Batch 7: Logged keys = ['val_loss', 'val_metric_c']
Rank 0, Batch 8: Logged keys = ['val_loss', 'val_metric_b']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 9/16  0:00:00 • 0:00:01 204.93it/s                               Rank 1, Batch 3: Logged keys = ['val
Rank 0, Batch 9: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 10: Logged keys = ['val_loss', 'val_metric_c']
Rank 0, Batch 11: Logged keys = ['val_loss', 'val_metric_b']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━ 11/16 0:00:00 • 0:00:01 206.23it/s                               Rank 1, Batch 4: Logged keys = ['val
Rank 0, Batch 12: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 13: Logged keys = ['val_loss', 'val_metric_c']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━ 14/16 0:00:00 • 0:00:01 207.14it/s                               Rank 1, Batch 5: Logged keys = ['val
Rank 0, Batch 14: Logged keys = ['val_loss', 'val_metric_b']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━ 15/16 0:00:00 • 0:00:01 193.32it/s                               Rank 1, Batch 6: Logged keys = ['val
Rank 0, Batch 15: Logged keys = ['val_loss', 'val_metric_a']
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 178.06it/s                               Rank 1, Batch 7: Logged keys = ['val_loss', 'val_metric_a']
Rank 1, Batch 8: Logged keys = ['val_loss', 'val_metric_c']
Rank 1, Batch 9: Logged keys = ['val_loss', 'val_metric_b']
Rank 1, Batch 10: Logged keys = ['val_loss', 'val_metric_a']
Rank 1, Batch 11: Logged keys = ['val_loss', 'val_metric_c']
Rank 1, Batch 12: Logged keys = ['val_loss', 'val_metric_b']
Rank 1, Batch 13: Logged keys = ['val_loss', 'val_metric_a']
Rank 1, Batch 14: Logged keys = ['val_loss', 'val_metric_c']
Rank 1, Batch 15: Logged keys = ['val_loss', 'val_metric_b']

=== Rank 1 Summary ===
Total validation steps: 16
All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c']

=== Rank 0 Summary ===
Total validation steps: 16
All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c']
Key frequencies: {'val_loss': 16, 'val_metric_a': 6, 'val_metric_c': 5, 'val_metric_b': 5}
Rank 0, Batch 0: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 1: Logged keys = ['val_loss', 'val_metric_c']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/16  0:00:00 • -:--:-- 0.00it/s                                                 Rank 1, Batch 0: Log
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2/16  0:00:00 • 0:00:01 219.10it/s                                               Rank 1, Batch 1: Log
Rank 0, Batch 2: Logged keys = ['val_loss', 'val_metric_b']
Rank 0, Batch 3: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 4: Logged keys = ['val_loss', 'val_metric_c']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5/16  0:00:00 • 0:00:01 209.41it/s                                               Rank 1, Batch 2: Log
Rank 0, Batch 5: Logged keys = ['val_loss', 'val_metric_b']
Rank 0, Batch 6: Logged keys = ['val_loss', 'val_metric_a']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━ 7/16  0:00:00 • 0:00:01 208.61it/s                                               Rank 1, Batch 3: Log
Rank 0, Batch 7: Logged keys = ['val_loss', 'val_metric_c']
Rank 0, Batch 8: Logged keys = ['val_loss', 'val_metric_b']
Rank 0, Batch 9: Logged keys = ['val_loss', 'val_metric_a']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 9/16  0:00:00 • 0:00:01 209.89it/s                                               Rank 1, Batch 4: Log
Rank 0, Batch 10: Logged keys = ['val_loss', 'val_metric_c']
Rank 0, Batch 11: Logged keys = ['val_loss', 'val_metric_b']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━ 12/16 0:00:00 • 0:00:01 210.92it/s                                               Rank 1, Batch 5: Log
Rank 0, Batch 12: Logged keys = ['val_loss', 'val_metric_a']
Rank 0, Batch 13: Logged keys = ['val_loss', 'val_metric_c']
Rank 0, Batch 14: Logged keys = ['val_loss', 'val_metric_b']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━ 14/16 0:00:00 • 0:00:01 211.86it/s                                               Rank 1, Batch 6: Log
Rank 0, Batch 15: Logged keys = ['val_loss', 'val_metric_a']
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116
Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 211.71it/s                                               Rank 1, Batch 7: Logged keys = ['val_loss', 'val_metric_a']
Rank 1, Batch 8: Logged keys = ['val_loss', 'val_metric_c']
Rank 1, Batch 9: Logged keys = ['val_loss', 'val_metric_b']
Rank 1, Batch 10: Logged keys = ['val_loss', 'val_metric_a']
Rank 1, Batch 11: Logged keys = ['val_loss', 'val_metric_c']
Rank 1, Batch 12: Logged keys = ['val_loss', 'val_metric_b']
Rank 1, Batch 13: Logged keys = ['val_loss', 'val_metric_a']
Rank 1, Batch 14: Logged keys = ['val_loss', 'val_metric_c']
Rank 1, Batch 15: Logged keys = ['val_loss', 'val_metric_b']

=== Rank 1 Summary ===
Total validation steps: 16
All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c']

=== Rank 0 Summary ===
Total validation steps: 16
All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c']
Key frequencies: {'val_loss': 16, 'val_metric_a': 6, 'val_metric_c': 5, 'val_metric_b': 5}
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116`Trainer.fit` stopped: `max_epochs=2` reached.

================================================================================
Bug Reproduction Complete!
================================================================================

Analysis:
1. Each process logs different metrics based on the data it receives.
2. Lightning tries to synchronize these logs across processes.
3. Because the metric keys differ, the synchronization gets confused.
4. Result: Some metrics show wrong values or appear in wrong steps.

Check the logs above to see the issue:
- Look for 'val_metric_a', 'val_metric_b', 'val_metric_c' in TensorBoard/logger
- Notice they might show incorrect values or appear/disappear unexpectedly
Epoch 1/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.099

================================================================================
Bug Reproduction Complete!
================================================================================

Analysis:
1. Each process logs different metrics based on the data it receives.
2. Lightning tries to synchronize these logs across processes.
3. Because the metric keys differ, the synchronization gets confused.
4. Result: Some metrics show wrong values or appear in wrong steps.

Check the logs above to see the issue:
- Look for 'val_metric_a', 'val_metric_b', 'val_metric_c' in TensorBoard/logger
- Notice they might show incorrect values or appear/disappear unexpectedly

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA TITAN Xp
    - NVIDIA TITAN X (Pascal)
    - available: True
    - version: 11.8
  • Lightning:
    - adabelief-pytorch: 0.2.0
    - lightning: 2.4.0
    - lightning-utilities: 0.15.2
    - pytorch-lightning: 2.6.0
    - pytorch-msssim: 1.0.0
    - torch: 2.4.1+cu118
    - torchaudio: 2.4.1+cu118
    - torchmetrics: 1.8.2
    - torchvision: 0.19.1+cu118
  • Packages:
    - absl-py: 2.3.1
    - adabelief-pytorch: 0.2.0
    - addict: 2.4.0
    - aiohappyeyeballs: 2.6.1
    - aiohttp: 3.13.2
    - aiosignal: 1.4.0
    - aliyun-python-sdk-core: 2.16.0
    - aliyun-python-sdk-kms: 2.16.5
    - attrs: 25.4.0
    - autograd: 1.8.0
    - av: 16.0.1
    - boto3: 1.42.2
    - botocore: 1.42.2
    - brotlipy: 0.7.0
    - cbench: 0.2
    - certifi: 2025.11.12
    - cffi: 2.0.0
    - charset-normalizer: 3.4.4
    - click: 8.3.1
    - colorama: 0.4.6
    - compressai: 1.2.3
    - contourpy: 1.3.3
    - crcmod: 1.7
    - cryptography: 46.0.3
    - cycler: 0.12.1
    - cython: 3.2.2
    - einops: 0.8.1
    - entmax: 1.1
    - filelock: 3.14.0
    - fonttools: 4.61.0
    - frozenlist: 1.8.0
    - fsspec: 2025.9.0
    - grpcio: 1.76.0
    - idna: 3.11
    - imageio: 2.37.2
    - jinja2: 3.1.6
    - jmespath: 0.10.0
    - kiwisolver: 1.4.9
    - lightning: 2.4.0
    - lightning-utilities: 0.15.2
    - litdata: 0.2.58
    - markdown: 3.10
    - markdown-it-py: 4.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.10.7
    - mdurl: 0.1.2
    - mmcv: 2.2.0
    - mmengine: 0.10.7
    - model-index: 0.1.11
    - mpmath: 1.3.0
    - multidict: 6.7.0
    - networkx: 2.6.3
    - numpy: 2.2.6
    - nvidia-cublas-cu11: 11.11.3.6
    - nvidia-cuda-cupti-cu11: 11.8.87
    - nvidia-cuda-nvrtc-cu11: 11.8.89
    - nvidia-cuda-runtime-cu11: 11.8.89
    - nvidia-cudnn-cu11: 9.1.0.70
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.3.0.86
    - nvidia-cusolver-cu11: 11.4.1.48
    - nvidia-cusparse-cu11: 11.7.5.86
    - nvidia-nccl-cu11: 2.20.5
    - nvidia-nvtx-cu11: 11.8.86
    - obstore: 0.8.2
    - opencv-python: 4.12.0.88
    - opendatalab: 0.0.10
    - openmim: 0.3.9
    - openxlab: 0.1.3
    - ordered-set: 4.1.0
    - oss2: 2.17.0
    - packaging: 24.2
    - pandas: 2.3.3
    - pillow: 11.3.0
    - pip: 25.3
    - platformdirs: 4.5.0
    - propcache: 0.4.1
    - protobuf: 6.33.1
    - ptflops: 0.7.5
    - pybind11: 3.0.1
    - pybind11-stubgen: 0.16.2
    - pycocotools: 2.0.10
    - pycparser: 2.23
    - pycryptodome: 3.23.0
    - pygments: 2.19.2
    - pyparsing: 3.2.5
    - pyrclone-wrapper: 0.0.3
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 2.6.0
    - pytorch-msssim: 1.0.0
    - pytz: 2023.4
    - pyyaml: 6.0.3
    - requests: 2.28.2
    - rich: 13.4.2
    - s3transfer: 0.16.0
    - scipy: 1.16.3
    - setuptools: 60.2.0
    - six: 1.17.0
    - sympy: 1.14.0
    - tabulate: 0.9.0
    - tensorboard: 2.20.0
    - tensorboard-data-server: 0.7.2
    - termcolor: 3.2.0
    - thop: 0.1.1.post2209072238
    - tifffile: 2025.10.16
    - torch: 2.4.1+cu118
    - torchaudio: 2.4.1+cu118
    - torchmetrics: 1.8.2
    - torchvision: 0.19.1+cu118
    - tqdm: 4.65.2
    - triton: 3.0.0
    - typing-extensions: 4.15.0
    - tzdata: 2025.2
    - urllib3: 1.26.20
    - werkzeug: 3.1.4
    - wheel: 0.45.1
    - yapf: 0.43.0
    - yarl: 1.22.0
    - zstandard: 0.25.0
    - zstd: 1.5.7.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.14
    - release: 5.15.0-139-generic
    - version: Gradient Accumulation Scheduler #149~20.04.1-Ubuntu SMP Wed Apr 16 08:29:56 UTC 2025

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions