Skip to content

Unable to Use SwiGLU with Transformer Engine #1347

Open
@Kyle1668

Description

@Kyle1668

Describe the bug

My training run crashed during initialization when I attempt to use SwiGLU and Transformer Engine. I do not observe this issue when using the GELU default or when I set te_layernorm_mlp to false.

NeoXArgs.configure_distributed_args() using world size: 8 and model-parallel size: 1
> building HFTokenizer tokenizer ...
 > padded vocab (size: 50277) with 27 dummy tokens (new size: 50304)
Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt,     or directly from https://github.com/state-spaces/mamba
For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3
For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer
WARNING:root:Outstanding DeepSpeed issue means that pp>0, zero1, and bf16 will break without fp32 grads
> setting up tensorboard ...
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Loading extension module scaled_upper_triang_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /workspace/gpt-neox/megatron/fused_kernels/build/build.ninja...
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module scaled_masked_softmax_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module scaled_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /workspace/gpt-neox/megatron/fused_kernels/build/build.ninja...
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module fused_rotary_positional_embedding...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_rotary_positional_embedding...
> initializing torch distributed ...
[2025-03-13 02:31:25,996] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:25,996] [INFO] [comm.py:689:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2070: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
[2025-03-13 02:31:26,691] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:26,829] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:26,998] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,011] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,022] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,023] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-13 02:31:27,029] [INFO] [comm.py:658:init_distributed] cdb=None
> initializing model parallel with size 1
MPU DP: [0, 1, 2, 3, 4, 5, 6, 7]
MPU PP: [0]
MPU PP: [1]
MPU PP: [2]
MPU PP: [3]
MPU PP: [4]
MPU PP: [5]
MPU PP: [6]
MPU PP: [7]
MPU MP: [0]
MPU MP: [1]
MPU MP: [2]
MPU MP: [3]
MPU MP: [4]
MPU MP: [5]
MPU MP: [6]
MPU MP: [7]
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1125:configure] Activation Checkpointing Information
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1126:configure] ----Partition Activations True, CPU CHECKPOINTING False
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1127:configure] ----contiguous Memory Checkpointing False with 32 total layers
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1128:configure] ----Synchronization True
[2025-03-13 02:31:27,040] [INFO] [checkpointing.py:1129:configure] ----Profiling time in checkpointing False
> setting random seeds to 1234 ...
[2025-03-13 02:31:27,042] [INFO] [checkpointing.py:229:model_parallel_cuda_manual_seed] > initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
make: Entering directory '/workspace/gpt-neox/megatron/data'
g++ -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color -I/usr/include/python3.12 -I/usr/local/lib/python3.12/dist-packages/pybind11/include helpers.cpp -o helpers.cpython-312-x86_64-linux-gnu.so
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
make: Leaving directory '/workspace/gpt-neox/megatron/data'
> building train, validation, and test datasets ...
    reading sizes...
    reading pointers...
    reading document index...
    creating numpy buffer of mmap...
    creating memory view of numpy buffer...
 > dataset split:
    train:
     document indices in [0, 409935486) total of 409935486 documents
    validation:
     document indices in [409935486, 409935486) total of 0 documents
    test:
     document indices in [409935486, 409935486) total of 0 documents
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
/workspace/gpt-neox/megatron/data/gpt2_dataset.py:373: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:78.)
  counts = torch.cuda.LongTensor([1])
 > loading doc-idx mapping from /data/pretraining-mix/pretraining-mix_text_document_train_indexmap_61035008ns_2048sl_1234s_packedpi_ac_doc_idx.npy
 > loading sample-idx mapping from /data/pretraining-mix/pretraining-mix_text_document_train_indexmap_61035008ns_2048sl_1234s_packedpi_ac_sample_idx.npy
 > loading shuffle-idx mapping from /data/pretraining-mix/pretraining-mix_text_document_train_indexmap_61035008ns_2048sl_1234s_packedpi_ac_shuffle_idx.npy
    loaded indexed file in 0.009 seconds
    total number of samples: 244340789
    total number of epochs: 1
building GPT2 model ...
SEED_LAYERS=False BASE_SEED=1234 SEED_FN=None
Using topology: {ProcessCoord(pipe=0, data=0, model=0): 0, ProcessCoord(pipe=0, data=1, model=0): 1, ProcessCoord(pipe=0, data=2, model=0): 2, ProcessCoord(pipe=0, data=3, model=0): 3, ProcessCoord(pipe=0, data=4, model=0): 4, ProcessCoord(pipe=0, data=5, model=0): 5, ProcessCoord(pipe=0, data=6, model=0): 6, ProcessCoord(pipe=0, data=7, model=0): 7}
[2025-03-13 02:31:47,341] [INFO] [module.py:398:_partition_layers] Partitioning pipeline stages with method type:transformer|mlp
stage=0 layers=37
     0: EmbeddingPipe
     1: _pre_transformer_block
     2: ParallelTransformerLayerPipe
     3: ParallelTransformerLayerPipe
     4: ParallelTransformerLayerPipe
     5: ParallelTransformerLayerPipe
     6: ParallelTransformerLayerPipe
     7: ParallelTransformerLayerPipe
     8: ParallelTransformerLayerPipe
     9: ParallelTransformerLayerPipe
    10: ParallelTransformerLayerPipe
    11: ParallelTransformerLayerPipe
    12: ParallelTransformerLayerPipe
    13: ParallelTransformerLayerPipe
    14: ParallelTransformerLayerPipe
    15: ParallelTransformerLayerPipe
    16: ParallelTransformerLayerPipe
    17: ParallelTransformerLayerPipe
    18: ParallelTransformerLayerPipe
    19: ParallelTransformerLayerPipe
    20: ParallelTransformerLayerPipe
    21: ParallelTransformerLayerPipe
    22: ParallelTransformerLayerPipe
    23: ParallelTransformerLayerPipe
    24: ParallelTransformerLayerPipe
    25: ParallelTransformerLayerPipe
    26: ParallelTransformerLayerPipe
    27: ParallelTransformerLayerPipe
    28: ParallelTransformerLayerPipe
    29: ParallelTransformerLayerPipe
    30: ParallelTransformerLayerPipe
    31: ParallelTransformerLayerPipe
    32: ParallelTransformerLayerPipe
    33: ParallelTransformerLayerPipe
    34: _post_transformer_block
    35: NormPipe
    36: ParallelLinearPipe
  loss: partial
[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank1]:     main()
[rank1]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank1]:     pretrain(neox_args=neox_args)
[rank1]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank1]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank1]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank1]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank1]:     model = GPT2ModelPipe(
[rank1]:             ^^^^^^^^^^^^^^
[rank1]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank1]:     super().__init__(
[rank1]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank1]:     self._build()
[rank1]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank1]:     module = layer.build()
[rank1]:              ^^^^^^^^^^^^^
[rank1]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank1]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank1]:     self.mlp = get_te_lnmlp()
[rank1]:                ^^^^^^^^^^^^^^
[rank1]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank1]:     return TELayerNormMLP(
[rank1]:            ^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank1]:     self.activation_func = Gated_Activation(self.activation_func)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank1]:     super().__setattr__(name, value)
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank1]:     raise AttributeError(
[rank1]: AttributeError: cannot assign module before Module.__init__() call
[rank3]: Traceback (most recent call last):
[rank3]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank3]:     main()
[rank3]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank3]:     pretrain(neox_args=neox_args)
[rank3]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank3]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank3]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank3]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank3]:     model = GPT2ModelPipe(
[rank3]:             ^^^^^^^^^^^^^^
[rank3]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank3]:     super().__init__(
[rank3]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank3]:     self._build()
[rank3]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank3]:     module = layer.build()
[rank3]:              ^^^^^^^^^^^^^
[rank3]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank3]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank3]:     self.mlp = get_te_lnmlp()
[rank3]:                ^^^^^^^^^^^^^^
[rank3]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank3]:     return TELayerNormMLP(
[rank3]:            ^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank3]:     self.activation_func = Gated_Activation(self.activation_func)
[rank3]:     ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank3]:     super().__setattr__(name, value)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank3]:     raise AttributeError(
[rank3]: AttributeError: cannot assign module before Module.__init__() call
[rank6]: Traceback (most recent call last):
[rank6]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank6]:     main()
[rank6]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank6]:     pretrain(neox_args=neox_args)
[rank6]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank6]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank6]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank6]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank6]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank6]:     model = GPT2ModelPipe(
[rank6]:             ^^^^^^^^^^^^^^
[rank6]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank6]:     super().__init__(
[rank6]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank6]:     self._build()
[rank6]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank6]:     module = layer.build()
[rank6]:              ^^^^^^^^^^^^^
[rank6]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank6]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank6]:     self.mlp = get_te_lnmlp()
[rank6]:                ^^^^^^^^^^^^^^
[rank6]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank6]:     return TELayerNormMLP(
[rank6]:            ^^^^^^^^^^^^^^^
[rank6]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank6]:     self.activation_func = Gated_Activation(self.activation_func)
[rank6]:     ^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank6]:     super().__setattr__(name, value)
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank6]:     raise AttributeError(
[rank6]: AttributeError: cannot assign module before Module.__init__() call
[rank5]: Traceback (most recent call last):
[rank5]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank5]:     main()
[rank5]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank5]:     pretrain(neox_args=neox_args)
[rank5]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank5]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank5]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank5]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank5]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank5]:     model = GPT2ModelPipe(
[rank5]:             ^^^^^^^^^^^^^^
[rank5]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank5]:     super().__init__(
[rank5]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank5]:     self._build()
[rank5]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank5]:     module = layer.build()
[rank5]:              ^^^^^^^^^^^^^
[rank5]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank5]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank5]:     self.mlp = get_te_lnmlp()
[rank5]:                ^^^^^^^^^^^^^^
[rank5]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank5]:     return TELayerNormMLP(
[rank5]:            ^^^^^^^^^^^^^^^
[rank5]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank5]:     self.activation_func = Gated_Activation(self.activation_func)
[rank5]:     ^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank5]:     super().__setattr__(name, value)
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank5]:     raise AttributeError(
[rank5]: AttributeError: cannot assign module before Module.__init__() call
[rank4]: Traceback (most recent call last):
[rank4]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank4]:     main()
[rank4]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank4]:     pretrain(neox_args=neox_args)
[rank4]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank4]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank4]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank4]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank4]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank4]:     model = GPT2ModelPipe(
[rank4]:             ^^^^^^^^^^^^^^
[rank4]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank4]:     super().__init__(
[rank4]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank4]:     self._build()
[rank4]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank4]:     module = layer.build()
[rank4]:              ^^^^^^^^^^^^^
[rank4]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank4]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank4]:     self.mlp = get_te_lnmlp()
[rank4]:                ^^^^^^^^^^^^^^
[rank4]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank4]:     return TELayerNormMLP(
[rank4]:            ^^^^^^^^^^^^^^^
[rank4]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank4]:     self.activation_func = Gated_Activation(self.activation_func)
[rank4]:     ^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank4]:     super().__setattr__(name, value)
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank4]:     raise AttributeError(
[rank4]: AttributeError: cannot assign module before Module.__init__() call
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank0]:     main()
[rank0]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank0]:     pretrain(neox_args=neox_args)
[rank0]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank0]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank0]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank0]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank0]:     model = GPT2ModelPipe(
[rank0]:             ^^^^^^^^^^^^^^
[rank0]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank0]:     super().__init__(
[rank0]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank0]:     self._build()
[rank0]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank0]:     module = layer.build()
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank0]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank0]:     self.mlp = get_te_lnmlp()
[rank0]:                ^^^^^^^^^^^^^^
[rank0]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank0]:     return TELayerNormMLP(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank0]:     self.activation_func = Gated_Activation(self.activation_func)
[rank0]:     ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank0]:     super().__setattr__(name, value)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: cannot assign module before Module.__init__() call
[rank7]: Traceback (most recent call last):
[rank7]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank7]:     main()
[rank7]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank7]:     pretrain(neox_args=neox_args)
[rank7]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank7]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank7]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank7]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank7]:     model = GPT2ModelPipe(
[rank7]:             ^^^^^^^^^^^^^^
[rank7]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank7]:     super().__init__(
[rank7]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank7]:     self._build()
[rank7]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank7]:     module = layer.build()
[rank7]:              ^^^^^^^^^^^^^
[rank7]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank7]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank7]:     self.mlp = get_te_lnmlp()
[rank7]:                ^^^^^^^^^^^^^^
[rank7]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank7]:     return TELayerNormMLP(
[rank7]:            ^^^^^^^^^^^^^^^
[rank7]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank7]:     self.activation_func = Gated_Activation(self.activation_func)
[rank7]:     ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank7]:     super().__setattr__(name, value)
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank7]:     raise AttributeError(
[rank7]: AttributeError: cannot assign module before Module.__init__() call
[rank2]: Traceback (most recent call last):
[rank2]:   File "/workspace/gpt-neox/train.py", line 35, in <module>
[rank2]:     main()
[rank2]:   File "/workspace/gpt-neox/train.py", line 31, in main
[rank2]:     pretrain(neox_args=neox_args)
[rank2]:   File "/workspace/gpt-neox/megatron/training.py", line 252, in pretrain
[rank2]:     model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
[rank2]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/gpt-neox/megatron/training.py", line 1256, in setup_model_and_optimizer
[rank2]:     model = get_model(neox_args=neox_args, use_cache=use_cache)
[rank2]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/gpt-neox/megatron/training.py", line 978, in get_model
[rank2]:     model = GPT2ModelPipe(
[rank2]:             ^^^^^^^^^^^^^^
[rank2]:   File "/workspace/gpt-neox/megatron/model/gpt2_model.py", line 131, in __init__
[rank2]:     super().__init__(
[rank2]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 214, in __init__
[rank2]:     self._build()
[rank2]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank2]:     module = layer.build()
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/workspace/DeeperSpeed/deepspeed/runtime/pipe/module.py", line 74, in build
[rank2]:     return self.typename(*self.module_args, **self.module_kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1073, in __init__
[rank2]:     self.mlp = get_te_lnmlp()
[rank2]:                ^^^^^^^^^^^^^^
[rank2]:   File "/workspace/gpt-neox/megatron/model/transformer.py", line 1056, in get_te_lnmlp
[rank2]:     return TELayerNormMLP(
[rank2]:            ^^^^^^^^^^^^^^^
[rank2]:   File "/workspace/gpt-neox/megatron/model/transformer_engine.py", line 198, in __init__
[rank2]:     self.activation_func = Gated_Activation(self.activation_func)
[rank2]:     ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 486, in __setattr__
[rank2]:     super().__setattr__(name, value)
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1965, in __setattr__
[rank2]:     raise AttributeError(
[rank2]: AttributeError: cannot assign module before Module.__init__() call
[rank3]:[W313 02:31:55.225809539 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank4]:[W313 02:31:55.475404132 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W313 02:31:55.557783357 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W313 02:31:55.805672537 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W313 02:31:55.823858162 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank5]:[W313 02:31:55.828077019 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W313 02:31:55.923131940 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank7]:[W313 02:31:55.017956611 ProcessGroupNCCL.cpp:1487] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[2025-03-13 02:31:56,920] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 443
[2025-03-13 02:31:57,175] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 444
[2025-03-13 02:31:57,430] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 445
[2025-03-13 02:31:57,885] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 446
[2025-03-13 02:31:57,885] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 447
[2025-03-13 02:31:57,886] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 448
[2025-03-13 02:31:58,299] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 449
[2025-03-13 02:31:58,299] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 450

To Reproduce
Steps to reproduce the behavior:

One should be able to reproduce this issue with the config I prove below.

Expected behavior

Training to begin as normal.

Proposed solution

Screenshots
If applicable, add screenshots to help explain your problem.

Environment (please complete the following information):

  • GPUs: 8xH100
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:19:00.0 Off |                    0 |
| N/A   24C    P0             68W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:3B:00.0 Off |                    0 |
| N/A   23C    P0             68W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   23C    P0             69W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:5D:00.0 Off |                    0 |
| N/A   23C    P0             67W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   25C    P0             69W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:BB:00.0 Off |                    0 |
| N/A   22C    P0             68W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:CB:00.0 Off |                    0 |
| N/A   25C    P0             71W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:DB:00.0 Off |                    0 |
| N/A   22C    P0             69W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
  • Configs:
{
  # Tokens
  "data_path": "/data/pretraining-mix/pretraining-mix_text_document",
  "vocab_file": "/data/neox_tokenizer/tokenizer.json",
  "tokenizer_type": "HFTokenizer",
  "data_impl": "mmap",

  # Logging
  "checkpoint_validation_with_forward_pass": False,
  "tensorboard_dir": "tensorboard",
  "log_dir": "logs",
  "log_interval": 10,
  "steps_per_print": 10,
  "wall_clock_breakdown": true,
  "use_wandb": True,
  "wandb_host": "https://api.wandb.ai",
  "wandb_project": "AISI",
  "wandb_team": "eleutherai",
  "wandb_run_name": "pretraining_baseline",

  # Distributed Training
  "hostfile": "/workspace/hostfile",
  "deepspeed_mpi": True,
  "launcher": "openmpi",
  "deepspeed_extra_args": { "ssh_port": 2222 },
  "pipe_parallel_size": 1,
  "model_parallel_size": 1,

  # Training Duration
  # 500B (tokens) / (1 (grad acc) * 32 (world size) * 32 (micro batch size) * 2048 (seq length))
  "train_iters": 238419,
  "lr_decay_iters": 238419,
  "distributed_backend": "nccl",
  "lr_decay_style": "cosine",
  "warmup": 0.01,
  "eval_interval": 238419,
  "eval_iters": 0,
  "split": "100,0,0",

  # Architecture
  "num_layers": 32,
  "hidden_size": 4096,
  "num_attention_heads": 32,
  "seq_length": 2048,
  "max_position_embeddings": 2048,
  "norm": "layernorm",
  "pos_emb": "rotary",
  "rotary_pct": 0.25,
  "no_weight_tying": true,
  "gpt_j_residual": true,
  "output_layer_parallelism": "column",
  "attention_config": [[["flash"], 32]],
  "scaled_upper_triang_masked_softmax_fusion": true,
  "activation": "swiglu",
  "precision": "bfloat16",

  # Transformer Engine
  "te_columnparallel": false,
  "te_rowparallel": false,
  "te_layernorm_mlp": true,
  "te_mha": true,
  "te_fp8_format": "hybrid",
  "te_fp8_wgrad": false,
  "te_fp8_amax_history_len": 1,
  "te_fp8_amax_compute_algo": "most_recent",
  "te_fp8_margin": 0,
  "te_fp8_mha": false,

  # Optimization
  # 0.0003 is OLMo 2's peak learning rate
  "optimizer":
    {
      "type": "Adam",
      "params": { "lr": 0.0003, "betas": [0.9, 0.95], "eps": 1.0e-8 },
    },
  "min_lr": 0.000012,
  "zero_optimization":
    {
      "stage": 1,
      "allgather_partitions": true,
      "allgather_bucket_size": 1260000000,
      "overlap_comm": true,
      "reduce_scatter": true,
      "reduce_bucket_size": 1260000000,
      "contiguous_gradients": true,
      "cpu_offload": false,
    },
  "train_micro_batch_size_per_gpu": 32,
  "gradient_accumulation_steps": 1,
  "gradient_clipping": 1.0,
  "weight_decay": 0.1,
  "hidden_dropout": 0,
  "attention_dropout": 0,

  # Checkpointing
  "checkpoint_activations": true,
  "checkpoint_num_layers": 1,
  "partition_activations": true,
  "synchronize_each_layer": true,
  "checkpoint_factor": 1000,
  "save": "/checkpoints/pretraining_baseline",
  "load": "/checkpoints/pretraining_baseline",
}

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions