Open
Description
Describe the bug
My training run crashed during initialization when I attempt to use SwiGLU and Transformer Engine. I do not observe this issue when using the GELU default or when I set te_layernorm_mlp
to false.
NeoXArgs.configure_distributed_args() using world size: 8 and model-parallel size: 1
> building HFTokenizer tokenizer ...
> padded vocab (size: 50277) with 27 dummy tokens (new size: 50304)
Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, or directly from https://github.com/state-spaces/mamba
For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3
For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer
WARNING:root:Outstanding DeepSpeed issue means that pp>0, zero1, and bf16 will break without fp32 grads
> setting up tensorboard ...
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
Loading extension module scaled_upper_triang_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /workspace/gpt-neox/megatron/fused_kernels/build/build.ninja...
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
Building extension module scaled_masked_softmax_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module scaled_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /workspace/gpt-neox/megatron/fused_kernels/build/build.ninja...
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
Building extension module fused_rotary_positional_embedding...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_rotary_positional_embedding...
> initializing torch distributed ...
[2025-03-13 02:31:25,996] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:25,996] [INFO] [comm.py:689:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
[2025-03-13 02:31:26,691] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:26,829] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:26,998] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,011] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,022] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,023] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,029] [INFO] [comm.py:658:init_distributed] cdb=None
> initializing model parallel with size 1
MPU DP: [0, 1, 2, 3, 4, 5, 6, 7]
MPU PP: [0]
MPU PP: [1]
MPU PP: [2]
MPU PP: [3]
MPU PP: [4]
MPU PP: [5]
MPU PP: [6]
MPU PP: [7]
MPU MP: [0]
MPU MP: [1]
MPU MP: [2]
MPU MP: [3]
MPU MP: [4]
MPU MP: [5]
MPU MP: [6]
MPU MP: [7]
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1125:configure] Activation Checkpointing Information
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1126:configure] ----Partition Activations True, CPU CHECKPOINTING False
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1127:configure] ----contiguous Memory Checkpointing False with 32 total layers
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1128:configure] ----Synchronization True
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1129:configure] ----Profiling time in checkpointing False
> setting random seeds to 1234 ...
[2025-03-13 02:31:27,042] [INFO] [checkpointing.py:229:model_parallel_cuda_manual_seed] > initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
make: Entering directory '/workspace/gpt-neox/megatron/data'
g++ -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color -I/usr/include/python3.12 -I/usr/local/lib/python3.12/dist-packages/pybind11/include helpers.cpp -o helpers.cpython-312-x86_64-linux-gnu.so
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
make: Leaving directory '/workspace/gpt-neox/megatron/data'
> building train, validation, and test datasets ...
reading sizes...
reading pointers...
reading document index...
creating numpy buffer of mmap...
creating memory view of numpy buffer...
> dataset split:
train:
document indices in [0, 409935486) total of 409935486 documents
validation:
document indices in [409935486, 409935486) total of 0 documents
test:
document indices in [409935486, 409935486) total of 0 documents
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
counts = torch.cuda.LongTensor([1])
> loading doc-idx mapping from /data/pretraining-mix/pretraining-mix_text_document_train_indexmap_61035008ns_2048sl_1234s_packedpi_ac_doc_idx.npy
> loading sample-idx mapping from /data/pretraining-mix/pretraining-mix_text_document_train_indexmap_61035008ns_2048sl_1234s_packedpi_ac_sample_idx.npy
> loading shuffle-idx mapping from /data/pretraining-mix/pretraining-mix_text_document_train_indexmap_61035008ns_2048sl_1234s_packedpi_ac_shuffle_idx.npy
loaded indexed file in 0.009 seconds
total number of samples: 244340789
total number of epochs: 1
building GPT2 model ...
SEED_LAYERS=False BASE_SEED=1234 SEED_FN=None
Using topology: {ProcessCoord(pipe=0, data=0, model=0): 0, ProcessCoord(pipe=0, data=1, model=0): 1, ProcessCoord(pipe=0, data=2, model=0): 2, ProcessCoord(pipe=0, data=3, model=0): 3, ProcessCoord(pipe=0, data=4, model=0): 4, ProcessCoord(pipe=0, data=5, model=0): 5, ProcessCoord(pipe=0, data=6, model=0): 6, ProcessCoord(pipe=0, data=7, model=0): 7}
[2025-03-13 02:31:47,341] [INFO] [module.py:398:_partition_layers] Partitioning pipeline stages with method type:transformer|mlp
stage=0 layers=37
0: EmbeddingPipe
1: _pre_transformer_block
2: ParallelTransformerLayerPipe
3: ParallelTransformerLayerPipe
4: ParallelTransformerLayerPipe
5: ParallelTransformerLayerPipe
6: ParallelTransformerLayerPipe
7: ParallelTransformerLayerPipe
8: ParallelTransformerLayerPipe
9: ParallelTransformerLayerPipe
10: ParallelTransformerLayerPipe
11: ParallelTransformerLayerPipe
12: ParallelTransformerLayerPipe
13: ParallelTransformerLayerPipe
14: ParallelTransformerLayerPipe
15: ParallelTransformerLayerPipe
16: ParallelTransformerLayerPipe
17: ParallelTransformerLayerPipe
18: ParallelTransformerLayerPipe
19: ParallelTransformerLayerPipe
20: ParallelTransformerLayerPipe
21: ParallelTransformerLayerPipe
22: ParallelTransformerLayerPipe
23: ParallelTransformerLayerPipe
24: ParallelTransformerLayerPipe
25: ParallelTransformerLayerPipe
26: ParallelTransformerLayerPipe
27: ParallelTransformerLayerPipe
28: ParallelTransformerLayerPipe
29: ParallelTransformerLayerPipe
30: ParallelTransformerLayerPipe
31: ParallelTransformerLayerPipe
32: ParallelTransformerLayerPipe
33: ParallelTransformerLayerPipe
34: _post_transformer_block
35: NormPipe
36: ParallelLinearPipe
loss: partial
[rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank1]: main()
[rank1]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank1]: pretrain(neox_args=neox_args)
[rank1]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank1]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank1]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank1]: model = GPT2ModelPipe(
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank1]: super().__init__(
[rank1]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank1]: self._build()
[rank1]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank1]: module = layer.build()
[rank1]: ^^^^^^^^^^^^^
[rank1]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank1]: return self.typename(*self.module_args, **self.module_kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank1]: self.mlp = get_te_lnmlp()
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank1]: return TELayerNormMLP(
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank1]: self.activation_func = Gated_Activation(self.activation_func)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank1]: super().__setattr__(name, value)
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank1]: raise AttributeError(
[rank1]: AttributeError: cannot assign module before Module.__init__() call
[rank3]: Traceback (most recent call last):
[rank3]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank3]: main()
[rank3]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank3]: pretrain(neox_args=neox_args)
[rank3]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank3]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank3]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank3]: model = GPT2ModelPipe(
[rank3]: ^^^^^^^^^^^^^^
[rank3]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank3]: super().__init__(
[rank3]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank3]: self._build()
[rank3]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank3]: module = layer.build()
[rank3]: ^^^^^^^^^^^^^
[rank3]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank3]: return self.typename(*self.module_args, **self.module_kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank3]: self.mlp = get_te_lnmlp()
[rank3]: ^^^^^^^^^^^^^^
[rank3]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank3]: return TELayerNormMLP(
[rank3]: ^^^^^^^^^^^^^^^
[rank3]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank3]: self.activation_func = Gated_Activation(self.activation_func)
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank3]: super().__setattr__(name, value)
[rank3]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank3]: raise AttributeError(
[rank3]: AttributeError: cannot assign module before Module.__init__() call
[rank6]: Traceback (most recent call last):
[rank6]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank6]: main()
[rank6]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank6]: pretrain(neox_args=neox_args)
[rank6]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank6]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank6]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank6]: model = GPT2ModelPipe(
[rank6]: ^^^^^^^^^^^^^^
[rank6]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank6]: super().__init__(
[rank6]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank6]: self._build()
[rank6]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank6]: module = layer.build()
[rank6]: ^^^^^^^^^^^^^
[rank6]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank6]: return self.typename(*self.module_args, **self.module_kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank6]: self.mlp = get_te_lnmlp()
[rank6]: ^^^^^^^^^^^^^^
[rank6]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank6]: return TELayerNormMLP(
[rank6]: ^^^^^^^^^^^^^^^
[rank6]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank6]: self.activation_func = Gated_Activation(self.activation_func)
[rank6]: ^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank6]: super().__setattr__(name, value)
[rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank6]: raise AttributeError(
[rank6]: AttributeError: cannot assign module before Module.__init__() call
[rank5]: Traceback (most recent call last):
[rank5]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank5]: main()
[rank5]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank5]: pretrain(neox_args=neox_args)
[rank5]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank5]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank5]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank5]: model = GPT2ModelPipe(
[rank5]: ^^^^^^^^^^^^^^
[rank5]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank5]: super().__init__(
[rank5]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank5]: self._build()
[rank5]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank5]: module = layer.build()
[rank5]: ^^^^^^^^^^^^^
[rank5]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank5]: return self.typename(*self.module_args, **self.module_kwargs)
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank5]: self.mlp = get_te_lnmlp()
[rank5]: ^^^^^^^^^^^^^^
[rank5]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank5]: return TELayerNormMLP(
[rank5]: ^^^^^^^^^^^^^^^
[rank5]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank5]: self.activation_func = Gated_Activation(self.activation_func)
[rank5]: ^^^^^^^^^^^^^^^^^^^^
[rank5]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank5]: super().__setattr__(name, value)
[rank5]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank5]: raise AttributeError(
[rank5]: AttributeError: cannot assign module before Module.__init__() call
[rank4]: Traceback (most recent call last):
[rank4]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank4]: main()
[rank4]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank4]: pretrain(neox_args=neox_args)
[rank4]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank4]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank4]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank4]: model = GPT2ModelPipe(
[rank4]: ^^^^^^^^^^^^^^
[rank4]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank4]: super().__init__(
[rank4]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank4]: self._build()
[rank4]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank4]: module = layer.build()
[rank4]: ^^^^^^^^^^^^^
[rank4]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank4]: return self.typename(*self.module_args, **self.module_kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank4]: self.mlp = get_te_lnmlp()
[rank4]: ^^^^^^^^^^^^^^
[rank4]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank4]: return TELayerNormMLP(
[rank4]: ^^^^^^^^^^^^^^^
[rank4]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank4]: self.activation_func = Gated_Activation(self.activation_func)
[rank4]: ^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank4]: super().__setattr__(name, value)
[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank4]: raise AttributeError(
[rank4]: AttributeError: cannot assign module before Module.__init__() call
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank0]: main()
[rank0]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank0]: pretrain(neox_args=neox_args)
[rank0]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank0]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank0]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank0]: model = GPT2ModelPipe(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank0]: super().__init__(
[rank0]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank0]: self._build()
[rank0]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank0]: module = layer.build()
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank0]: return self.typename(*self.module_args, **self.module_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank0]: self.mlp = get_te_lnmlp()
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank0]: return TELayerNormMLP(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank0]: self.activation_func = Gated_Activation(self.activation_func)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank0]: super().__setattr__(name, value)
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank0]: raise AttributeError(
[rank0]: AttributeError: cannot assign module before Module.__init__() call
[rank7]: Traceback (most recent call last):
[rank7]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank7]: main()
[rank7]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank7]: pretrain(neox_args=neox_args)
[rank7]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank7]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank7]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank7]: model = GPT2ModelPipe(
[rank7]: ^^^^^^^^^^^^^^
[rank7]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank7]: super().__init__(
[rank7]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank7]: self._build()
[rank7]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank7]: module = layer.build()
[rank7]: ^^^^^^^^^^^^^
[rank7]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank7]: return self.typename(*self.module_args, **self.module_kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank7]: self.mlp = get_te_lnmlp()
[rank7]: ^^^^^^^^^^^^^^
[rank7]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank7]: return TELayerNormMLP(
[rank7]: ^^^^^^^^^^^^^^^
[rank7]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank7]: self.activation_func = Gated_Activation(self.activation_func)
[rank7]: ^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank7]: super().__setattr__(name, value)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank7]: raise AttributeError(
[rank7]: AttributeError: cannot assign module before Module.__init__() call
[rank2]: Traceback (most recent call last):
[rank2]: File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank2]: main()
[rank2]: File "/workspace/gpt-neox/train.py", line 31, in main
[rank2]: pretrain(neox_args=neox_args)
[rank2]: File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank2]: model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank2]: model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank2]: model = GPT2ModelPipe(
[rank2]: ^^^^^^^^^^^^^^
[rank2]: File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank2]: super().__init__(
[rank2]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank2]: self._build()
[rank2]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank2]: module = layer.build()
[rank2]: ^^^^^^^^^^^^^
[rank2]: File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank2]: return self.typename(*self.module_args, **self.module_kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank2]: self.mlp = get_te_lnmlp()
[rank2]: ^^^^^^^^^^^^^^
[rank2]: File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank2]: return TELayerNormMLP(
[rank2]: ^^^^^^^^^^^^^^^
[rank2]: File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank2]: self.activation_func = Gated_Activation(self.activation_func)
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank2]: super().__setattr__(name, value)
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank2]: raise AttributeError(
[rank2]: AttributeError: cannot assign module before Module.__init__() call
[rank3]:[W313 02:31:55.225809539 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank4]:[W313 02:31:55.475404132 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W313 02:31:55.557783357 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W313 02:31:55.805672537 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W313 02:31:55.823858162 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank5]:[W313 02:31:55.828077019 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W313 02:31:55.923131940 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank7]:[W313 02:31:55.017956611 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[2025-03-13 02:31:56,920] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 443
[2025-03-13 02:31:57,175] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 444
[2025-03-13 02:31:57,430] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 445
[2025-03-13 02:31:57,885] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 446
[2025-03-13 02:31:57,885] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 447
[2025-03-13 02:31:57,886] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 448
[2025-03-13 02:31:58,299] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 449
[2025-03-13 02:31:58,299] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 450
To Reproduce
Steps to reproduce the behavior:
One should be able to reproduce this issue with the config I prove below.
Expected behavior
Training to begin as normal.
Proposed solution
Screenshots
If applicable, add screenshots to help explain your problem.
Environment (please complete the following information):
- GPUs: 8xH100
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 |
| N/A 24C P0 68W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 |
| N/A 23C P0 68W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 |
| N/A 23C P0 69W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 |
| N/A 23C P0 67W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 |
| N/A 25C P0 69W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 |
| N/A 22C P0 68W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 |
| N/A 25C P0 71W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 |
| N/A 22C P0 69W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
- Configs:
{
# Tokens
"data_path": "/data/pretraining-mix/pretraining-mix_text_document",
"vocab_file": "/data/neox_tokenizer/tokenizer.json",
"tokenizer_type": "HFTokenizer",
"data_impl": "mmap",
# Logging
"checkpoint_validation_with_forward_pass": False,
"tensorboard_dir": "tensorboard",
"log_dir": "logs",
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
"use_wandb": True,
"wandb_host": "https://api.wandb.ai",
"wandb_project": "AISI",
"wandb_team": "eleutherai",
"wandb_run_name": "pretraining_baseline",
# Distributed Training
"hostfile": "/workspace/hostfile",
"deepspeed_mpi": True,
"launcher": "openmpi",
"deepspeed_extra_args": { "ssh_port": 2222 },
"pipe_parallel_size": 1,
"model_parallel_size": 1,
# Training Duration
# 500B (tokens) / (1 (grad acc) * 32 (world size) * 32 (micro batch size) * 2048 (seq length))
"train_iters": 238419,
"lr_decay_iters": 238419,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"eval_interval": 238419,
"eval_iters": 0,
"split": "100,0,0",
# Architecture
"num_layers": 32,
"hidden_size": 4096,
"num_attention_heads": 32,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,
"output_layer_parallelism": "column",
"attention_config": [[["flash"], 32]],
"scaled_upper_triang_masked_softmax_fusion": true,
"activation": "swiglu",
"precision": "bfloat16",
# Transformer Engine
"te_columnparallel": false,
"te_rowparallel": false,
"te_layernorm_mlp": true,
"te_mha": true,
"te_fp8_format": "hybrid",
"te_fp8_wgrad": false,
"te_fp8_amax_history_len": 1,
"te_fp8_amax_compute_algo": "most_recent",
"te_fp8_margin": 0,
"te_fp8_mha": false,
# Optimization
# 0.0003 is OLMo 2's peak learning rate
"optimizer":
{
"type": "Adam",
"params": { "lr": 0.0003, "betas": [0.9, 0.95], "eps": 1.0e-8 },
},
"min_lr": 0.000012,
"zero_optimization":
{
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 1260000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": true,
"cpu_offload": false,
},
"train_micro_batch_size_per_gpu": 32,
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,
# Checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,
"checkpoint_factor": 1000,
"save": "/checkpoints/pretraining_baseline",
"load": "/checkpoints/pretraining_baseline",
}
Additional context
Add any other context about the problem here.