Skip to content

[BUG] count_used_parameters_in_backward does not work with PyTorch 2.1.2 #7756

@riyadhrazzaq

Description

@riyadhrazzaq

Describe the bug
Error thrown count_used_parameters_in_backward requires internal PyTorch APIs that are not available in this PyTorch build when training with zero 2.

To Reproduce
Following is the ds_config.json:

{
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "steps_per_print": 100,
    "bf16": {
      "enabled": true
    },
    "zero_optimization": {
      "stage": 2,
      "overlap_comm": true,
      "contiguous_gradients": true,
      "reduce_bucket_size": 5e8
    },
    "gradient_clipping": 1.0,
    "optimizer": {
      "type": "AdamW",
      "params": {
        "lr": "auto",
        "betas": "auto",
        "eps": "auto",
        "weight_decay": "auto"
      }
    },
    "scheduler": {
      "type": "WarmupCosineLR",
      "params": {
        "warmup_min_ratio": 0.0,
        "warmup_num_steps": "auto",
        "cos_min_ratio": 0.0,
        "total_num_steps": "auto"
      }
    },
    "activation_checkpointing": {
      "partition_activations": true,
      "contiguous_memory_optimization": true,
      "cpu_checkpointing": false
    },
    "aio": {
      "block_size": 1048576,
      "queue_depth": 8,
      "single_submit": false,
      "overlap_events": true
    },
    "zero_allow_untested_optimizer": true,
    "wall_clock_breakdown": false
  }

Following is the complete error message:

Traceback (most recent call last):
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 194, in <module>
    train_model(args)
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 155, in train_model
    trainer.train()
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
Traceback (most recent call last):
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 194, in <module>
    tr_loss_step = self.training_step(model, inputs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 3349, in training_step
    train_model(args)
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 155, in train_model
    self.accelerator.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/accelerator.py", line 2151, in backward
    trainer.train()
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)    return inner_training_loop(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward

  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    self.engine.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
    tr_loss_step = self.training_step(model, inputs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 3349, in training_step
    ret_val = func(*args, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2501, in backward
    self.accelerator.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/accelerator.py", line 2151, in backward
    loss.backward(**backward_kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
    torch.autograd.backward(    self.engine.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward

  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
    ret_val = func(*args, **kwargs)    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1003, in grad_handling_hook
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2501, in backward
    self._remaining_grad_acc_hooks = count_used_parameters_in_backward(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/utils.py", line 1445, in count_used_parameters_in_backward
    loss.backward(**backward_kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    assert check_internal_apis_for_count_used_parameters(), (
AssertionError:     torch.autograd.backward(count_used_parameters_in_backward requires internal PyTorch APIs that are not available in this PyTorch build.

  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Traceback (most recent call last):
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 194, in <module>
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1003, in grad_handling_hook
    self._remaining_grad_acc_hooks = count_used_parameters_in_backward(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/utils.py", line 1445, in count_used_parameters_in_backward
    assert check_internal_apis_for_count_used_parameters(), (
AssertionError: count_used_parameters_in_backward requires internal PyTorch APIs that are not available in this PyTorch build.
    train_model(args)
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 155, in train_model
Traceback (most recent call last):
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 194, in <module>
    trainer.train()
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 3349, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/accelerator.py", line 2151, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
    self.engine.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2501, in backward
    loss.backward(**backward_kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1003, in grad_handling_hook
    self._remaining_grad_acc_hooks = count_used_parameters_in_backward(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/utils.py", line 1445, in count_used_parameters_in_backward
    assert check_internal_apis_for_count_used_parameters(), (
AssertionError: count_used_parameters_in_backward requires internal PyTorch APIs that are not available in this PyTorch build.
    train_model(args)
  File "/scratch/username/llama_omni_asr_tts/omni_speech/train/stage1.py", line 155, in train_model
    trainer.train()
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/transformers/trainer.py", line 3349, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/accelerator.py", line 2151, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
    self.engine.backward(loss, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2501, in backward
    loss.backward(**backward_kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1003, in grad_handling_hook
    self._remaining_grad_acc_hooks = count_used_parameters_in_backward(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/deepspeed/runtime/utils.py", line 1445, in count_used_parameters_in_backward
    assert check_internal_apis_for_count_used_parameters(), (
AssertionError: count_used_parameters_in_backward requires internal PyTorch APIs that are not available in this PyTorch build.
DEBUG:filelock:Attempting to acquire lock 22721025180384 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 22721025180384 acquired on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 22721025180384 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 22721025180384 released on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 22721025180384 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 22721025180384 acquired on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 22721025180384 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 22721025180384 released on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 22908287533792 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 22908287533792 acquired on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 22908287533792 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 22908287533792 released on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 22908287533792 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 22908287533792 acquired on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 22908287533792 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 22908287533792 released on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 23106327347936 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 23106327347936 acquired on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 23106327347936 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 23106327347936 released on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 23106327347936 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 23106327347936 acquired on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 23106327347936 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 23106327347936 released on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 22552921752288 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 22552921752288 acquired on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 22552921752288 on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Lock 22552921752288 released on /home/username/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
DEBUG:filelock:Attempting to acquire lock 22552921752288 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 22552921752288 acquired on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Attempting to release lock 22552921752288 on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
DEBUG:filelock:Lock 22552921752288 released on /home/username/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock

  0%|          | 0/5000 [00:02<?, ?it/s]
[2026-01-01 11:19:01,409] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 580879) of binary: /scratch/username/.conda/envs/omni2/bin/python3.10
Traceback (most recent call last):
  File "/scratch/username/.conda/envs/omni2/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
omni_speech/train/stage1.py FAILED
------------------------------------------------------------

Expected behavior
The same script worked in previous conda environment where I installed deepspeed using pip.

ds_report output

DeepSpeed general environment info:
torch install path ............... ['/scratch/username/.conda/envs/omni2/lib/python3.10/site-packages/torch']
torch version .................... 2.1.2+cu121
deepspeed install path ........... ['/scratch/usernameconda/envs/omni2/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.18.3, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.8
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 503.62 GB

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types: 4 x A100
  • Interconnects (if applicable)
  • Python version 3.10

Additional context
Downgrading to v18.0.0 fixes the problem for me.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions