Skip to content

Error when saving optimizer states #17

@oncleJules

Description

@oncleJules

Following Training script for OLMo 1B, when saving the optimizer states, I encountered the following error:

... (some normal training output until Step 1000)

2025-08-14 10:05:05.871 0d171996b0c8:2 olmo.util:165 CRITICAL Uncaught AssertionError:
Traceback (most recent call last):
File "/root/COAT/examples/OLMo/scripts/train.py", line 377, in
main(cfg)
File "/root/COAT/examples/OLMo/scripts/train.py", line 346, in main
trainer.fit()
File "/root/COAT/examples/OLMo/olmo/train.py", line 1319, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 603, in save_checkpoint
result = self.save_sharded_checkpoint()
File "/root/COAT/examples/OLMo/olmo/train.py", line 511, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 471, in _save_checkpoint
checkpointer.save_checkpoint(
File "/root/COAT/examples/OLMo/olmo/checkpoint.py", line 1006, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
return _optim_state_dict(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
2025-08-14 10:05:05.926 0d171996b0c8:1 olmo.util:165 CRITICAL Uncaught AssertionError:
Traceback (most recent call last):
File "/root/COAT/examples/OLMo/scripts/train.py", line 377, in
main(cfg)
File "/root/COAT/examples/OLMo/scripts/train.py", line 346, in main
trainer.fit()
File "/root/COAT/examples/OLMo/olmo/train.py", line 1319, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 603, in save_checkpoint
result = self.save_sharded_checkpoint()
File "/root/COAT/examples/OLMo/olmo/train.py", line 511, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 471, in _save_checkpoint
checkpointer.save_checkpoint(
File "/root/COAT/examples/OLMo/olmo/checkpoint.py", line 1006, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
return _optim_state_dict(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
2025-08-14 10:05:05.960 0d171996b0c8:3 olmo.util:165 CRITICAL Uncaught AssertionError:
Traceback (most recent call last):
File "/root/COAT/examples/OLMo/scripts/train.py", line 377, in
main(cfg)
File "/root/COAT/examples/OLMo/scripts/train.py", line 346, in main
trainer.fit()
File "/root/COAT/examples/OLMo/olmo/train.py", line 1319, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 603, in save_checkpoint
result = self.save_sharded_checkpoint()
File "/root/COAT/examples/OLMo/olmo/train.py", line 511, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 471, in _save_checkpoint
checkpointer.save_checkpoint(
File "/root/COAT/examples/OLMo/olmo/checkpoint.py", line 1006, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
return _optim_state_dict(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
W0814 10:05:07.576000 622756 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 622846 closing signal SIGTERM
W0814 10:05:07.577000 622756 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 622847 closing signal SIGTERM
W0814 10:05:07.578000 622756 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 622849 closing signal SIGTERM
E0814 10:05:08.544000 622756 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 2 (pid: 622848) of binary: /opt/conda/envs/coat/bin/python3.10
Traceback (most recent call last):
File "/root/.local/bin/torchrun", line 7, in
sys.exit(main())
File "/root/.local/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 355, in wrapper
return f(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/root/.local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

scripts/train.py FAILED

Failures:
<NO_OTHER_FAILURES>

Root Cause (first observed failure):
[0]:
time : 2025-08-14_10:05:07
host : 0d171996b0c8
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 622848)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Is this expected? I would greatly appreciate it if anyone could help me with a fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions