Skip to content

[BUG] IndexError in BF16_Optimizer.destroy() when optimizer=None (DummyOptim) #7752

@ptx9363

Description

@ptx9363

Describe the bug
When initializing DeepSpeed with optimizer=None (which creates DummyOptim) and BF16 enabled, calling engine.destroy() raises an IndexError.
The error occurs because BF16_Optimizer.destroy() unconditionally indexes bf16_groups even when using_real_optimizer == False and bf16_groups is empty.


To Reproduce
Steps to reproduce the behavior:

  1. Simple inference script to reproduce
#!/usr/bin/env python3
"""
Minimal reproduction script for BF16_Optimizer destroy bug.

This script reproduces the IndexError that occurs when:
1. DeepSpeedEngine is initialized with optimizer=None (triggers DummyOptim)
2. BF16 is enabled
3. engine.destroy() is called

The bug occurs because BF16_Optimizer.destroy() tries to access
self.bf16_groups[i] but bf16_groups remains empty when using DummyOptim.
"""

import os
import torch
import torch.nn as nn
import deepspeed
from deepspeed.runtime.utils import DummyOptim


def create_simple_model():
    """Create a minimal model for testing"""
    return nn.Sequential(
        nn.Linear(10, 5),
        nn.ReLU(),
        nn.Linear(5, 1),
    )


def create_ds_config():
    """Create minimal DeepSpeed config with BF16 enabled"""
    return {
        "bf16": {
            "enabled": True
        },
        "zero_optimization": {
            # Use stage 0 to avoid ZeRO complications
            "stage": 0
        },
        "train_batch_size": 1,
        "train_micro_batch_size_per_gpu": 1,
        "gradient_accumulation_steps": 1,
    }


def reproduce_bug():
    """Reproduce the BF16_Optimizer destroy bug"""

    # Set environment variables early to avoid MPI issues
    os.environ["LOCAL_RANK"] = "0"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"

    # Remove MPI environment variables completely to disable MPI detection
    mpi_env_vars = [
        "OMPI_COMM_WORLD_LOCAL_RANK",
        "OMPI_COMM_WORLD_RANK",
        "OMPI_COMM_WORLD_SIZE",
        "OMPI_UNIVERSE_SIZE",
        "OMPI_COMM_WORLD_LOCAL_SIZE",
    ]
    for var in mpi_env_vars:
        if var in os.environ:
            del os.environ[var]

    print("Creating simple model...")
    model = create_simple_model()

    print("Creating DeepSpeed config...")
    ds_config = create_ds_config()

    print("Initializing DeepSpeed distributed environment...")
    try:
        import torch.distributed as dist

        if not dist.is_initialized():
            dist.init_process_group(
                backend="nccl",
                init_method="env://",
                world_size=1,
                rank=0,
            )
        print("PyTorch distributed environment initialized")
    except Exception as e:
        print(f"Failed to initialize distributed environment: {e}")
        return False

    print("Initializing DeepSpeed with optimizer=None (this will create DummyOptim)...")
    try:
        # This will trigger the bug path:
        # 1. optimizer=None -> creates DummyOptim
        # 2. BF16_Optimizer detects DummyOptim -> using_real_optimizer=False
        # 3. bf16_groups remains empty []
        # 4. destroy() tries to access bf16_groups[i] -> IndexError
        engine, _, _, _ = deepspeed.initialize(
            model=model,
            optimizer=None,  # This is key - triggers DummyOptim creation
            config=ds_config,
            dist_init_required=False,  # Already initialized above
        )

        print("DeepSpeed initialized successfully")
        print(f"Engine optimizer type: {type(engine.optimizer)}")
        print(f"BF16_Optimizer using_real_optimizer: {engine.optimizer.using_real_optimizer}")
        print(f"BF16_Optimizer bf16_groups length: {len(engine.optimizer.bf16_groups)}")

        print("Calling engine.destroy() - this should trigger IndexError...")
        engine.destroy()  # This will fail with IndexError

        print("SUCCESS: No error occurred!")
        return False

    except IndexError as e:
        print(f"REPRODUCED: IndexError occurred as expected: {e}")
        print("This confirms the bug exists!")
        return True

    except Exception as e:
        print(f"UNEXPECTED ERROR: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    print("=" * 60)
    print("BF16_Optimizer Destroy Bug Reproduction")
    print("=" * 60)

    bug_reproduced = reproduce_bug()

    print("=" * 60)
    if bug_reproduced:
        print("✓ Bug successfully reproduced!")
        print("The issue is in BF16_Optimizer.destroy() accessing uninitialized bf16_groups")
    else:
        print("✗ Bug was not reproduced - may have been fixed or conditions not met")
    print("=" * 60)

  1. What packages are required and their versions

    • deepspeed == 0.17.5
    • torch >= 2.x
    • CUDA + NCCL available
  2. How to run the script
    Run the placeholder reproduction script in a single-GPU environment with BF16 enabled.

  3. Observe that calling engine.destroy() raises IndexError: list index out of range.


Expected behavior
engine.destroy() should complete without error when BF16 is enabled and optimizer=None is used.


ds_report output
No ds_report output


Screenshots
Not applicable (exception traceback only).


System info (please complete the following information):

  • OS: Ubuntu 20.04
  • GPU count and types: Single GPU (reproducible on single-node, single-GPU)
  • (if applicable) DeepSpeed-MII version: N/A
  • (if applicable) Hugging Face Transformers / Accelerate versions: N/A
  • Python version: 3.10
  • Any other relevant info about your setup: ZeRO stage 0, BF16 enabled

Docker context
Not using a custom Docker image.


Additional context
Root cause summary:

  • optimizer=None causes DeepSpeed to create DummyOptim
  • BF16_Optimizer.using_real_optimizer == False
  • bf16_groups is never populated
  • BF16_Optimizer.destroy() still indexes bf16_groups[i], leading to IndexError

A minimal fix would be to short-circuit BF16_Optimizer.destroy() when no real optimizer is used (e.g. return early when using_real_optimizer == False or when bf16_groups is empty).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions