-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
Following Training script for OLMo 1B, when saving the optimizer states, I encountered the following error:
... (some normal training output until Step 1000)
2025-08-14 10:05:05.871 0d171996b0c8:2 olmo.util:165 CRITICAL Uncaught AssertionError:
Traceback (most recent call last):
File "/root/COAT/examples/OLMo/scripts/train.py", line 377, in
main(cfg)
File "/root/COAT/examples/OLMo/scripts/train.py", line 346, in main
trainer.fit()
File "/root/COAT/examples/OLMo/olmo/train.py", line 1319, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 603, in save_checkpoint
result = self.save_sharded_checkpoint()
File "/root/COAT/examples/OLMo/olmo/train.py", line 511, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 471, in _save_checkpoint
checkpointer.save_checkpoint(
File "/root/COAT/examples/OLMo/olmo/checkpoint.py", line 1006, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
return _optim_state_dict(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
2025-08-14 10:05:05.926 0d171996b0c8:1 olmo.util:165 CRITICAL Uncaught AssertionError:
Traceback (most recent call last):
File "/root/COAT/examples/OLMo/scripts/train.py", line 377, in
main(cfg)
File "/root/COAT/examples/OLMo/scripts/train.py", line 346, in main
trainer.fit()
File "/root/COAT/examples/OLMo/olmo/train.py", line 1319, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 603, in save_checkpoint
result = self.save_sharded_checkpoint()
File "/root/COAT/examples/OLMo/olmo/train.py", line 511, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 471, in _save_checkpoint
checkpointer.save_checkpoint(
File "/root/COAT/examples/OLMo/olmo/checkpoint.py", line 1006, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
return _optim_state_dict(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
2025-08-14 10:05:05.960 0d171996b0c8:3 olmo.util:165 CRITICAL Uncaught AssertionError:
Traceback (most recent call last):
File "/root/COAT/examples/OLMo/scripts/train.py", line 377, in
main(cfg)
File "/root/COAT/examples/OLMo/scripts/train.py", line 346, in main
trainer.fit()
File "/root/COAT/examples/OLMo/olmo/train.py", line 1319, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 603, in save_checkpoint
result = self.save_sharded_checkpoint()
File "/root/COAT/examples/OLMo/olmo/train.py", line 511, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
File "/root/COAT/examples/OLMo/olmo/train.py", line 471, in _save_checkpoint
checkpointer.save_checkpoint(
File "/root/COAT/examples/OLMo/olmo/checkpoint.py", line 1006, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
return _optim_state_dict(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
W0814 10:05:07.576000 622756 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 622846 closing signal SIGTERM
W0814 10:05:07.577000 622756 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 622847 closing signal SIGTERM
W0814 10:05:07.578000 622756 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 622849 closing signal SIGTERM
E0814 10:05:08.544000 622756 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 2 (pid: 622848) of binary: /opt/conda/envs/coat/bin/python3.10
Traceback (most recent call last):
File "/root/.local/bin/torchrun", line 7, in
sys.exit(main())
File "/root/.local/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 355, in wrapper
return f(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "/root/.local/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/root/.local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/root/.local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
scripts/train.py FAILED
Failures:
<NO_OTHER_FAILURES>
Root Cause (first observed failure):
[0]:
time : 2025-08-14_10:05:07
host : 0d171996b0c8
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 622848)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
Is this expected? I would greatly appreciate it if anyone could help me with a fix.
Metadata
Metadata
Assignees
Labels
No labels