From 416de68d5a8fce45826624772112f2da22ca142d Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Thu, 14 Aug 2025 21:42:21 -0500 Subject: [PATCH 01/14] add small llm pretraining address review from previous PR Credit co-authors for prior squash Co-authored-by: ZixianWangAMD Co-authored-by: Michal Marcinkiewicz Co-authored-by: Lukasz Pierscieniewski --- small_llm_pretraining/nemo/Dockerfile.h200 | 134 ++++ small_llm_pretraining/nemo/Dockerfile.mi325 | 89 +++ small_llm_pretraining/nemo/README.md | 182 ++++++ small_llm_pretraining/nemo/callbacks.py | 244 ++++++++ .../nemo/config_H100_1x8x4_8b.sh | 97 +++ .../nemo/config_H200_1x8x1_8b.sh | 95 +++ .../nemo/config_MI325X_1x8x1_8b.sh | 157 +++++ .../nemo/patches/nemo_v2_1_0.patch | 64 ++ .../nemo/pretrain_llama31.py | 579 ++++++++++++++++++ small_llm_pretraining/nemo/requirements.txt | 9 + small_llm_pretraining/nemo/run_llama31.sh | 152 +++++ .../nemo/utils/consolidate_data.sh | 42 ++ .../utils/parallel_compress_json_to_gz.sh | 16 + .../nemo/utils/preprocess.sh | 34 + 14 files changed, 1894 insertions(+) create mode 100644 small_llm_pretraining/nemo/Dockerfile.h200 create mode 100644 small_llm_pretraining/nemo/Dockerfile.mi325 create mode 100644 small_llm_pretraining/nemo/README.md create mode 100644 small_llm_pretraining/nemo/callbacks.py create mode 100644 small_llm_pretraining/nemo/config_H100_1x8x4_8b.sh create mode 100644 small_llm_pretraining/nemo/config_H200_1x8x1_8b.sh create mode 100644 small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh create mode 100644 small_llm_pretraining/nemo/patches/nemo_v2_1_0.patch create mode 100644 small_llm_pretraining/nemo/pretrain_llama31.py create mode 100644 small_llm_pretraining/nemo/requirements.txt create mode 100644 small_llm_pretraining/nemo/run_llama31.sh create mode 100644 small_llm_pretraining/nemo/utils/consolidate_data.sh create mode 100644 small_llm_pretraining/nemo/utils/parallel_compress_json_to_gz.sh create mode 100644 small_llm_pretraining/nemo/utils/preprocess.sh diff --git a/small_llm_pretraining/nemo/Dockerfile.h200 b/small_llm_pretraining/nemo/Dockerfile.h200 new file mode 100644 index 000000000..5cf32b704 --- /dev/null +++ b/small_llm_pretraining/nemo/Dockerfile.h200 @@ -0,0 +1,134 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:25.01-py3 +FROM ${FROM_IMAGE_NAME} + +# Document build setup +ARG FROM_IMAGE_NAME +ENV CUSTOM_FROM_IMAGE_NAME ${FROM_IMAGE_NAME} + +# Custom libraries version +WORKDIR /workspace/ + +ARG GIT_COMMIT_ID +ENV GIT_COMMIT_ID=$GIT_COMMIT_ID + +RUN git config --global user.name "a" && \ + git config --global user.email "a" + +WORKDIR /workspace/ + +RUN pip install numcodecs==0.13.1 + +## 1. Apex +ARG APEX_REVISION=SKIP +ENV CUSTOM_APEX_REVISION ${APEX_REVISION} +ARG APEX_MAX_JOBS=4 + +RUN if [ "${APEX_REVISION}" != SKIP ]; then \ + git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + echo APEX_REVISION=${APEX_REVISION} && \ + git checkout ${APEX_REVISION} && \ + echo APEX_COMMIT_HASH=$(git rev-parse HEAD) && \ + MAX_JOBS=${APEX_MAX_JOBS} NVCC_APPEND_FLAGS="--threads 8" pip install -v --no-build-isolation --no-cache-dir --disable-pip-version-check --config-settings "--build-option=--cpp_ext --cuda_ext --bnp --xentropy --deprecated_fused_adam --deprecated_fused_lamb --fast_multihead_attn --distributed_lamb --fast_layer_norm --transducer --distributed_adam --fmha --fast_bottleneck --nccl_p2p --peer_memory --permutation_search --focal_loss --fused_conv_bias_relu --index_mul_2d --cudnn_gbn --group_norm" . \ + ; fi + + + +## 2. Transformer Engine +ARG TE_REVISION=SKIP +ENV CUSTOM_TE_REVISION ${TE_REVISION} + +RUN if [ "${TE_REVISION}" != SKIP ]; then \ + pip uninstall -y transformer-engine && \ + git clone https://github.com/NVIDIA/TransformerEngine.git transformerengine && \ + cd transformerengine && \ + git checkout ${TE_REVISION} && \ + echo TE_COMMIT_HASH=$(git rev-parse HEAD) && \ + echo $(git rev-parse HEAD) > /TE_COMMIT_HASH.env && \ + git submodule init && git submodule update && \ + NVTE_CUDA_ARCHS="90;100" NVTE_UB_WITH_MPI=1 NVTE_FRAMEWORK=pytorch MPI_HOME=/usr/local/mpi pip install --force-reinstall --no-deps . \ + ; fi + + +## 3. NeMo +ARG NEMO_REVISION=v2.1.0 +ENV CUSTOM_NEMO_REVISION ${NEMO_REVISION} + +# Clone and checkout NeMo at specified version +RUN git clone https://github.com/NVIDIA/NeMo.git && \ + cd NeMo && \ + git checkout ${NEMO_REVISION} && \ + echo NEMO_COMMIT_HASH=$(git rev-parse HEAD) && \ + echo $(git rev-parse HEAD) > /NEMO_COMMIT_HASH.env && \ + pip uninstall -y nemo-toolkit sacrebleu && \ + # Only keep edits that are necessary (remove AMD-specific workarounds if upstream doesn't need them) + sed -i "/mamba-ssm/d" requirements/requirements_nlp.txt && \ + sed -i 's/tensorstore<0.1.46/tensorstore/g' requirements/requirements_nlp.txt && \ + sed -i 's/protobuf==3.20.3/protobuf/g' requirements/requirements.txt && \ + pip install "cython<3.0.0" && \ + pip install -e ".[llm]" && \ + pip install -e ".[nlp]" + + +## 3.1 NeMo-Run +ARG NEMORUN_REVISION=v0.4.0 +ENV CUSTOM_NEMORUN_REVISION ${NEMORUN_REVISION} + +RUN git clone https://github.com/NVIDIA/NeMo-Run.git && \ + cd NeMo-Run && \ + git checkout ${NEMORUN_REVISION} && \ + echo NEMORUN_COMMIT_HASH=$(git rev-parse HEAD) && \ + pip install -e . + + +# Python deps +# Important this should be done after NeMo, otherwise the pinned transformers==4.40.2 version will be overwritten +COPY requirements.txt requirements.txt +RUN pip3 install -r requirements.txt + + +# 4. Megatron-core +ARG MCORE_REVISION=core_r0.11.0 +ARG MCORE_REPO=https://github.com/NVIDIA/Megatron-LM.git +ENV CUSTOM_MCORE_REVISION ${MCORE_REVISION} + +RUN if [ "${MCORE_REVISION}" != SKIP ]; then \ + pip uninstall -y megatron-core && \ + git clone ${MCORE_REPO} Megatron-LM && \ + cd Megatron-LM && \ + git checkout ${MCORE_REVISION} && \ + echo MCORE_COMMIT_HASH=$(git rev-parse HEAD) && \ + echo $(git rev-parse HEAD) > /MCORE_COMMIT_HASH.env && \ + pip install . && \ + cd megatron/core/datasets && \ + make \ + ; fi + +ENV PYTHONPATH "${PYTHONPATH}:/workspace/Megatron-LM" + + +WORKDIR /workspace/code + +# Copy the current state of the code inside the image +COPY . . \ No newline at end of file diff --git a/small_llm_pretraining/nemo/Dockerfile.mi325 b/small_llm_pretraining/nemo/Dockerfile.mi325 new file mode 100644 index 000000000..b287a4917 --- /dev/null +++ b/small_llm_pretraining/nemo/Dockerfile.mi325 @@ -0,0 +1,89 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# MIT License + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +FROM rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.6.0 + +WORKDIR /workspace + +RUN pip install pybind11 +RUN pip install ninja +RUN pip install packaging +RUN /usr/bin/python3 -m pip install pyYAML + +# Install library dependencies +WORKDIR /workspace/deps + +# FlashAttention +RUN git clone https://github.com/ROCm/flash-attention/ flash_attention \ + # latest stable commit of ck_tile/fa3 branch + && cd flash_attention && git checkout cace3592812640486b04196a209bb85d12267b4c \ + && git submodule update --init --recursive \ + && PYTORCH_ROCM_ARCH='gfx942' GPU_ARCHS="gfx942" MAX_JOBS=64 pip install --no-build-isolation -e . + +ADD patches /workspace/deps/patches + +# Megatron-core +RUN git clone --recursive https://github.com/ROCm/Megatron-LM.git megatron_lm +RUN pip uninstall -y megatron-core +# dev branch commit +RUN cd megatron_lm && git checkout megatron_190213a_mlperf \ + && pip install -e . && cd megatron/core/datasets && make + +ENV PYTHONPATH "${PYTHONPATH}:/workspace/deps/megatron_lm" + +# mambe dependency required for NeMo +RUN git clone https://github.com/state-spaces/mamba.git mamba_ssm \ + && cd mamba_ssm \ + && git checkout v2.2.2 \ + && export HIP_ARCHITECTURES="gfx942" \ + && pip install --no-cache-dir --verbose . + +# NeMo +RUN git clone https://github.com/NVIDIA/NeMo nemo \ + && cd nemo && git checkout v2.1.0 +RUN cd /workspace/deps/nemo \ + && git apply /workspace/deps/patches/nemo_v2_1_0.patch \ + && pip install --no-build-isolation -e ".[nlp]" + +# NeMo-Run +RUN pip install git+https://github.com/NVIDIA/NeMo-Run.git@v0.4.0 + +# Python deps +# Important this should be done after NeMo, otherwise the pinned transformers==4.40.2 version will be overwritten +COPY requirements.txt requirements.txt +RUN pip3 install -r requirements.txt + +# Transformer Engine +ARG TE_COMMIT=te_v1.9_mlperf_llama2 +RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git \ + # dev branch commit + && cd TransformerEngine && git checkout $TE_COMMIT && git submodule update --init --recursive \ + # Workaround logging debug info to the console + && sed -i 's/self.logger.info/self.logger.debug/g' /workspace/deps/TransformerEngine/transformer_engine/pytorch/attention.py \ + && sed -i 's/warnings.warn/if False: warnings.warn/g' /workspace/deps/TransformerEngine/transformer_engine/pytorch/attention.py \ + && sed -i '/.*\"window_size should be.*/d' /workspace/deps/TransformerEngine/transformer_engine/common/fused_attn_rocm/fused_attn.cpp \ + && NVTE_FUSED_ATTN_AOTRITON=0 NVTE_ROCM_ARCH='gfx942' NVTE_FRAMEWORK='pytorch' NVTE_USE_HIPBLASLT=1 MAX_JOBS=128 PYTORCH_ROCM_ARCH='gfx942' GPU_ARCHS='gfx942' pip install -e . + +WORKDIR /workspace/code + +# Copy the current state of the code inside the image +COPY . . \ No newline at end of file diff --git a/small_llm_pretraining/nemo/README.md b/small_llm_pretraining/nemo/README.md new file mode 100644 index 000000000..f27409697 --- /dev/null +++ b/small_llm_pretraining/nemo/README.md @@ -0,0 +1,182 @@ +# 1. Problem + +Small Language Model pretraining - Llama 3.1 8B + +# 2. Docker Setup +To build the docker image: +```bash +docker build -t -f Dockerfile . +``` + +To launch the docker container: +``` +docker run -it --rm \ + --net=host --uts=host \ + +``` + + +# 3. Dataset and Model + +The current codebase is using the c4/en/3.0.1 dataset from [HuggingFace/AllenAI](https://huggingface.co/datasets/allenai/c4) for train and evaluation. + + + +### Raw data downloading + +We use [AllenAI C4](https://huggingface.co/datasets/allenai/c4) dataset for this benchmark. The original zipped **`json.gz`** files can be downloaded by following AllenAI C4's instruction, and you can download our zipped customized validation dataset from the MLCommons S3 bucket by running the following command: + + +```bash +export C4_PATH="" + +# download the full C4 files, including all raw train and validations +rclone copy mlc-training:mlcommons-training-wg-public/common/datasets/c4/original/en_json/3.0.1 $C4_PATH -P +``` +After downloading, run the following command to process them to zip them into `.gz` format before running the data preprocessing. + +``` +bash utils/parallel_compress_json_to_gz.sh +``` + +Run the following commands to merge all 1024 training files into 8 `json.gz` files, all 8 validation files into a single `json.gz` file, as well as generate our customized validation dataset. Each of the `json.gz` files will subsequently be preprocessed into a pair of megatron dataset files (`.bin` and `.idx`) by our preprocess.sh script. + +```bash +export C4_PATH="" +export MERGED_C4_PATH="" +# more information about this knob can be found in consolidate_data.sh +export N_VALIDATION_SAMPLES=91205 + +bash utils/consolidate_data.sh +``` + +### Tokenizer +We are using the Llama 3.1 8B tokenizer. To download it, you can run the following commands: +```bash +export TOKENIZER_PATH="" +huggingface-cli login +huggingface-cli download meta-llama/Llama-3.1-8B --local-dir $TOKENIZER_PATH +``` + +After the data consolidation is done, we can perform preprocessing using the following commands: + +```bash +# pass in the folder path that contains the Llama tokenizer here +# please refer to the tokenizer section above for more details +export TOKENIZER_PATH="" +# pass in the merged file path here +export MERGED_C4_PATH="" +# this path is used for storing the preprocessed .bin and .idx files +export PREPROCESSED_PATH="" + +for index in {0..7}; do + # please specify the right path to nemo + python3 /nemo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \ + --input "${MERGED_C4_PATH}/c4-train.en_${index}.json.gz" \ + --output-prefix "${PREPROCESSED_PATH}/c4-train.en_${index}" \ + --tokenizer-library huggingface --tokenizer-type ${TOKENIZER_PATH} \ + --dataset-impl mmap --workers 128 & +done + # please specify the right path to nemo + python3 /nemo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \ + --input "${MERGED_C4_PATH}/c4-validation-91205-samples.en.json.gz" \ + --output-prefix "${PREPROCESSED_PATH}/c4-validation-91205-samples.en" \ + --tokenizer-library huggingface --tokenizer-type ${TOKENIZER_PATH} \ + --dataset-impl mmap --workers 128 & +wait + +``` + +After the download is complete, you should see files with the following naming conventions under `PREPROCESSED_PATH`, ending with both `.idx` and `.bin`: +- Training partitions: `c4-train.en__text_document` +- Validation partitions: `c4-validation-91205-samples.en_text_document` + +#### Training and test data separation + +We use the default split from the C4 dataset. This means that we use `c4-train.-of-01024.json.gz` files (where `768 <= x <= 1023`) for training, and we use our customized `c4-validation-91205-samples.en.json.gz`, which contains the first 91205 samples from the unshuffled C4 validation dataset, for evaluation. + +Notice here that we are using the first 1024 sequences (8,388,608 tokens) from the validation dataset to perform the validation. According to our experiments, the first 91205 samples from the unshuffled C4 dataset yields 47,186,855 tokens, which is the smallest amount of samples needed to yield 47,185,920 tokens. Thus, we have chosen the first 91205 samples as our validation dataset. + +#### Training data order + +We randomly shuffle the **last 256 of 1024 shards** for the benchmarking area. + +#### Test data order + +We use the first 1024 sequences in the validation dataset for validation. We **do not shuffle** the validation dataset. + +# 4. Model +### Publication/Attribution + +The model largely follows the Llama 3.1 8B [paper](https://arxiv.org/abs/2407.21783). + +### Model details + +| Config | Value | +| :-- | :-- | +| Embedding | RoPE + parameter adjustments | +| # Layers | 32 | +| Attention Type | GQA | +| # Attn Heads | 32 | +| Key/Value Heads | 8 | +| Model Dimension | 4096 | +| FFN Dimension | 14336 | +| Activation | SwiGLU | +| Normalization | RMSNorm | +| Tokenizer | Llama tokenizer | +| Vocab size | 128,000 | +| Context Length | 8192 | + + +#### Saving and restoring a checkpoint + +Large runs might need to span across multiple Slurm jobs, and we need to save and load checkpoints with contexts so that training can resume between jobs. To support this, we have added some environment variables. Please refer to `config.sh` for more details. + +### Optimizer spec + +1. Optimizer type: **AdamW** +2. Warmup steps computed as 10% of the total allocated steps. + +# 5. Quality +### Quality metric + +Validation loss + +### Quality target + +Validation log perplexity = 3.3 + +### Evaluation frequency + +We perform evaluation every **12288** sequences. + +### Evaluation thoroughness + +We evaluate using **1024** sequences from our customized validation dataset. + +# 6. Launch a training run + +To train Llama 3.1 8B, we need to fill out all fields in `config.sh`. This file contains all configurations for Slurm cluster access and job submission configurations, directory mappings, containers, and model configurations. + +Once the `config.sh` is properly filled, we launch a training run using the following commands: + +```bash +source config.sh +bash run_llama31.sh +``` \ No newline at end of file diff --git a/small_llm_pretraining/nemo/callbacks.py b/small_llm_pretraining/nemo/callbacks.py new file mode 100644 index 000000000..8ce5669ed --- /dev/null +++ b/small_llm_pretraining/nemo/callbacks.py @@ -0,0 +1,244 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +### MLLogger +from mlperf_logging import mllog +from mlperf_logging.mllog import constants +import torch.distributed as dist + +def is_dist_avail_and_initialized(): + return (dist.is_available() and dist.is_initialized()) + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + +def barrier(): + if not is_dist_avail_and_initialized(): + return + + dist.barrier() + +class MLLogger: + def __init__(self, filepath="/mlperf-outputs/mlperf_llama31_8b.log", default_stack_offset=2): + self.logger = mllog.get_mllogger() + mllog.config(default_stack_offset=default_stack_offset, filename=filepath) + + def start(self, **kwargs): + if get_rank() == 0: + self.logger.start(**kwargs) + + def end(self, **kwargs): + if get_rank() == 0: + self.logger.end(**kwargs) + + def event(self, **kwargs): + if get_rank() == 0: + self.logger.event(**kwargs) + + def submission_info(self): + self.event(key=constants.SUBMISSION_BENCHMARK, value="llama31_8b") + self.event(key=constants.SUBMISSION_ORG, value="reference_implementation") + self.event(key=constants.SUBMISSION_DIVISION, value=constants.CLOSED) + self.event(key=constants.SUBMISSION_STATUS, value=constants.ONPREM) + self.event(key=constants.SUBMISSION_PLATFORM, value="DGX-H100") + self.event(key=constants.SUBMISSION_POC_NAME, value="Yunzhou Liu") + self.event(key=constants.SUBMISSION_POC_EMAIL, value="yunzhoul@nvidia.com") + +mllogger = MLLogger() + +### Preemptive checkpoint callbacks +import lightning.pytorch as pl +from nemo.utils import logging + +class PreemptiveStop(pl.Callback): + """Preemptively stop training at a given global step. Allows stopping training before reaching + the max steps. Useful for testing checkpoint save and resume. + + Args: + stop_on_step (int): Stop training when trainer.global_step reaches this value. + Checked at the start of every step. + """ + + def __init__(self, stop_on_step: int): + self.stop_on_step = stop_on_step + + def on_train_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx + ) -> None: + if trainer.global_step >= self.stop_on_step: + logging.info(f"Global step {trainer.global_step} >= {self.stop_on_step}, signaling Trainer to stop.") + trainer.should_stop = True + # skip EarlyStopping validation unless val_check_interval met + if trainer.global_step % trainer.val_check_interval != 0: + trainer.limit_val_batches = 0 + + +### Metrics Logger +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only + +class MetricsLogger(Logger): + def __init__( + self, + init_global_step, global_batch_size, seq_length, + target_log_ppl, + train_loss_key = "reduced_train_loss", + val_loss_key = "val_loss", + train_step_time_in_s = "train_step_timing in s", + train_step_time_atol=7200, + ): + super().__init__() + + self.init_global_step = init_global_step + self.gbs = global_batch_size + self.seq_len = seq_length + + self.target = target_log_ppl + self.train_loss_key = train_loss_key + self.val_loss_key = val_loss_key + self.is_target_reached = False + + self.train_step_time_in_s = train_step_time_in_s + self.train_step_time_atol = train_step_time_atol + + def log_metrics(self, metrics, step): + if self.val_loss_key in metrics: + self.log_validation_loss(metrics, step) + + if self.train_step_time_in_s in metrics: + step_time = metrics[self.train_step_time_in_s] + assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). " + + def log_validation_loss(self, metrics, step): + consumed_samples = step * self.gbs + + loss = metrics[self.val_loss_key] + + mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={constants.SAMPLES_COUNT: consumed_samples}) + + if not self.is_target_reached and loss <= self.target: + self.is_target_reached = True + + @rank_zero_only + def log_hyperparams(self, params, *args, **kwargs): + pass + + @property + def name(self): + return 'mlperf-metrics' + + @property + def version(self): + return 1 + +### MLPerf callbacks +def compute_consumed_mllog_samples(trainer, init_global_step, global_batch_size, seq_length): + consumed_samples = ( + trainer.global_step * global_batch_size + ) + return int(consumed_samples) # we log the epoch numbers in sequences, not tokens + +class MLPerfCallback(pl.Callback): + def __init__( + self, + global_batch_size, + micro_batch_size, + sequence_length, + init_global_step, + eval_every, + configs={} + ): + mllogger.event(key=constants.CACHE_CLEAR, value=True) + mllogger.start(key=constants.INIT_START) + super().__init__() + + self.init_global_step = init_global_step + self.gbs = global_batch_size + self.mbs = micro_batch_size + self.seq_len = sequence_length + self.eval_every = eval_every + + self.is_target_reached = False + self.status = constants.ABORTED + self.configs = configs + + def consumed_samples(self, trainer): + return compute_consumed_mllog_samples(trainer, self.init_global_step, self.gbs, self.seq_len) + + def set_success_status(self): + self.status = constants.SUCCESS + self.is_target_reached = True + + @rank_zero_only + def on_train_epoch_start(self, trainer, pl_module): + mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + + return super().on_train_epoch_start(trainer, pl_module) + + @rank_zero_only + def on_train_epoch_end(self, trainer, pl_module): + mllogger.end(key=constants.EPOCH_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + return super().on_train_epoch_end(trainer, pl_module) + + def on_train_end(self, trainer, pl_module): + # for every occurrences, run on all ranks to allow sync + barrier() + mllogger.end(key=constants.RUN_STOP, metadata={"status": self.status}) + mllogger.event(key="train_samples", value=self.consumed_samples(trainer)) + return super().on_train_end(trainer, pl_module) + + @rank_zero_only + def log_eval_start(self, trainer, pl_module): + mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + + + def on_validation_start(self, trainer, pl_module): + trainer.val_check_interval = self.eval_every + trainer.val_check_batch = self.eval_every + self.log_eval_start(trainer, pl_module) + + def on_validation_end(self, trainer, pl_module): + mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + + for logger in trainer.loggers: + if isinstance(logger, MetricsLogger): + if logger.is_target_reached: + trainer.should_stop = True + self.set_success_status() + + if not trainer.should_stop: + mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + + return super().on_validation_end(trainer, pl_module) + + @rank_zero_only + def load_state_dict(self, state_dict): + print(f":::MLLOG Weight initialization: {state_dict.keys()}") + return super().load_state_dict(state_dict) + + def on_train_start(self, trainer, pl_module): + # run on all ranks to allow synchronization + barrier() + mllogger.submission_info() + + for key, value in self.configs.items(): + mllogger.event(key=key, value=value) + + mllogger.end(key=constants.INIT_STOP) + mllogger.start(key=constants.RUN_START) diff --git a/small_llm_pretraining/nemo/config_H100_1x8x4_8b.sh b/small_llm_pretraining/nemo/config_H100_1x8x4_8b.sh new file mode 100644 index 000000000..5f48b1784 --- /dev/null +++ b/small_llm_pretraining/nemo/config_H100_1x8x4_8b.sh @@ -0,0 +1,97 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SSH: username that connects to the remote cluster +export USER="DUMMY" +# SSH: remote cluster URL +export HOST="DUMMY" +# Slurm: account for job submission +export ACCOUNT="DUMMY" +# Slurm: partition for job submission +export PARTITION="DUMMY" +# Slurm: job time limit, defaults to 8 hours +export TIME="08:00:00" +# Slurm: --nodes arguments, default to use 288 nodes +export NNODES=1 +# Slurm: --gpus_per_node and --ntasks_per_node argument, defaults to 8 GPUs per node +export GPUS_PER_NODE=8 +# Slurm: max job retries for transient job failures, defaults to retry 3 times +export MAX_RETRIES=1 + +# Folder mapping: +# Output directory that holds logs, any path that you like. +export JOB_DIR="/workspace/code/logs" +# Image / container path, either local cache file or remote URL +export IMAGE="DUMMY" +# Dataset: C4 dataset location that contains the dataset after preprocessing +# export ORIGINAL_C4_PATH="/data/data/C4" + +# This corresponds to the PREPROCESSED_PATH in README section 3's dataset download part +export PREPROCESSED_PATH="/data/llama3_8b/data/C4_processed" +export MERGED_C4_PATH="/data/llama3_8b/data/C4_merged" +# Dataset: Numpy index working directory, contains shuffled dataset +# This path must be able to hold >400GB data +export TMP_NPY_INDEX="/data/npy_indices" +# Dataset: Tokenizer path +# This corresponds to the TOKENIZER_PATH in README section 3's tokenizer download part +export TOKENIZER_PATH="/data/llama3_8b/model/Llama-3.1-8B" +# export TOKENIZER_PATH="/data/llama3_405b_ref/tokenizer" + +# export MODEL_CKPT="None" +# Model: Continual checkpoint directory to write and resume +# This is the directory to hold all intermediate checkpoints. +# Once a run is complete and we specify to save checkpoints, +# we should see a checkpoint written in this folder +# with name `checkpoint-par-x-y-steps` +# Inside this directory, there should be a `checkpoint` directory that holds context and weights +# which is the "actual checkpoint". +# Notice that this path must be able to hold at least 5.2TB data since each checkpoint is 5.2TB. +export CONTINUAL_CKPT="/data/model/saved_ckpts" +# Model: Whether we want to save a checkpoint. Must be 1 if NPAR > 1. If 1, then we save a checkpoint at the end. +export SAVE_CKPT=0 + +# Training Configs: +# Model: size, to choose from 8b, 70b, 405b +export SIZE="8b" +# Dataloader: Global batch size +export GBS=32 +# Dataloader: Micro batch size +export MBS=1 +export MAX_LR="5e-4" +# Dataloader: Max run N batches, optional +# If an empty string is provided (""), then the training will continue until time limit +# If we want to save a checkpoint, then this value must be set +# export MAX_STEPS=1200000 # Fixed max_steps=1200000 in pretrain_llama31.py +export WARMUP_STEPS=512 # 16384 // GBS +export EVAL_EVERY=12288 +export START_EVAL_AT=0 + +export TENSOR_PARALLEL_SIZE=4 + +# Experiment: starting steps +# This is the starting "offset" step from the checkpoint. +# For instance, if you are resuming from a checkpoint folder `checkpoint-par-0-20-steps/checkpoint`, +# which means that the model is trained for 20 steps to generate the checkpoint, +# then the value 20 is needed here. +export START_STEPS="0" +# Experiment manager: Number of experiments to launch +export NEXP=1 +# Experiment manager: how many consecutive jobs we want for each experiment +export NPAR=1 +# Experiment manager: provides seeds to the launched experiments, use space as delimiter, such as "1234 1235 1236" +# The training script will discard all excessive seeds, and generate seeds if given seeds < NEXP. +# To preserve randomness, we recommend not to set this value so that each time seeds can be randomly generated. + + +export DGXSYSTEM=$(basename $(readlink -f ${BASH_SOURCE[0]}) | sed 's/^config_//' | sed 's/\.sh$//' ) diff --git a/small_llm_pretraining/nemo/config_H200_1x8x1_8b.sh b/small_llm_pretraining/nemo/config_H200_1x8x1_8b.sh new file mode 100644 index 000000000..7134b50c7 --- /dev/null +++ b/small_llm_pretraining/nemo/config_H200_1x8x1_8b.sh @@ -0,0 +1,95 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SSH: username that connects to the remote cluster +export USER="DUMMY" +# SSH: remote cluster URL +export HOST="DUMMY" +# Slurm: account for job submission +export ACCOUNT="DUMMY" +# Slurm: partition for job submission +export PARTITION="DUMMY" +# Slurm: job time limit, defaults to 8 hours +export TIME="08:00:00" +# Slurm: --nodes arguments, default to use 288 nodes +export NNODES=1 +# Slurm: --gpus_per_node and --ntasks_per_node argument, defaults to 8 GPUs per node +export GPUS_PER_NODE=8 +# Slurm: max job retries for transient job failures, defaults to retry 3 times +export MAX_RETRIES=1 + +# Folder mapping: +# Output directory that holds logs, any path that you like. +export JOB_DIR="/workspace/code/logs" +# Image / container path, either local cache file or remote URL +export IMAGE="DUMMY" +# Dataset: C4 dataset location that contains the dataset after preprocessing +# export ORIGINAL_C4_PATH="/data/data/C4" + +# This corresponds to the PREPROCESSED_PATH in README section 3's dataset download part +export PREPROCESSED_PATH="/data/llama3_8b/data/C4_processed" +export MERGED_C4_PATH="/data/llama3_8b/data/C4_merged" +# Dataset: Numpy index working directory, contains shuffled dataset +# This path must be able to hold >400GB data +export TMP_NPY_INDEX="/data/npy_indices" +# Dataset: Tokenizer path +# This corresponds to the TOKENIZER_PATH in README section 3's tokenizer download part +export TOKENIZER_PATH="/data/llama3_8b/model/Llama-3.1-8B" +# export TOKENIZER_PATH="/data/llama3_405b_ref/tokenizer" + +# Model: Continual checkpoint directory to write and resume +# This is the directory to hold all intermediate checkpoints. +# Once a run is complete and we specify to save checkpoints, +# we should see a checkpoint written in this folder +# with name `checkpoint-par-x-y-steps` +# Inside this directory, there should be a `checkpoint` directory that holds context and weights +# which is the "actual checkpoint". +# Notice that this path must be able to hold at least 5.2TB data since each checkpoint is 5.2TB. +export CONTINUAL_CKPT="/data/model/saved_ckpts" +# Model: Whether we want to save a checkpoint. Must be 1 if NPAR > 1. If 1, then we save a checkpoint at the end. +export SAVE_CKPT=0 + +# Training Configs: +# Model: size, to choose from 8b, 70b, 405b +export SIZE="8b" +# Dataloader: Global batch size +export GBS=32 +# Dataloader: Micro batch size +export MBS=2 +export MAX_LR="5e-4" +# Dataloader: Max run N batches, optional +# If an empty string is provided (""), then the training will continue until time limit +# If we want to save a checkpoint, then this value must be set +# export MAX_STEPS=1200000 # Fixed max_steps=1200000 in pretrain_llama31.py +export WARMUP_STEPS=512 # 16384 // GBS +export EVAL_EVERY=12288 +export START_EVAL_AT=0 + +export TENSOR_PARALLEL_SIZE=1 + +# Experiment: starting steps +# This is the starting "offset" step from the checkpoint. +# For instance, if you are resuming from a checkpoint folder `checkpoint-par-0-20-steps/checkpoint`, +# which means that the model is trained for 20 steps to generate the checkpoint, +# then the value 20 is needed here. +export START_STEPS="0" +# Experiment manager: Number of experiments to launch +export NEXP=1 +# Experiment manager: how many consecutive jobs we want for each experiment +export NPAR=1 +# Experiment manager: provides seeds to the launched experiments, use space as delimiter, such as "1234 1235 1236" +# The training script will discard all excessive seeds, and generate seeds if given seeds < NEXP. +# To preserve randomness, we recommend not to set this value so that each time seeds can be randomly generated. + +export DGXSYSTEM=$(basename $(readlink -f ${BASH_SOURCE[0]}) | sed 's/^config_//' | sed 's/\.sh$//' ) diff --git a/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh b/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh new file mode 100644 index 000000000..000465616 --- /dev/null +++ b/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh @@ -0,0 +1,157 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# MIT License + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SSH: username that connects to the remote cluster +export USER="DUMMY" +# SSH: remote cluster URL +export HOST="DUMMY" +# Slurm: account for job submission +export ACCOUNT="DUMMY" +# Slurm: partition for job submission +export PARTITION="DUMMY" +# Slurm: job time limit, defaults to 8 hours +export TIME="08:00:00" +# Slurm: --nodes arguments, default to use 288 nodes +export NNODES=1 +# Slurm: --gpus_per_node and --ntasks_per_node argument, defaults to 8 GPUs per node +export GPUS_PER_NODE=8 +# Slurm: max job retries for transient job failures, defaults to retry 3 times +export MAX_RETRIES=1 + +# Folder mapping: +# Output directory that holds logs, any path that you like. +export JOB_DIR="/workspace/code/logs" +# Image / container path, either local cache file or remote URL +export IMAGE="DUMMY" +# Dataset: C4 dataset location that contains the dataset after preprocessing +# export ORIGINAL_C4_PATH="/data/data/C4" + +# This corresponds to the PREPROCESSED_PATH in README section 3's dataset download part +export PREPROCESSED_PATH="/data/llama31_8b/data/C4_processed" +export MERGED_C4_PATH="/data/llama31_8b/data/C4_merged" +# Dataset: Numpy index working directory, contains shuffled dataset +# This path must be able to hold >400GB data +export TMP_NPY_INDEX="/data/npy_indices" +# Dataset: Tokenizer path +# This corresponds to the TOKENIZER_PATH in README section 3's tokenizer download part +export TOKENIZER_PATH="/data/llama31_8b/model/Llama-3.1-8B-ref/" +# export TOKENIZER_PATH="/data/llama3_405b_ref/tokenizer" + +# Model: Continual checkpoint directory to write and resume +# This is the directory to hold all intermediate checkpoints. +# Once a run is complete and we specify to save checkpoints, +# we should see a checkpoint written in this folder +# with name `checkpoint-par-x-y-steps` +# Inside this directory, there should be a `checkpoint` directory that holds context and weights +# which is the "actual checkpoint". +# Notice that this path must be able to hold at least 5.2TB data since each checkpoint is 5.2TB. +export CONTINUAL_CKPT="/data/model/saved_ckpts" +# Model: Whether we want to restore from MODEL_CKPT path. If 0, then we are not restoring. +export USE_CKPT=0 +# Model: Whether we are resuming from a NeMo-formatted HuggingFace checkpoint (weights only). +# If set to 1, then checkpoint resuming code will not try to load the optimizer states. +export FROM_HF=1 +# Model: Whether we want to save a checkpoint. Must be 1 if NPAR > 1. If 1, then we save a checkpoint at the end. +export SAVE_CKPT=0 + +# Training Configs: +# Model: size, to choose from 8b, 70b, 405b +export SIZE="8b" +# Dataloader: Global batch size +export GBS=32 +# Dataloader: Micro batch size +export MBS=4 +export MAX_LR="5e-4" +# Dataloader: Max run N batches, optional +# If an empty string is provided (""), then the training will continue until time limit +# If we want to save a checkpoint, then this value must be set +# Fixed max_steps=1200000 in pretrain_llama31.py +export WARMUP_STEPS=512 # 16384 // GBS +export EVAL_EVERY=12288 +export START_EVAL_AT=0 + +export TENSOR_PARALLEL_SIZE=1 +# Experiment: starting steps +# This is the starting "offset" step from the checkpoint. +# For instance, if you are resuming from a checkpoint folder `checkpoint-par-0-20-steps/checkpoint`, +# which means that the model is trained for 20 steps to generate the checkpoint, +# then the value 20 is needed here. +export START_STEPS="0" +# Experiment manager: Number of experiments to launch +export NEXP=1 +# Experiment manager: how many consecutive jobs we want for each experiment +export NPAR=1 +# Experiment manager: provides seeds to the launched experiments, use space as delimiter, such as "1234 1235 1236" +# The training script will discard all excessive seeds, and generate seeds if given seeds < NEXP. +# To preserve randomness, we recommend not to set this value so that each time seeds can be randomly generated. +# export SEEDS=7963 +# export SEEDS=4786 + +export DGXSYSTEM=$(basename $(readlink -f ${BASH_SOURCE[0]}) | sed 's/^config_//' | sed 's/\.sh$//' ) + +# Extra configs +export NCCL_MIN_P2P_NCHANNELS=32; +export NCCL_MIN_CTAS=32; +export NCCL_NCHANNELS_PER_NET_PEER=32; +export NCCL_NVLS_ENABLE=0 + +export TP_COMM_OVERLAP=False +export MC_TP_OVERLAP_AG=False +export MC_TP_OVERLAP_RS=False +export MC_TP_OVERLAP_RS_DGRAD=False + +export CUBLAS_FORCE_XMMA_KERNEL_INIT=DEVICE + +export NVTE_RS_STRIDED_ATOMIC=2 +export NVTE_FP8_DPA_BWD=1 +export NVTE_FUSED_ATTN=1 +export NVTE_FUSED_ATTN_CK=1 +export NVTE_FUSED_ATTN_AOTRITON=1 +export NVTE_DEBUG=0 +export NVTE_DEBUG_LEVEL=0 +export NVTE_USE_HIPBLASLT=1 +export NVTE_USE_CAST_TRANSPOSE_TRITON=1 +export NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE=0 +export USE_TE_SWIGLU=1 + +# FAv3 +export NVTE_CK_USES_BWD_V3=1 # enable dqdkdv bwd kernel +export NVTE_CK_USES_FWD_V3=1 +export NVTE_CK_IS_V3_ATOMIC_FP32=0 # 16bit atomics + +# ck logging +export CK_FUSED_ATTN_LOG_CONFIG=0 # Diable logging for CK fused attn. Enabled for debugging onl + +export CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=0 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +export FUSED_SOFTMAX=0 +export RMSNORM_CAST=0 + +export PT_TENSOR_VALIDATION=0 +export PROFILE_RPD=0 + +export USE_HIPBLASLT=1 +export TORCH_BLAS_PREFER_HIPBLASLT=1 + +export NVTE_USE_RMSNORM_TRITON=1 +export ENABLE_TRANSPOSE_CACHE=0 \ No newline at end of file diff --git a/small_llm_pretraining/nemo/patches/nemo_v2_1_0.patch b/small_llm_pretraining/nemo/patches/nemo_v2_1_0.patch new file mode 100644 index 000000000..590256d55 --- /dev/null +++ b/small_llm_pretraining/nemo/patches/nemo_v2_1_0.patch @@ -0,0 +1,64 @@ +diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py +index 646540e88..fad8e4e0a 100644 +--- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py ++++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py +@@ -143,7 +143,13 @@ def zero_module(module): + + + def Normalize(in_channels, num_groups=32, act=""): +- return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) ++ try: ++ from apex.contrib.group_norm import GroupNorm as GroupNorm ++ return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) ++ except ImportError: ++ print("Using torch.nn.GroupNorm. Hip/Cuda kernel could not be imported from Apex") ++ import torch.nn.GroupNorm as GroupNorm ++ return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + + class LinearAttention(nn.Module): +@@ -595,4 +601,4 @@ class SpatialTransformer(nn.Module): + x = x.transpose(1, 2).view(b, c, h, w) # b (h w) c -> b c h w + if not self.use_linear: + x = self.proj_out(x) +- return x_in + x ++ return x_in + x +\ No newline at end of file +diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py +index 69700a436..bfd21ab7e 100644 +--- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py ++++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py +@@ -257,7 +257,13 @@ def mean_flat(tensor): + + + def normalization(in_channels, act="", gn_groups=32): +- return GroupNorm(num_groups=gn_groups, num_channels=in_channels, eps=1e-5, affine=True, act=act) ++ try: ++ from apex.contrib.group_norm import GroupNorm as GroupNorm ++ return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) ++ except ImportError: ++ print("Using torch.nn.GroupNorm. Hip/Cuda kernel could not be imported from Apex") ++ import torch.nn.GroupNorm as GroupNorm ++ return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + + # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +@@ -361,4 +367,4 @@ def exists(x): + def default(val, d): + if exists(val): + return val +- return d() if isfunction(d) else d ++ return d() if isfunction(d) else d +\ No newline at end of file +diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt +index 6a86dacbf..a6f60380e 100644 +--- a/requirements/requirements_nlp.txt ++++ b/requirements/requirements_nlp.txt +@@ -11,7 +11,6 @@ jieba + mamba-ssm==2.2.2; sys_platform == 'linux' + markdown2 + matplotlib>=3.3.2 +-#megatron_core>0.6.0 # add back once mcore on pypi is compatible again + nltk>=3.6.5 + numpy<2 # tensorstore has an implicit compiled dependency on numpy<2 + opencc diff --git a/small_llm_pretraining/nemo/pretrain_llama31.py b/small_llm_pretraining/nemo/pretrain_llama31.py new file mode 100644 index 000000000..8b039a3bd --- /dev/null +++ b/small_llm_pretraining/nemo/pretrain_llama31.py @@ -0,0 +1,579 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import math +import argparse +from typing import Optional + +import torch +import wandb +from lightning.pytorch.loggers import WandbLogger + +from nemo.collections import llm +from nemo.collections.common.tokenizers import AutoTokenizer +from nemo import lightning as nl +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +import nemo_run as run +from nemo.lightning.run import plugins +from nemo.collections.llm.gpt.data import build_pretraining_datamodule + +from callbacks import PreemptiveStop, MLPerfCallback, MetricsLogger + + +def local_executor( + custom_env_vars: Optional[dict[str, str]] = None, + devices: int = 8, + retries: int = 0, +) -> run.LocalExecutor: + env_vars = { + "TRANSFORMERS_OFFLINE": "1", + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NCCL_NVLS_ENABLE": "0", + "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", + "NVTE_ASYNC_AMAX_REDUCTION": "1", + "TOKENIZERS_PARALLELISM": "false", + } + + if custom_env_vars: + env_vars |= custom_env_vars + + executor = run.LocalExecutor() + executor.launcher = 'torchrun' + executor.env_vars = env_vars + executor.retries = retries + executor.nodes = 1 + executor.ntasks_per_node = devices + + return executor + +def slurm_executor( + user: str, + host: str, + remote_job_dir: str, + account: str, + partition: str, + nodes: int, + devices: int, + time: str = "01:00:00", + custom_mounts: Optional[list[str]] = None, + custom_env_vars: Optional[dict[str, str]] = None, + container_image: str = "nvcr.io/nvidia/nemo:dev", + dependencies: list[str] = [], + retries: int = 0, +) -> run.SlurmExecutor: + if not (user and host and remote_job_dir and account and partition and nodes and devices): + raise RuntimeError( + "Please set user, host, remote_job_dir, account, partition, nodes and devices args for using this function." + ) + + mounts = [] + if custom_mounts: + mounts.extend(custom_mounts) + + env_vars = { + "TRANSFORMERS_OFFLINE": "1", + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NCCL_NVLS_ENABLE": "0", + "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", + "NVTE_ASYNC_AMAX_REDUCTION": "1", + "NVTE_FUSED_ATTN": "1", + "TOKENIZERS_PARALLELISM": "false", + } + if custom_env_vars: + env_vars |= custom_env_vars + + executor = run.SlurmExecutor( + account=account, + partition=partition, + tunnel=run.SSHTunnel( + user=user, + host=host, + job_dir=remote_job_dir, + ), + exclusive=True, + gres="gpu:8", + nodes=nodes, + ntasks_per_node=devices, + mem="0", + packager=run.GitArchivePackager(subpath="small_language_model_pretraining/nemo", ref="HEAD"), + dependencies=dependencies, + ) + + executor.launcher = None + executor.container_image = container_image + executor.container_mounts = mounts + executor.env_vars = env_vars + executor.retries = retries + executor.time = time + + return executor + +def get_pretrain( + size: str, + nnodes: int, + ngpus_per_node: int, + max_steps: int, + warmup_steps: int, + data_module: run.Config, + max_lr: float = 1e-4, + eval_every: Optional[int] = None, + start_eval_at: Optional[int] = None, + eval_batches: Optional[int] = None, +) -> run.Partial: + + exp_name = size + + pretrain = llm.llama3_8b.pretrain_recipe( + dir="/outputs", + name=exp_name, + num_nodes=nnodes, + num_gpus_per_node=ngpus_per_node + ) + + llama31_config = run.Config(llm.gpt.model.llama.Llama31Config8B) + llama31_config.seq_length = 8192 + pretrain.model.config = llama31_config + + pretrain.trainer.strategy.tensor_model_parallel_size = 1 + pretrain.trainer.strategy.pipeline_model_parallel_size = 1 + pretrain.trainer.strategy.virtual_pipeline_model_parallel_size = 1 # set it back to 7? + pretrain.trainer.strategy.context_parallel_size = 1 + + # Code tracing shows that this is AdamW + pretrain.optim = distributed_fused_adam_with_cosine_annealing( + max_lr=max_lr, + warmup_steps=warmup_steps, + min_lr=max_lr * 0.1 + ) + + precision = run.Config( + nl.MegatronMixedPrecision, + precision="bf16-mixed", + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + autocast_enabled=True, + grad_reduce_in_fp32=False, + fp8="hybrid", + fp8_amax_history_len=4, + fp8_amax_compute_algo='most_recent', + fp8_params=True, + fp8_dot_product_attention=False, + ) + + pretrain.trainer.plugins = precision + + # sets up everything else + pretrain.trainer.max_steps = max_steps #1200000 # Hardcoded to fix max_steps for this benchmark + + pretrain.data = data_module + pretrain.trainer.val_check_interval = eval_every / int (os.getenv ("GBS")) + pretrain.trainer.limit_val_batches = eval_batches + pretrain.trainer.limit_test_batches = eval_batches + + pretrain.log.tensorboard = None + pretrain.log.ckpt.every_n_train_steps = None + pretrain.log.ckpt.save_top_k = -1 + pretrain.log.ckpt.save_last = False + pretrain.log.ckpt.always_save_context = False + pretrain.log.ckpt.save_weights_only = False + pretrain.log.ckpt.save_optim_on_train_end = False + pretrain.log.ckpt.save_on_train_epoch_end = False + pretrain.log.ckpt.monitor = "consumed_samples" + pretrain.log.ckpt.mode = "max" + + return exp_name, pretrain + +def get_data( + gbs: int = 288, + mbs: int = 4, + seq_length: Optional[int] = 8192, + tokenizer_path: Optional[str] = "", + seed: Optional[int] = 1234, + use_full_dataset: Optional[bool] = False, +) -> run.Config: + tokenizer = run.Config(AutoTokenizer, pretrained_model_name=tokenizer_path) + + print (f'tokenizer: {tokenizer}') + print (f'use_full_dataset: {use_full_dataset}') + + train_datasets = None + + dataset_path = dataset_path = os.getenv("PREPROCESSED_PATH") + + if use_full_dataset: + train_datasets = sum([["12.5", f"{dataset_path}/c4-train.en_{idx}_text_document"] for idx in range(8)], []) + else: + train_datasets = sum([["10", f"{dataset_path}/c4-train.en_{idx}_text_document"] for idx in [6]], []) + + data_paths = { + "train": train_datasets, + "validation": [ + f"{dataset_path}/c4-validation-91205-samples.en_text_document" + ], + "test": [ + f"{dataset_path}/c4-validation-91205-samples.en_text_document" + ], + } + + return run.Config( + llm.PreTrainingDataModule, + tokenizer=tokenizer, + paths=data_paths, + num_workers=128, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + index_mapping_dir="/npy_index", + seed=seed, + + # Option to reset the position IDs in the dataset at an interval. + reset_position_ids=False, + # Option to reset the attention mask from the dataset. + reset_attention_mask=False, + # Option to enable the EOD mask loss. + eod_mask_loss=False, + # Rampup batch size, should be in format of [start_global_batch_size, batch_size_increment, ramup_samples]. + rampup_batch_size=None, + ) + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Llama3.1 Pretraining") + parser.add_argument("--tag", type=str, help="Optional experiment tag", required=False, default="") + + # Slurm and executor related + slurm_group = parser.add_argument_group("Slurm executor arguments") + slurm_group.add_argument('--user', type=str, required=True, help="Remote cluster SSH user name") + slurm_group.add_argument("--host", type=str, required=True, help="Remote cluster host address") + slurm_group.add_argument("--job_dir", type=str, required=True, help="Remote job directory") + + slurm_group.add_argument("--account", type=str, required=True, help="Account to be used for Slurm job submission") + slurm_group.add_argument("--partition", type=str, required=True, help="Partition to be used for Slurm job submission") + slurm_group.add_argument("--nodes", type=int, required=True, help="Number of nodes to be used") + slurm_group.add_argument("--gpus_per_node", type=int, required=True, help="Number of GPUs per node") + slurm_group.add_argument("--time", type=str, required=True, help="Time limit for the job") + slurm_group.add_argument("--dependencies", nargs="*", help="list of dependencies for the job, dependency type as 'afterok'") # not useful for now + slurm_group.add_argument("--max_retries", type=int, default=0) + slurm_group.add_argument("--run_slurm", action="store_true", help="run in slurm executor instead of locally") + + slurm_group.add_argument( + "--mounts", + type=str, + required=True, + help=( + "Custom mount paths, formatted as a string of :[,:], " + + "and should contain " + + "one path for /output, " + + "NeMo mounted on /opt/NeMo, " + + "dataset path: /workspace/llm/tokenizer.model, /preproc_data, /npy_index" + )) + slurm_group.add_argument("--envvars", type=str, help="Environment variables to be added", default=None) + slurm_group.add_argument("--image", type=str, required=True, help="Container image path, either remote or local") + + model_group = parser.add_argument_group("Model arguments") + model_group.add_argument( + "--size", + type=str, + default="8b", + help="Choose the model to be trained", + choices=[ + "8b", # Llama 3 8B config + ]) + + model_group.add_argument("--initial_ckpt_path", type=str, default=None) + model_group.add_argument("--use_ckpt", action="store_true", help="If set, then resume from the initial checkpoint path") + model_group.add_argument("--resume_from_hf", action="store_true", help="Setting this knob indicates that we are resuming from a weight-only checkpoint") + model_group.add_argument("--ckpt_start_step", type=int, default=0, help="Sets this value to how many steps the resumed checkpoint is already trained on") + model_group.add_argument("--continual_ckpt_path", type=str, default=None, help="Sets this to the path that saves the checkpoint") + model_group.add_argument("--save_ckpt", action="store_true", help="If set, then we save the checkpoint at the end of the experiment") + model_group.add_argument("--tensor_parallel_size", type=int, default=None, help="Set tensor parallelism to the model") + + data_group = parser.add_argument_group("Dataset arguments") + + data_group.add_argument("--gbs", type=int, default=1152, help="Global batch size, should be divisible by PP") + data_group.add_argument("--mbs", type=int, default=1, help="Micro batch size") + data_group.add_argument("--max_lr", type=float, default=1e-4, help="Peak learning rate. Min LR will be 0.1 of max_lr") + data_group.add_argument("--eval_every", type=int, default=12288, help="Evaluate at least every N training sequences") + data_group.add_argument("--start_eval_at", type=int, default=0, help="Start evaluation at N training sequences") + data_group.add_argument("--eval_tokens", type=int, default=1024, help="Evaluate using at least N evaluation sequences") + data_group.add_argument('--max_steps', type=int, default=1200000, help="Maximum number of steps that each experiment partition will train on. None means no restriction on max steps. ") + data_group.add_argument('--warmup_steps', type=int, default=None, help="Number of steps for LR warmup") + data_group.add_argument("--use_full_dataset", action="store_true", help="If set, then we use the full dataset, instead of the last 256/1024 shards") + data_group.add_argument("--tokenizer_path", type=str, help="Tokenizer path that's used to tokenize the dataset") + + experiment_group = parser.add_argument_group("Experiment management arguments") + experiment_group.add_argument("--dryrun", action="store_true", help="Whether we are launching dryrun or actual runs") + experiment_group.add_argument("--seeds", type=int, nargs="*", default=[], help="random seeds") + experiment_group.add_argument("--num_exps", type=int, default=1) + experiment_group.add_argument("--num_pars", type=int, default=1) + experiment_group.add_argument("--target_log_ppl", type=float, default=5.6) + experiment_group.add_argument("--step_time_atol", type=int, default=1600, help="train step time atol") + + return parser + + +if __name__ == "__main__": + args = get_parser().parse_args() + if args.tag and not args.tag.startswith("-"): + args.tag = "-" + args.tag + + assert not (args.num_pars == 1 and args.continual_ckpt_path is None), "NPar > 1 but a shared checkpoint path is not found" + assert not (not args.save_ckpt and args.num_pars > 1), "multiple experiments are specified but checkpoint is not saved" + + if args.run_slurm: + executor = slurm_executor( + user=args.user, + host=args.host, + remote_job_dir=args.job_dir, + account=args.account, + partition=args.partition, + nodes=args.nodes, + devices=args.gpus_per_node, + time=args.time, + custom_mounts=list(args.mounts.split(",")), + custom_env_vars=({envvar.split("=")[0]: envvar.split("=")[1] for envvar in args.envvars.split(",")} if args.envvars is not None else None), + container_image=args.image, + dependencies=args.dependencies, + retries=args.max_retries, + ) + else: + executor = local_executor( + custom_env_vars=({envvar.split("=")[0]: envvar.split("=")[1] for envvar in args.envvars.split(",")} if args.envvars is not None else None), + devices=args.gpus_per_node, + retries=args.max_retries, + ) + + seq_length = 8192 + + data = get_data( + gbs=args.gbs, + mbs=args.mbs, + seq_length=seq_length, + tokenizer_path=args.tokenizer_path, + seed=1234, # overwritten in each experiments + use_full_dataset=args.use_full_dataset, + ) + + eval_every_n_batches = math.ceil(args.eval_every / (args.gbs)) + eval_batches = math.ceil(args.eval_tokens / (args.gbs)) + if args.start_eval_at == 0: + start_eval_at = math.ceil(args.start_eval_at / args.gbs) + else: + start_eval_at = eval_every_n_batches + + exp_prefix, pretrain = get_pretrain( + max_lr=args.max_lr, + size=args.size, + nnodes=args.nodes, + ngpus_per_node=args.gpus_per_node, + max_steps=args.max_steps, + warmup_steps=args.warmup_steps, + data_module=data, + eval_every=eval_every_n_batches, + start_eval_at=start_eval_at, + eval_batches=eval_batches, + ) + + + # assert args.gbs % pretrain.trainer.strategy.pipeline_model_parallel_size == 0, f"GBS({args.gbs}) should be divisible by PP({pretrain.trainer.strategy.pipeline_model_parallel_size})" + + # Collect all HP configs + from mlperf_logging.mllog import constants + tp = args.tensor_parallel_size or pretrain.trainer.strategy.tensor_model_parallel_size + pp = pretrain.trainer.strategy.pipeline_model_parallel_size + cp = pretrain.trainer.strategy.context_parallel_size + dp = (pretrain.trainer.num_nodes * pretrain.trainer.devices) // (tp * pp * cp) + mini_batch_size = (args.gbs // dp) + grad_accumulation_steps = mini_batch_size // args.mbs + print(f"Parallel settings: {tp=} {pp=} {cp=} {dp=} {mini_batch_size=} {grad_accumulation_steps=}") + # assert False, f"Parallel settings: {tp=} {pp=} {cp=} {pretrain.trainer.num_nodes=} {pretrain.trainer.devices=} {dp=} {args.gbs=} {mini_batch_size=} {grad_accumulation_steps=}" + + configs = { + # HPs + constants.GLOBAL_BATCH_SIZE: args.gbs, + constants.GRADIENT_ACCUMULATION_STEPS: grad_accumulation_steps, + constants.MAX_SEQUENCE_LENGTH: 8192, + constants.EVAL_SAMPLES: args.eval_tokens, + + # Optimizers + constants.OPT_NAME: "adamw", + constants.OPT_BASE_LR: pretrain.optim.config.lr, + constants.OPT_ADAMW_BETA_1: pretrain.optim.config.adam_beta1, + constants.OPT_ADAMW_BETA_2: pretrain.optim.config.adam_beta2, + constants.OPT_ADAMW_EPSILON: pretrain.optim.config.adam_eps, + constants.OPT_ADAMW_WEIGHT_DECAY: pretrain.optim.config.weight_decay, + constants.OPT_GRADIENT_CLIP_NORM: pretrain.optim.config.clip_grad, + + # Schedulers + constants.OPT_END_LR: pretrain.optim.lr_scheduler.min_lr, + constants.OPT_LR_WARMUP_STEPS: pretrain.optim.lr_scheduler.warmup_steps, + constants.OPT_LR_DECAY_STEPS: pretrain.trainer.max_steps - pretrain.optim.lr_scheduler.warmup_steps, + constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmup", + } + + # Override config for MLPerf + pretrain.trainer.num_sanity_val_steps = 0 + + run_plugins = [ + plugins.PerfEnvPlugin(), + ] + + exp_prefix = f"{exp_prefix}{args.tag}" + + # Pretrain data index builder + # max steps + pretrain.data.num_train_samples = pretrain.trainer.max_steps * pretrain.data.global_batch_size + print (f'{pretrain.trainer.max_steps=}\n{pretrain.data.global_batch_size=}\n{pretrain.data.num_train_samples=}') + datamodule = pretrain.data.clone() + datamodule.num_dataset_builder_threads = 64 + build_data_index = run.Partial( + build_pretraining_datamodule, + datamodule=datamodule, + trainer_max_steps=pretrain.trainer.max_steps, + trainer_val_check_interval=pretrain.trainer.val_check_interval, + trainer_limit_val_batches=pretrain.trainer.limit_val_batches, + trainer_limit_test_batches=pretrain.trainer.limit_test_batches, + ) + data_index_executor = executor.clone() + data_index_executor.launcher = 'torchrun' + data_index_executor.nodes = 1 + data_index_executor.ntasks_per_node = 1 + data_index_executor.retries = 1 + + static_read_from_path = args.initial_ckpt_path if args.use_ckpt else None + static_write_to_path = args.continual_ckpt_path + static_max_steps = args.max_steps if args.max_steps is not None else None + # Enable this to make static_max_steps not None to enable PreemptiveStop overwrite PL in Callback + static_max_steps = pretrain.trainer.max_steps if static_max_steps is None else static_max_steps + + print (f'{static_max_steps=}') + if not args.save_ckpt: + print (f'Not saving checkpoints') + pretrain.trainer.enable_checkpointing = False + else: + print (f'Saving checkpoints') + + original_callbacks = pretrain.trainer.callbacks + + random_seeds = args.seeds + if len(random_seeds) < args.num_exps: + import random + random_seeds = random_seeds + [random.randint(0, 32767) for _ in range(args.num_exps - len(random_seeds))] + print(f"Missing {args.num_exps - len(random_seeds)} seeds, padding the random seeds to {random_seeds}") + + random_seeds = random_seeds[:args.num_exps] + + for index, seed in enumerate(random_seeds): + # sets the seeds + pretrain.data.seed = seed + build_data_index.datamodule.seed = seed + configs[constants.SEED] = seed + + exp_name = f"{exp_prefix}_{index}_seed_{seed}" + experiment_read_from_path = static_read_from_path + experiment_write_to_path = static_write_to_path + experiment_max_steps = args.ckpt_start_step + + with run.Experiment(exp_name) as exp: + exp.add(build_data_index, executor=data_index_executor, name=f"build_data_index") + + for j in range(args.num_pars): + ending_steps = "" + starting_steps = f"{experiment_max_steps}" + if static_max_steps is not None: + ending_steps = f"-{experiment_max_steps + static_max_steps}-steps" + + print (f'experiment_max_steps: {experiment_max_steps}') + checkpoint_name = "checkpoint" + f"-seed-{seed}-par-{j}{ending_steps}" + experiment_write_to_path = static_write_to_path + "/" + checkpoint_name + + if not (args.resume_from_hf and j == 0): + pretrain.resume = run.Config( + nl.AutoResume, + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_path = experiment_read_from_path, + resume_from_directory = experiment_read_from_path, + ) + else: + pretrain.resume = run.Config(nl.AutoResume, restore_config = run.Config(nl.RestoreConfig, path=experiment_read_from_path)) + pretrain.log.ckpt.train_time_interval = None + + if args.save_ckpt: + pretrain.log.ckpt.dirpath = experiment_write_to_path + pretrain.log.ckpt.filename = "checkpoint" + + if static_max_steps is not None: + start_step = experiment_max_steps + experiment_max_steps += static_max_steps + print (f'static_max_steps: {static_max_steps}') + print (f'stop_on_step=experiment_max_steps={experiment_max_steps}') + configs[constants.INIT_CHECKPOINT_STEP] = start_step + pretrain.trainer.callbacks = ( + original_callbacks + [ + run.Config(PreemptiveStop, stop_on_step=experiment_max_steps), + run.Config( + MLPerfCallback, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + sequence_length=8192, + eval_every=eval_every_n_batches, + init_global_step=start_step, + configs=configs, + ), + ] + ) + + if args.save_ckpt: + pretrain.log.ckpt.every_n_train_steps = experiment_max_steps + pretrain.log.ckpt.save_on_train_epoch_end = False + + try: + print ("control C to skip") + login_info = wandb.login() + print("WandB is logged in.") + pretrain.log.extra_loggers = [ + run.Config( + WandbLogger, + project='llama3.1_8b_training', + name=f'{checkpoint_name}-gbs={args.gbs}-lr={pretrain.optim.config.lr}', + + ), + ] + except: + print("WandB is NOT logged in.") + pretrain.log.extra_loggers = [ + run.Config( + MetricsLogger, + init_global_step=start_step, + global_batch_size=args.gbs, + seq_length=8192, + target_log_ppl=args.target_log_ppl, + train_step_time_atol=args.step_time_atol, + ), + ] + if args.save_ckpt: + pretrain.log.ckpt.every_n_train_steps = experiment_max_steps + pretrain.log.ckpt.save_on_train_epoch_end = False + experiment_read_from_path = experiment_write_to_path + "/checkpoint" + + exp.add( + pretrain, executor=executor, + name=f"{exp_name}_{j}_{starting_steps}{ending_steps}", + plugins=run_plugins + ) + + if args.dryrun: + exp.dryrun() + else: + exp.run(sequential=True, detach=True) diff --git a/small_llm_pretraining/nemo/requirements.txt b/small_llm_pretraining/nemo/requirements.txt new file mode 100644 index 000000000..3dd355195 --- /dev/null +++ b/small_llm_pretraining/nemo/requirements.txt @@ -0,0 +1,9 @@ +git+https://github.com/mlcommons/logging.git@5.0.0-rc2 +git+https://github.com/NVIDIA/mlperf-common.git@68cf1d0d5e3de3351e66abb696d0e2d011aabf47 +huggingface_hub==0.24.0 +transformers==4.43.2 +numpy==1.26.4 +plotly==6.0.0 +nbformat==5.10.4 +kaleido==0.2.1 +redis==5.2.1 \ No newline at end of file diff --git a/small_llm_pretraining/nemo/run_llama31.sh b/small_llm_pretraining/nemo/run_llama31.sh new file mode 100644 index 000000000..5f5ea57b1 --- /dev/null +++ b/small_llm_pretraining/nemo/run_llama31.sh @@ -0,0 +1,152 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +#git config --global --add safe.directory /workspace/llama31 + +# Vars without defaults +# Slurm settings +: "${USER:?USER not set}" +: "${HOST:?HOST not set}" +: "${ACCOUNT:?ACCOUNT not set}" +: "${PARTITION:?PARTITION not set}" +: "${REMOTE:=0}" + +# Job settings +: "${JOB_DIR:?JOB_DIR not set}" +: "${IMAGE:?IMAGE not set}" + +# Dataset settings +: "${PREPROCESSED_PATH:?PREPROCESSED_PATH not set}" +: "${TOKENIZER_PATH:?TOKENIZER_PATH not set}" + +# Model settings +: "${USE_CKPT:="0"}" +: "${FROM_HF:?FROM_HF not set}" +: "${CONTINUAL_CKPT:?CONTINUAL_CKPT not set}" + +# Vars with defaults +# Slurm settings +: "${TIME:="04:00:00"}" +: "${NNODES:=1}" +: "${GPUS_PER_NODE:=8}" +: "${DEPENDENCIES:=""}" + +# Job settings +: "${NEMO_DIR:=""}" # Provide customized NeMo path here +: "${NEMO_RUN_DIR:=""}" # Provide customized NeMo-Run path here +: "${TMP_NPY_INDEX:=""}" # Provide temporary NNumpy Index saving directory +: "${MAX_RETRIES:=0}" + +# Model settings +: "${SIZE:="8b"}" +: "${GBS:=4}" +: "${MBS:=1}" +: "${START_STEPS:=0}" + +# Dataloader settings +: "${MAX_STEPS:="1200000"}" + +# Experiment settings +: "${SEEDS:=""}" +IFS=" " read -ra seeds <<< $SEEDS +: "${NEXP:=1}" +: "${NPAR:=1}" +: "${SAVE_CKPT:=0}" +: "${TAG:=""}" +: "${TARGET:="3.3"}" +: "${STEP_TIME_ATOL:="18000"}" # maximum tolerable step time, setting to 5hr by default + +# Run + +MOUNTS="${JOB_DIR}:/output,${JOB_DIR}:/mlperf-outputs,${PREPROCESSED_PATH}:/preproc_data,${TOKENIZER_PATH}:/tokenizer,${CONTINUAL_CKPT}:/continual" + +CKPT_OPTION="" + +CMD_SUFFIX="" + +if [ $USE_CKPT -gt 0 ]; then + CMD_SUFFIX="${CMD_SUFFIX} --use_ckpt" + if [ $FROM_HF -gt 0 ]; then + CMD_SUFFIX="${CMD_SUFFIX} --resume_from_hf" + fi +fi + +if [ $SAVE_CKPT -gt 0 ]; then + CMD_SUFFIX="${CMD_SUFFIX} --save_ckpt" +fi + +if [ ! $NEMO_DIR = "" ]; then + MOUNTS="${MOUNTS},${NEMO_DIR}:/opt/NeMo" +fi + +if [ ! $NEMO_RUN_DIR = "" ]; then + MOUNTS="${MOUNTS},${NEMO_RUN_DIR}:/opt/NeMo-Run" +fi + +if [ ! $TMP_NPY_INDEX = "" ]; then + MOUNTS="${MOUNTS},${TMP_NPY_INDEX}:/npy_index" +fi + +if [ ! $DEPENDENCIES = "" ]; then + CMD_SUFFIX="${CMD_SUFFIX} --dependencies ${DEPENDENCIES}" +fi + +if [ ! $MAX_STEPS = "" ]; then + CMD_SUFFIX="${CMD_SUFFIX} --max_steps ${MAX_STEPS}" +fi + +if [ ! $TAG = "" ]; then + CMD_SUFFIX="${CMD_SUFFIX} --tag ${TAG}" +fi + +if [ $REMOTE -gt 0 ]; then + CMD_SUFFIX="${CMD_SUFFIX} --run_slurm" +fi + +if [ $TENSOR_PARALLEL_SIZE -gt 0 ]; then + CMD_SUFFIX="${CMD_SUFFIX} --tensor_parallel_size ${TENSOR_PARALLEL_SIZE}" +fi + +# Allows MLLogger objects to be constructed locally +if [ ! -d /mlperf-outputs ]; then mkdir /mlperf-outputs; fi + +set -x + +python3 pretrain_llama31.py \ +--user $USER --host $HOST \ +--job_dir $JOB_DIR \ +--account $ACCOUNT --partition $PARTITION \ +--nodes $NNODES --gpus_per_node $GPUS_PER_NODE \ +--time $TIME \ +--max_retries $MAX_RETRIES \ +--mounts $MOUNTS \ +--image $IMAGE \ +--size $SIZE \ +--gbs $GBS --mbs $MBS \ +--seeds ${seeds[@]} \ +--num_exps $NEXP \ +--num_pars $NPAR \ +--continual_ckpt_path $CONTINUAL_CKPT \ +--tokenizer_path $TOKENIZER_PATH \ +--target_log_ppl $TARGET \ +--step_time_atol $STEP_TIME_ATOL \ +--ckpt_start_step $START_STEPS \ +--warmup_steps $WARMUP_STEPS \ +--eval_every $EVAL_EVERY \ +--start_eval_at $START_EVAL_AT \ +$CMD_SUFFIX diff --git a/small_llm_pretraining/nemo/utils/consolidate_data.sh b/small_llm_pretraining/nemo/utils/consolidate_data.sh new file mode 100644 index 000000000..e1a9b637a --- /dev/null +++ b/small_llm_pretraining/nemo/utils/consolidate_data.sh @@ -0,0 +1,42 @@ +set -e + +: "${C4_PATH:?C4_PATH not set}" +: "${MERGED_C4_PATH:?MERGED_C4_PATH not set}" +: "${N_VALIDATION_SAMPLES:=91205}" +# defaults the N_VALIDATION_SAMPLES to 91205 +# C4 validation dataset: each sample on average tokenizes to 518 tokens +# thus, to reach 47,185,920 validation tokens, we need to use at least 91205 samples, +# which, after tokenization, will yield 47,186,855 tokens. + +# create softlinks to store each shard before merging +mkdir -p softlinks +for shard in {0..7}; do + start=$((shard * 128)) + end=$((shard * 128 + 127)) + mkdir -p softlinks/en_$shard + for ind in $(seq -f "%05g" $start $end); do + src=${C4_PATH}/c4-train.${ind}-of-01024.json.gz + if [ -f "$src" ]; then + ln -s "$src" softlinks/en_${shard}/ + else + echo "Warning: missing file $src — skipping" >&2 + fi + done +done + +mkdir -p softlinks/en_validation +start=0 +end=7 +for ind in $(seq -f "%05g" $start $end); do + ln -s ${C4_PATH}/c4-validation.${ind}-of-00008.json.gz softlinks/en_validation/c4-validation.${ind}-of-00008.json.gz +done + +# merge +for shard in {0..7}; do + cat softlinks/en_${shard}/*gz > ${MERGED_C4_PATH}/c4-train.en_${shard}.json.gz +done + +cat softlinks/en_validation/*gz > ${MERGED_C4_PATH}/c4-validation.en.json.gz + +# select the first N_VALIDATION_SAMPLES number of samples +zcat ${MERGED_C4_PATH}/c4-validation.en.json.gz | head -n $N_VALIDATION_SAMPLES | gzip > ${MERGED_C4_PATH}/c4-validation-${N_VALIDATION_SAMPLES}-samples.en.json.gz \ No newline at end of file diff --git a/small_llm_pretraining/nemo/utils/parallel_compress_json_to_gz.sh b/small_llm_pretraining/nemo/utils/parallel_compress_json_to_gz.sh new file mode 100644 index 000000000..2c506708f --- /dev/null +++ b/small_llm_pretraining/nemo/utils/parallel_compress_json_to_gz.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e + +: "${C4_PATH:?C4_PATH not set}" + +echo "Starting parallel compression in $C4_PATH..." + +# Use 50% of available CPU cores (adjust -j as needed) +find "$C4_PATH" -maxdepth 1 -name '*.json' | \ + parallel -j$(nproc) ' + echo "Compressing {}" + gzip "{}" +' + +echo "Parallel compression complete!" + diff --git a/small_llm_pretraining/nemo/utils/preprocess.sh b/small_llm_pretraining/nemo/utils/preprocess.sh new file mode 100644 index 000000000..604377c46 --- /dev/null +++ b/small_llm_pretraining/nemo/utils/preprocess.sh @@ -0,0 +1,34 @@ +#!/bin/bash +#SBATCH -N 9 +#SBATCH --gpus-per-node 1 +#SBATCH -t 04:00:00 +#SBATCH --mem=0 + +set -e + +: "${CONT_IMAGE_URL:?CONT_IMAGE_URL not set}" +: "${TOKENIZER_PATH:?TOKENIZER_PATH not set}" +: "${MERGED_C4_PATH:?MERGED_C4_PATH not set}" +: "${PREPROCESSED_PATH:?PREPROCESSED_PATH not set}" + +container_maps="${TOKENIZER_PATH}:/tokenizer,${MERGED_C4_PATH}:/dataset,${PREPROCESSED_PATH}:/outputs" + +for index in {0..7}; do + srun --nodes=1 --ntasks-per-node=1 \ + --container-image=$CONT_IMAGE_URL --container-mounts $container_maps --no-container-entrypoint \ + python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \ + --input "/dataset/c4-train.en_${index}.json.gz" \ + --output-prefix "/outputs/c4-train.en_${index}" \ + --tokenizer-library huggingface --tokenizer-type /tokenizer \ + --dataset-impl mmap --workers 128 & +done + +srun --nodes=1 --ntasks-per-node=1 \ + --container-image=$CONT_IMAGE_URL --container-mounts $container_maps --no-container-entrypoint \ + --output preprocess_outputs/dataset_preprocess_validation.out \ + python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \ + --input "/dataset/c4-validation-91205-samples.en.json.gz" \ + --output-prefix "/outputs/c4-validation-91205-samples.en" \ + --tokenizer-library huggingface --tokenizer-type /tokenizer \ + --dataset-impl mmap --workers 128 & +wait From 9c98fa60541e1528a5a35ea1331e9df6725334f9 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Fri, 15 Aug 2025 14:59:37 -0400 Subject: [PATCH 02/14] Update pretrain_llama31.py disable async save and save intermediate checkpoint --- small_llm_pretraining/nemo/pretrain_llama31.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/small_llm_pretraining/nemo/pretrain_llama31.py b/small_llm_pretraining/nemo/pretrain_llama31.py index 8b039a3bd..60f0b2d30 100644 --- a/small_llm_pretraining/nemo/pretrain_llama31.py +++ b/small_llm_pretraining/nemo/pretrain_llama31.py @@ -183,7 +183,7 @@ def get_pretrain( pretrain.log.tensorboard = None pretrain.log.ckpt.every_n_train_steps = None - pretrain.log.ckpt.save_top_k = -1 + pretrain.log.ckpt.save_top_k = 0 pretrain.log.ckpt.save_last = False pretrain.log.ckpt.always_save_context = False pretrain.log.ckpt.save_weights_only = False @@ -191,6 +191,7 @@ def get_pretrain( pretrain.log.ckpt.save_on_train_epoch_end = False pretrain.log.ckpt.monitor = "consumed_samples" pretrain.log.ckpt.mode = "max" + pretrain.trainer.strategy.async_save = False return exp_name, pretrain From 08d765e37940cc967bf23d94c667fcbe6575a462 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Sat, 16 Aug 2025 00:36:25 -0500 Subject: [PATCH 03/14] update README instruction, minor change to callback and set default target log perplexity to be 3.3 for consistency purposes --- small_llm_pretraining/nemo/README.md | 23 +++++++++++++++---- small_llm_pretraining/nemo/callbacks.py | 1 + .../nemo/pretrain_llama31.py | 16 +------------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/small_llm_pretraining/nemo/README.md b/small_llm_pretraining/nemo/README.md index f27409697..c4add5a5c 100644 --- a/small_llm_pretraining/nemo/README.md +++ b/small_llm_pretraining/nemo/README.md @@ -20,7 +20,7 @@ docker run -it --rm \ The current codebase is using the c4/en/3.0.1 dataset from [HuggingFace/AllenAI](https://huggingface.co/datasets/allenai/c4) for train and evaluation. - +You can then navigate in the terminal to your desired download directory and run the following commands to download the dataset and tokenizer: + +```bash +# Dataset: +rclone copy -P --stats 5s mlc-upload:mlcommons-training-wg-intake/dataset/llama31_8b -P + +# Tokenizer: +rclone copy -P --stats 5s mlc-upload:mlcommons-training-wg-intake/tokenizer/llama31_8b -P +``` -### Raw data downloading +## Raw data downloading [Optional] We use [AllenAI C4](https://huggingface.co/datasets/allenai/c4) dataset for this benchmark. The original zipped **`json.gz`** files can be downloaded by following AllenAI C4's instruction, and you can download our zipped customized validation dataset from the MLCommons S3 bucket by running the following command: diff --git a/small_llm_pretraining/nemo/callbacks.py b/small_llm_pretraining/nemo/callbacks.py index 8ce5669ed..680e2b673 100644 --- a/small_llm_pretraining/nemo/callbacks.py +++ b/small_llm_pretraining/nemo/callbacks.py @@ -212,6 +212,7 @@ def on_validation_start(self, trainer, pl_module): trainer.val_check_interval = self.eval_every trainer.val_check_batch = self.eval_every self.log_eval_start(trainer, pl_module) + return super().on_validation_start(trainer, pl_module) def on_validation_end(self, trainer, pl_module): mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) diff --git a/small_llm_pretraining/nemo/pretrain_llama31.py b/small_llm_pretraining/nemo/pretrain_llama31.py index 60f0b2d30..8c48f7e8d 100644 --- a/small_llm_pretraining/nemo/pretrain_llama31.py +++ b/small_llm_pretraining/nemo/pretrain_llama31.py @@ -317,7 +317,7 @@ def get_parser() -> argparse.ArgumentParser: experiment_group.add_argument("--seeds", type=int, nargs="*", default=[], help="random seeds") experiment_group.add_argument("--num_exps", type=int, default=1) experiment_group.add_argument("--num_pars", type=int, default=1) - experiment_group.add_argument("--target_log_ppl", type=float, default=5.6) + experiment_group.add_argument("--target_log_ppl", type=float, default=3.3) experiment_group.add_argument("--step_time_atol", type=int, default=1600, help="train step time atol") return parser @@ -539,20 +539,6 @@ def get_parser() -> argparse.ArgumentParser: pretrain.log.ckpt.every_n_train_steps = experiment_max_steps pretrain.log.ckpt.save_on_train_epoch_end = False - try: - print ("control C to skip") - login_info = wandb.login() - print("WandB is logged in.") - pretrain.log.extra_loggers = [ - run.Config( - WandbLogger, - project='llama3.1_8b_training', - name=f'{checkpoint_name}-gbs={args.gbs}-lr={pretrain.optim.config.lr}', - - ), - ] - except: - print("WandB is NOT logged in.") pretrain.log.extra_loggers = [ run.Config( MetricsLogger, From 7b6544d3e90656be9d7f166bedd1090453cfad7f Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Sun, 17 Aug 2025 23:04:35 -0500 Subject: [PATCH 04/14] set LR for GB32 --- small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh b/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh index 000465616..746b86f8e 100644 --- a/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh +++ b/small_llm_pretraining/nemo/config_MI325X_1x8x1_8b.sh @@ -80,7 +80,7 @@ export SIZE="8b" export GBS=32 # Dataloader: Micro batch size export MBS=4 -export MAX_LR="5e-4" +export MAX_LR="1e-3" # Dataloader: Max run N batches, optional # If an empty string is provided (""), then the training will continue until time limit # If we want to save a checkpoint, then this value must be set @@ -154,4 +154,4 @@ export USE_HIPBLASLT=1 export TORCH_BLAS_PREFER_HIPBLASLT=1 export NVTE_USE_RMSNORM_TRITON=1 -export ENABLE_TRANSPOSE_CACHE=0 \ No newline at end of file +export ENABLE_TRANSPOSE_CACHE=0 From 2b66faac030d143b21f98027f33ef5a03fec2dd5 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Wed, 20 Aug 2025 11:59:36 -0500 Subject: [PATCH 05/14] update README with download instructions from https://training.mlcommons-storage.org/index.html --- small_llm_pretraining/nemo/README.md | 32 ++++++++-------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/small_llm_pretraining/nemo/README.md b/small_llm_pretraining/nemo/README.md index c4add5a5c..e69c20c5a 100644 --- a/small_llm_pretraining/nemo/README.md +++ b/small_llm_pretraining/nemo/README.md @@ -22,33 +22,19 @@ The current codebase is using the c4/en/3.0.1 dataset from [HuggingFace/AllenAI] ## Preprocessed data download -The pre-tokenized dataset and the tokenizer are available to download from the S3 bucket. You can download this data from the bucket using RClone as follows: - -To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows). To install Rclone on Linux/macOS/BSD systems, run: - -``` -sudo -v ; curl https://rclone.org/install.sh | sudo bash -``` - -Once Rclone is installed, run the following command to authenticate with the bucket: +The pre-tokenized dataset and the tokenizer are available to download. More instructions to download on Windows are available [here](https://training.mlcommons-storage.org/index.html). You can download using the following commands: ```bash -rclone config create mlc-upload s3 \ - provider=Cloudflare \ - access_key_id=b4da5f7caeb19316de92d9e3a03346f6 \ - secret_access_key=ff25d851e53229ae1383c35cef05921a9f484a6d2183d7be62d630715a8105a9 \ - endpoint=https://c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com \ - no_check_bucket=true -``` - -You can then navigate in the terminal to your desired download directory and run the following commands to download the dataset and tokenizer: +# data +# go to the path where you want the data to be downloaded +# use the same path in config when exporting PREPROCESSED_PATH +bash <(curl -s https://raw.githubusercontent.com/mlcommons/r2-downloader/refs/heads/main/mlc-r2-downloader.sh) -d llama3_1_8b_preprocessed_c4_dataset https://training.mlcommons-storage.org/metadata/llama-3-1-8b-preprocessed-c4-dataset.uri ```bash -# Dataset: -rclone copy -P --stats 5s mlc-upload:mlcommons-training-wg-intake/dataset/llama31_8b -P - -# Tokenizer: -rclone copy -P --stats 5s mlc-upload:mlcommons-training-wg-intake/tokenizer/llama31_8b -P +# tokenizer +# go to the path where you want the tokenizer to be downloaded +# use the same path in config when exporting TOKENIZER_PATH +bash <(curl -s https://raw.githubusercontent.com/mlcommons/r2-downloader/refs/heads/main/mlc-r2-downloader.sh) -d llama3_1_8b_tokenizer https://training.mlcommons-storage.org/metadata/llama-3-1-8b-tokenizer.uri ``` ## Raw data downloading [Optional] From fed1bb42ee5380c5a37c98e8f712aa4890b4cb8f Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Wed, 20 Aug 2025 12:25:37 -0500 Subject: [PATCH 06/14] update link to https://github.com/mlcommons/r2-downloader --- small_llm_pretraining/nemo/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/small_llm_pretraining/nemo/README.md b/small_llm_pretraining/nemo/README.md index e69c20c5a..66aaf180c 100644 --- a/small_llm_pretraining/nemo/README.md +++ b/small_llm_pretraining/nemo/README.md @@ -22,7 +22,7 @@ The current codebase is using the c4/en/3.0.1 dataset from [HuggingFace/AllenAI] ## Preprocessed data download -The pre-tokenized dataset and the tokenizer are available to download. More instructions to download on Windows are available [here](https://training.mlcommons-storage.org/index.html). You can download using the following commands: +The pre-tokenized dataset and the tokenizer are available to download. More instructions to download on Windows are available [here]( https://github.com/mlcommons/r2-downloader). You can download using the following commands: ```bash # data @@ -178,4 +178,4 @@ Once the `config.sh` is properly filled, we launch a training run using the foll ```bash source config.sh bash run_llama31.sh -``` \ No newline at end of file +``` From 7f8f3ec365a6782b5adbbeb9c8f04ef91d1ede2e Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Wed, 20 Aug 2025 15:22:46 -0500 Subject: [PATCH 07/14] Revert "update link to https://github.com/mlcommons/r2-downloader" This reverts commit fed1bb42ee5380c5a37c98e8f712aa4890b4cb8f. --- small_llm_pretraining/nemo/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/small_llm_pretraining/nemo/README.md b/small_llm_pretraining/nemo/README.md index 66aaf180c..e69c20c5a 100644 --- a/small_llm_pretraining/nemo/README.md +++ b/small_llm_pretraining/nemo/README.md @@ -22,7 +22,7 @@ The current codebase is using the c4/en/3.0.1 dataset from [HuggingFace/AllenAI] ## Preprocessed data download -The pre-tokenized dataset and the tokenizer are available to download. More instructions to download on Windows are available [here]( https://github.com/mlcommons/r2-downloader). You can download using the following commands: +The pre-tokenized dataset and the tokenizer are available to download. More instructions to download on Windows are available [here](https://training.mlcommons-storage.org/index.html). You can download using the following commands: ```bash # data @@ -178,4 +178,4 @@ Once the `config.sh` is properly filled, we launch a training run using the foll ```bash source config.sh bash run_llama31.sh -``` +``` \ No newline at end of file From c5bff492f49963079ab2f420a4630be49f8ab868 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 13:27:27 -0600 Subject: [PATCH 08/14] initial commit for gpt-oss-20b --- .gitignore | 4 + gpt-oss-20b/primus/Dockerfile | 26 +++ gpt-oss-20b/primus/Dockerfile.nvidia | 31 ++++ gpt-oss-20b/primus/README.md | 165 ++++++++++++++++++ .../conf/gpt_oss_20B-pretrain-nvidia.yaml | 106 +++++++++++ .../primus/conf/gpt_oss_20B-pretrain.yaml | 121 +++++++++++++ gpt-oss-20b/primus/config_B200_1x8x1.sh | 57 ++++++ gpt-oss-20b/primus/config_MI355X_1x8x1.sh | 58 ++++++ .../primus_mllog-0.1.0-py3-none-any.whl | Bin 0 -> 5831 bytes gpt-oss-20b/primus/requirements.txt | 1 + gpt-oss-20b/primus/run_and_time.sh | 50 ++++++ gpt-oss-20b/primus/run_with_docker.sh | 134 ++++++++++++++ gpt-oss-20b/primus/src/train.py | 98 +++++++++++ 13 files changed, 851 insertions(+) create mode 100644 gpt-oss-20b/primus/Dockerfile create mode 100644 gpt-oss-20b/primus/Dockerfile.nvidia create mode 100644 gpt-oss-20b/primus/README.md create mode 100644 gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml create mode 100644 gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml create mode 100755 gpt-oss-20b/primus/config_B200_1x8x1.sh create mode 100755 gpt-oss-20b/primus/config_MI355X_1x8x1.sh create mode 100644 gpt-oss-20b/primus/primus_mllog-0.1.0-py3-none-any.whl create mode 100644 gpt-oss-20b/primus/requirements.txt create mode 100755 gpt-oss-20b/primus/run_and_time.sh create mode 100755 gpt-oss-20b/primus/run_with_docker.sh create mode 100644 gpt-oss-20b/primus/src/train.py diff --git a/.gitignore b/.gitignore index dc314947c..240309f05 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ __pycache__/ *.py[cod] *$py.class single_stage_detector/mlcube/workspace/* + +# Dev folder +dev/ +output/ \ No newline at end of file diff --git a/gpt-oss-20b/primus/Dockerfile b/gpt-oss-20b/primus/Dockerfile new file mode 100644 index 000000000..acf900d75 --- /dev/null +++ b/gpt-oss-20b/primus/Dockerfile @@ -0,0 +1,26 @@ +# ARG BASE_IMAGE=rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0 +ARG BASE_IMAGE=docker.io/rocm/primus:v25.11 +FROM ${BASE_IMAGE} + +# Set working directory +WORKDIR /workspace/deps + +# Clone Primus repository (remove existing if present in base image) +RUN rm -rf Primus && \ + git clone --recursive https://github.com/AMD-AIG-AIMA/Primus.git && \ + cd Primus && \ + # main branch with evaluation bug fix + git checkout 85c51c0da12f7d9b819f944eba9ffeb313795b9a && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt + +# Build Megatron C++ extensions (required for dataset helpers) +RUN cd /workspace/deps/Primus/third_party/Megatron-LM && \ + pip install -e . --no-deps + +WORKDIR /workspace/code +# Copy the current state of the code inside the image +COPY . . + +# Install primus-mllog from local wheel +RUN pip install primus_mllog-0.1.0-py3-none-any.whl \ No newline at end of file diff --git a/gpt-oss-20b/primus/Dockerfile.nvidia b/gpt-oss-20b/primus/Dockerfile.nvidia new file mode 100644 index 000000000..9107a3da0 --- /dev/null +++ b/gpt-oss-20b/primus/Dockerfile.nvidia @@ -0,0 +1,31 @@ +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:25.12-py3 +FROM ${BASE_IMAGE} + +WORKDIR /workspace + +RUN pip install --no-cache-dir \ + pyyaml \ + pybind11 \ + ninja \ + packaging \ + transformers + +WORKDIR /workspace/deps +RUN git clone --recursive https://github.com/AMD-AIG-AIMA/Primus.git && \ + cd Primus && \ + git checkout 6014259525111fa70e531631cf03d53a656855b9 && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt + +RUN cd /workspace/deps/Primus/third_party/Megatron-LM && \ + pip install -e . --no-deps + +ENV PYTHONPATH="/workspace/deps/Primus:/workspace/deps/Primus/third_party/Megatron-LM" + +WORKDIR /workspace/code +COPY . . + +# Install primus-mllog from local wheel +RUN pip install primus_mllog-0.1.0-py3-none-any.whl + +RUN pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 \ No newline at end of file diff --git a/gpt-oss-20b/primus/README.md b/gpt-oss-20b/primus/README.md new file mode 100644 index 000000000..c63c2447a --- /dev/null +++ b/gpt-oss-20b/primus/README.md @@ -0,0 +1,165 @@ +# GPT-OSS-20B Pretraining Benchmark + +GPT-OSS 20B (Mixture of Experts) + +## Overview + +This benchmark trains a 20B parameter GPT model with Mixture of Experts (MoE) architecture using the Primus framework on AMD GPUs. + +**Key Features:** +- 20B parameter MoE model +- Expert Parallelism (EP=8) +- FP8 hybrid precision training +- Primus Turbo optimizations (DeepEP, sync-free MoE) + +# 1. Setup Docker Image + +## Option 1: Pull Docker Image + +```bash +docker pull rocm/amd-mlperf:gpt_oss_20b_training_5.1 +``` + +## Option 2: Build Docker Image + +Run the following build command from this directory. The build process will take a while to complete. + +```bash +# From gpt-oss-20b/primus directory +docker build -t rocm/amd-mlperf:gpt_oss_20b_training_5.1 . +``` + +Or use the build script: + +```bash +bash dev/build_docker.sh +``` + +# 2. Prepare Dataset + +The current codebase uses the c4/en/3.0.1 dataset from [HuggingFace/AllenAI](https://huggingface.co/datasets/allenai/c4) for training and evaluation. + +## Download Preprocessed Data + +The pre-tokenized dataset is available for download. Navigate to your desired download directory and run the following commands: + +```bash +# Desired download directory +cd /data/gpt_oss_20b + +# Download training and validation data +bash <(curl -s https://raw.githubusercontent.com/mlcommons/r2-downloader/refs/heads/main/mlc-r2-downloader.sh) \ + -d data https://training.mlcommons-storage.org/metadata/llama-3-1-8b-preprocessed-c4-dataset.uri +``` + +After download, you should see files with the following naming conventions: +- Training: `c4-train.en_6_text_document.bin` and `.idx` +- Validation: `c4-validation-91205-samples.en_text_document.bin` and `.idx` + +The data directory is approximately **80 GB**. + +# 3. Run Training + +## Set Environment Variables + +Set the directory for data and results. Ensure `$LOGDIR` has write access. + +```bash +export DATADIR=/data/gpt_oss_20b/data +export LOGDIR=/data/gpt_oss_20b/results +export CONT=rocm/amd-mlperf:gpt_oss_20b_training_5.1 +export HF_TOKEN= + +# Create results directory +mkdir -p $LOGDIR +sudo chmod -R 777 $LOGDIR +``` + +## Set Configuration + +Set appropriate configuration and system-specific hyperparameters: + +| Config File | System | GPUs | +|-------------|--------|------| +| `config_MI355X_1x8x1.sh` | MI355X | 1 node × 8 GPUs | + +```bash +source config_MI355X_1x8x1.sh +``` + +## Launch Training + +### Single Run + +```bash +export NEXP=1 +bash run_with_docker.sh +``` + +### Multiple Runs (for submission) + +```bash +export NEXP=10 +bash run_with_docker.sh +``` + +After completion, logs will be available under `$LOGDIR`. + +# 4. Quality Metrics + +## Quality Metric + +Validation loss (log perplexity) + +## Quality Target + +Validation log perplexity = **3.3** + +## Evaluation Frequency + +Evaluation every **768 iterations** (12,288 samples with GBS=16) + +## Evaluation Thoroughness + +We evaluate using **1024 sequences** from the validation dataset. + +# 5. Model Architecture + +| Parameter | Value | +|-----------|-------| +| Model Size | 20B parameters | +| Architecture | GPT with Mixture of Experts | +| Sequence Length | 8192 | +| Precision | BF16 / FP8 hybrid | +| Expert Parallelism | 8 | + +# 6. Training Configuration + +| Hyperparameter | Value | +|----------------|-------| +| Micro Batch Size | 2 | +| Global Batch Size | 16 | +| Learning Rate | 8e-4 | +| LR Schedule | Cosine decay with warmup | +| Weight Decay | 0.1 | +| Adam β1, β2 | 0.9, 0.95 | +| Training Iterations | 20,000 | + +# 7. Directory Structure + +``` +gpt-oss-20b/primus/ +├── conf/ # Configuration files +│ └── gpt_oss_20B-pretrain.yaml +├── dev/ # Development scripts +│ ├── build_docker.sh +│ ├── run_docker.sh +│ └── run_and_time.sh +├── src/ # Training source code +│ └── train.py +├── config_MI355X_1x8x1.sh # System configuration +├── Dockerfile +└── requirements.txt # Python dependencies (includes primus-mllog) +``` + + diff --git a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml new file mode 100644 index 000000000..f16cbeab5 --- /dev/null +++ b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml @@ -0,0 +1,106 @@ +work_group: ${TEAM:nvidia} +user_name: ${USER:root} +exp_name: ${EXP_NAME:gpt_oss_20b_nvidia} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_20B}.yaml + overrides: + # log + wandb_project: "Primus_GPT_OSS_20B_NVIDIA" + stderr_sink_level: DEBUG + log_interval: 99999999 # Suppress console logs + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # profile + profile: false + use_pytorch_profiler: false + profile_step_end: 7 + profile_step_start: 6 + + # precision (mixed precision training) + # Using bf16 for B200 + bf16: true + fp16: false + fp8: null # Disabled - using bf16 instead + + # hyper parameters + train_iters: ${PRIMUS_TRAIN_ITERS:20000} + micro_batch_size: ${PRIMUS_MICRO_BATCH_SIZE:2} + global_batch_size: ${PRIMUS_GLOBAL_BATCH_SIZE:16} + seq_length: ${PRIMUS_SEQ_LENGTH:8192} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:8192} + lr: ${PRIMUS_LR:8.0e-4} + min_lr: 8.0e-5 # Set to 10% of max LR + lr_warmup_iters: ${PRIMUS_LR_WARMUP_ITERS:128} + lr_decay_iters: 1199872 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:1} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: false + train_data_path: "10 /data/c4-train.en_6_text_document" + valid_data_path: "/data/c4-validation-91205-samples.en_text_document" + test_data_path: "/data/c4-validation-91205-samples.en_text_document" + + # fusion (standard Megatron optimizations) + moe_permute_fusion: false + gradient_accumulation_fusion: false + moe_grouped_gemm: false # Disable grouped_gemm requirement + moe_use_legacy_grouped_gemm: false + moe_use_fused_router_with_aux_score: false + multi_latent_attention: false + apply_rope_fusion: false + + # MoE router configuration + moe_shared_expert_overlap: false + moe_router_dtype: fp32 + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + exit_on_missing_checkpoint: false + ckpt_format: torch + eval_iters: 64 # 64 iters × 2 MBS × 8 GPUs = 1024 eval samples + eval_interval: ${PRIMUS_EVAL_INTERVAL:768} + + # Turbo features disabled for NVIDIA + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + + use_turbo_deepep: false + turbo_deepep_num_cu: 0 + turbo_deepep_use_comm_stream: false + + turbo_sync_free_moe_stage: 0 + diff --git a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml new file mode 100644 index 000000000..d9371892e --- /dev/null +++ b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml @@ -0,0 +1,121 @@ +work_group: ${TEAM:amd} +user_name: ${USER:root} +exp_name: ${EXP_NAME:gpt_oss_20b} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_20B}.yaml + overrides: + # log + wandb_project: "Primus_DeepSeek_Pretrain" + stderr_sink_level: DEBUG + log_interval: 99999999 # Suppress console logs + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # profile + profile: false + use_pytorch_profiler: false + profile_step_end: 7 + profile_step_start: 6 + + # precision (mixed precision training) + bf16: true + fp16: false + fp8: null # Set to "e4m3" or "hybrid" for FP8 training + + # hyper parameters + train_iters: ${PRIMUS_TRAIN_ITERS:20} + micro_batch_size: ${PRIMUS_MICRO_BATCH_SIZE:4} + global_batch_size: ${PRIMUS_GLOBAL_BATCH_SIZE:640} + seq_length: ${PRIMUS_SEQ_LENGTH:8192} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:8192} + lr: ${PRIMUS_LR:1.0e-5} + min_lr: 1.0e-5 # Set to 10% of max LR like Llama 8B (8e-5 is 10% of 8e-4) + lr_warmup_iters: ${PRIMUS_LR_WARMUP_ITERS:128} # Match Llama 8B warmup + lr_decay_iters: 1199872 # Match Llama 8B decay schedule (75M samples / 640 GBS = 117187 iters, but using their value) + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:1} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: false + train_data_path: "10 /data/c4-train.en_6_text_document" + valid_data_path: "/data/c4-validation-91205-samples.en_text_document" + test_data_path: "/data/c4-validation-91205-samples.en_text_document" + + # fusion + # 20250321: need latest megatron docker image + moe_permute_fusion: false + # fused wgrad gemm and accumulation + gradient_accumulation_fusion: false + # recommend set `false` in fp8 + moe_use_legacy_grouped_gemm: true + # fused topk router with aux score + moe_use_fused_router_with_aux_score: false + # MLA + multi_latent_attention: false + # rope fusion + apply_rope_fusion: false + + # DeepEP does not support moe_shared_expert_overlap + moe_shared_expert_overlap: false + # DeepEP only supports float32 probs + moe_router_dtype: fp32 + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + exit_on_missing_checkpoint: false + ckpt_format: torch + eval_iters: 64 # 64 iters × 2 MBS × 8 GPUs = 1024 eval samples (matching Llama 8B) + eval_interval: ${PRIMUS_EVAL_INTERVAL:5} + + # Turbo + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: true + + # deepep + use_turbo_deepep: true + + # 64 or 80 for ep8, 32 for ep16-64 is best practice + turbo_deepep_num_cu: 64 + turbo_deepep_use_comm_stream: false + + # sync-free moe support stage 0-3, 0 means not use sync-free moe + # stage 3 is completely no gpu-cpu sync in MoE, but cost more memory + # stage 2 is recommended for better performance + turbo_sync_free_moe_stage: 2 + + # Cross entropy flags + # cross_entropy_fusion_impl: "te" + # cross_entropy_loss_fusion: true + diff --git a/gpt-oss-20b/primus/config_B200_1x8x1.sh b/gpt-oss-20b/primus/config_B200_1x8x1.sh new file mode 100755 index 000000000..12f4ad2d1 --- /dev/null +++ b/gpt-oss-20b/primus/config_B200_1x8x1.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +export DGXSYSTEM=B200_1x8x1 +export GPUS_PER_NODE=8 +export NNODES=1 +export NODE_RANK=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29501 + +export PRIMUS_PATH=/workspace/deps/Primus +export PYTHONPATH="${PRIMUS_PATH}:${PRIMUS_PATH}/third_party/Megatron-LM:${PYTHONPATH}" +export EXP=/workspace/code/conf/gpt_oss_20B-pretrain-nvidia.yaml +export DATA_PATH=/data + +export PRIMUS_MICRO_BATCH_SIZE=2 +export PRIMUS_GLOBAL_BATCH_SIZE=16 +export PRIMUS_LR=8e-4 +export PRIMUS_TRAIN_ITERS=20000 +export SEED=30279 + +# Evaluation frequency (sample-based, adjusts automatically with GBS) +export EVAL_SAMPLES_INTERVAL=12288 # Evaluate every 12,288 samples +export PRIMUS_EVAL_INTERVAL=$((EVAL_SAMPLES_INTERVAL / PRIMUS_GLOBAL_BATCH_SIZE)) # Auto-computed + +export PRIMUS_BF16=true +export PRIMUS_FP16=false +export PRIMUS_FP8=null + +export PRIMUS_TURBO_ENABLED=false +export USE_TURBO_ATTENTION=false +export USE_TURBO_GROUPED_MLP=false +export USE_TURBO_DEEPEP=false +export TURBO_DEEPEP_NUM_CU=0 +export TURBO_SYNC_FREE_MOE_STAGE=0 + +export PRIMUS_APPLY_ROPE_FUSION=false +export USE_ROCM_MEM_INFO=false + +export OVERLAP_GRAD_REDUCE=true +export OVERLAP_PARAM_GATHER=true +export GRADIENT_ACCUMULATION_FUSION=false + +export PRIMUS_TP=1 +export PRIMUS_PP=1 +export PRIMUS_EP=8 + +export ENABLE_MLLOG=1 +export MLLOG_OUTPUT_FILE=/results/mlperf_output.log +export MLLOG_TRAIN_LOSS_LOG_FREQ=1 +export MLLOG_TARGET_EVAL_LOSS=3.3 +export MLLOG_SUBMISSION_BENCHMARK=gpt-oss-20b +export MLLOG_SUBMISSION_DIVISION=closed +export MLLOG_SUBMISSION_ORG=NVIDIA +export MLLOG_SUBMISSION_PLATFORM=B200 + +export HF_TOKEN="${HF_TOKEN:-}" + diff --git a/gpt-oss-20b/primus/config_MI355X_1x8x1.sh b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh new file mode 100755 index 000000000..ee0706594 --- /dev/null +++ b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# ============================================================================= +# MLPerf GPT-OSS-20B Configuration for MI355X (1 node, 8 GPUs) +# ============================================================================= + +# ----------------------------------------------------------------------------- +# System Configuration +# ----------------------------------------------------------------------------- +export DGXSYSTEM=MI355X_1x8x1 +export GPUS_PER_NODE=8 +export NNODES=1 +export NODE_RANK=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29501 + +# ----------------------------------------------------------------------------- +# Paths +# ----------------------------------------------------------------------------- +export PRIMUS_PATH=/workspace/deps/Primus +export PYTHONPATH="${PRIMUS_PATH}:${PRIMUS_PATH}/third_party/Megatron-LM:${PYTHONPATH}" +export EXP=/workspace/code/conf/gpt_oss_20B-pretrain.yaml +export DATA_PATH=/data + +# ----------------------------------------------------------------------------- +# Training Hyperparameters +# ----------------------------------------------------------------------------- +export PRIMUS_MICRO_BATCH_SIZE=2 +export PRIMUS_GLOBAL_BATCH_SIZE=16 +export PRIMUS_LR=8e-4 +export PRIMUS_TRAIN_ITERS=20000 # 20K iters × 16 GBS = 320K samples +export SEED=30279 + +# Evaluation frequency (sample-based, adjusts automatically with GBS) +export EVAL_SAMPLES_INTERVAL=12288 # Evaluate every 12,288 samples +export PRIMUS_EVAL_INTERVAL=$((EVAL_SAMPLES_INTERVAL / PRIMUS_GLOBAL_BATCH_SIZE)) # Auto-computed + +# ----------------------------------------------------------------------------- +# Optimizations +# ----------------------------------------------------------------------------- +export PRIMUS_APPLY_ROPE_FUSION=True +export PRIMUS_FP8_RECIPE=hybrid + +# ----------------------------------------------------------------------------- +# MLPerf Logging +# ----------------------------------------------------------------------------- +export ENABLE_MLLOG=1 +export MLLOG_OUTPUT_FILE=/results/mlperf_output.log +export MLLOG_TRAIN_LOSS_LOG_FREQ=100 +export MLLOG_TARGET_EVAL_LOSS=3.3 +export MLLOG_SUBMISSION_BENCHMARK=gpt-oss-20b +export MLLOG_SUBMISSION_DIVISION=closed +export MLLOG_SUBMISSION_ORG=AMD +export MLLOG_SUBMISSION_PLATFORM=MI355X + +# ----------------------------------------------------------------------------- +# TE Configuration +# ----------------------------------------------------------------------------- +export NVTE_ROCM_ENABLE_MXFP8=0 diff --git a/gpt-oss-20b/primus/primus_mllog-0.1.0-py3-none-any.whl b/gpt-oss-20b/primus/primus_mllog-0.1.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..75f0a035deea8fee6c50b72d7a1c256016f01c79 GIT binary patch literal 5831 zcmai&bySq=+Qx?%>5`Cc0f`|bML;?w2T2)18ipDqq=sYq9uD*BoM~n^t zP~NT1uC|WuZV*QY2PbP@2*eg@3xhyF&R*p)680T@Wbz(2;Smz~Kns-|3q-6H_O~u# zrvhvyse!!eS6Nayw_mI0WX5xK78pm?-(7jydXJMliZ;72wQt1nxTL>A+*rJot>LIF zk+v1WdOB_{6h(zJ*%TE;Jg00Eyw5waX71QXoGZ$p!J9%l+v==!%+hvwu{-%HC}t^& zh*&uu>~Q5JX8`Wx2{(;AyO%U}(6334#$-q+alAkf7Yn-IS#?9~c{vr%Goa}~Yc)=A z0-npdFr-ag>+HX*yH(-sD-H{)Ocl;C$u7owl;ozpf>Vj^^15Yuzq)d7cr5#3aR<8@ z=Lkph{}sB3d~D20vAVu_4*)oL007Yau8^aHv!$yQM@y7V2Lm_>+xAuFpu z3<0D}td0fr_t_n>8)QR>(Y8w@GP!H)Tcc0~pHhsy8*djrHaG3jVAg3xkeTnL^$bV$ ztbQCunw*V`%69O1Q#jP<7O#IA)jilLB6R&ihhw~x`^Zf`6uK4N0De6p>fF@3OiKj} z!%PD26;C_b<$fD&z0`Ll(mm`~%pqXofZ{vF0Z z6Pzq@8wy)-Y$ri#5c~Zcb`-em>^u0B23s=VMmd%v51oXh8;*}QkKoIQB+-Ez*IlyKSX;@`}RkB0L6?|`dcVgIaLvf)|_O#3api! zW=3K2AnLJH2FQ6U#osiku3de4tOm2QC?<*@g$HJRfz_Dd~+dlt8eEQS-UH(HF zqb}S}vUPXA98{@C56_?Tp*Xs#lQJ|005sREl0V;&Dcmd4JjAERe`9qa6(#1?Qq>BD_W0iG6L&^cz z*BB2UCza(AbbE6mANAaKCGez($JSCw<^Co;RH{}J1e){OoNfbCa)*{U^p_rzAE-_Ezgt%xNTZJZ8S2s7WGM$ zRC~AX`5N(HVDu&rf}g}JGtLhAnD5_B)^CbW-mT!6(YH2|d7212WXa4$w*qbNUL845&fuyNP753$^|kla3FlwMeWzDLFD19EVZ^SeuK^K zG{5%g_XR<`SL>-0(m{G79yro7yI<|ILt*Ze+x^VV;y}D)kEdz#lR=|OJA<;r{*q|h z<2KjJtM?Zp-y)sDjIxTqUYo`=t~Jz}F4=w|>9lBMb`^_o28d)^g&?TL0@Ic1rz_y_ zabR9#DBzgyonm+)-a!o@MX%e2ET8{7_qC`16Uzt8&q_cw+_8#s7Oh}G3hPkiP}XMQ zgo3{E3CH*hssI+;t&9boqM!?udBxCT%1S`?c}6Of+&XC|ZM<5-N^!G+VJ&_xi)ecM z$UN7r`)mt(;j%O655Hk`hpPp>{2-R8X2sx~nX2AB6oi&M{+a6b$yC(p=aPxOs)&Lb zAAMS^IstZB>6ypt@tuqyEO zN@oZ$(Dc=F>i7Y8}){@6U z)?7vWO{_jtp3>=)R$21kaI!ncXZ z^*7(G3N(I(J81P9lUDf!_DS)+>M+e13)#lGe^p%udpQq5%!bT*spm3^R7y=rqxsHW zS#`;xK5BC6ZSVNMywUa3cvJ%_q1fJbp0AW!F%zMXVPFr_nreCSszw!=j*uhg#gJ{E zPnD&R_*{*>GF_l_5;TAp_c2FOZM;0$mXLM8yh|^)7N30CD!WJWOmpWpndT}TIJeHZ zg?|1W(;dezu6P$qFo*hHR@7{{pGq){z+E@18PDBY5mFD6+hXCs)?(d?4*58kk&Mz< z@7FVCXCu&$o1$v8WKJW>d@p5gbHJci{am-1KiMJ&q#q4*>6>@g5Dgis;;ka>VmZ-P zXSU61#ckgd8brYzy;{u%ma~kO^|U}INNVdXren3(LYzM6*UG8B9uMuTg9xvsSUfkd zU43MfI5tA?@ED$OslK>)dSW!nd%z_!ecDyU_&Tl~Y%<%}tg!0zM}<7PtB`EUx^X<% z0KhOS0KoD;Dx|ZkB?RVbcGpipt^ZLc4UC*tg(>_{l}p}LRZlf19pJt=Q+S^%iU_7i zi$2!0n)5LV3hdjksB+3%rqivpplIV8;}UrjU@wTmxHD;5+?JO|Gej!A2^L}vlBJG3 zJ$OSOMD-Cpkh-2{A0M(vRdA49=ph;I)!5^Y)?ZHeAc)&5s2uu=1KS$78sk8UuQ?XZ zU;R0Ks8Y|ZdPr0`;1LgHn9fBVk^H-YXBYi=l=z4aR7A$l{iXwbYCn|`n@@Pa;V~K{ zCebP^%Ab-bz?Cl)<<$-B3*S-id?Wnr*g>OHhrUQWFI!{02=1MA;t8Q>lblP$;X=>t z+OJDz+>G2okh{g;7@AgTxOUeI_PuVR7+$$Y4_bSiU3TF0;>V8@?XtFKOSE#z`ef;) zieINse0;D6Xv836sL?R<jCkxDa|WztYBE@>ZXwFvq|L5qx#ei)6SP zx1(QxL6xtzOZixi|de5o-Jb;;L%#t@X z8_v4frKTs*o~`x?ux5KNMJ$tNW85gam04BT;3Ph(#c*lzY%;w95b6Gqh)(E-f|$a) z$V2F};c6JofeJD-I5BwpZSN$ySe#uGQ`upOT` zboOJR`bPp1h9WIygWtH4l92kl`}1}E4{DCT)sfr+FM7y5C(~{GG0(}ZELcvTc<+S)4gjzOO(-;O{(avI(jS8u1woC?a zQX}f+I8@dgjcXb+hBiKsht>NQWJu3L`PKZ57b-TT=)%s_g&P9RyjDGI&eejeeugCQ zX7)_Z)_ojDiA4MiM<#nr@2cXF-4gG5H-0sGX@~tJc=)XHy-{G~t{K+4nJ&JXJ#xVQ zk&d-68CNJNF{{3My(RTe%tnRe3)TUqYsRMyUkbc>J0y5Uo{#9by501v%O%Mqt>*TA z78G7cC34G>wl*07qv&^-~GSwMJx#@yDoV-#Pi(VowtVJRCBSQ6p!kOLQ= zVIYnWZ&a|(pCRb%Np=)Ec{IoDw0zsZcy&=7d%p%do+8FHwG6~%>TilgNb2csu5e74 zNgLyeGDHL4MOL%I^D)jc_fC~vvXgKq&)!5f8tV2LVwhCPQs!BnJW)Bv%#N%o5x@rV zAnK8n_x2HnBZ?7t;6;Qy2PTRWO_IuNQ9@NX{lWA)G&?F_2=bGUR}C0&Y4v| z+WqG_vMaa`#I<8EWBY>^m;+PZzb|fx79yKUXXQJV!^I!PPzmw*C?)sb{PNu8+wM1eFmOKkv``0zI09M#f%EJPmL9yNbPy*fqIq%7QC=J zpD|~|ub`*jV^XAhJw|8OfnQcsGc&&rJI@|pnzGln{`^To5mmh7-O#}e6O{RpJ1EsB z@65zi5w;}72Bq*IzyOBUH_dp|EKQ0}w`r;HrbG%kiC^bzxHI$_Ru8DKL`eKx5xhwQ z)W77VB(?e6zECR4NUBf@Ub%GG%r+EJo9alk^&os=-Wh=7=~3 z?n7ji_y(8a6%dGh*x-Aw%&GLIpx}jngLBp?*h$|qRfa^io=u9kiZ&h`AawkTx8vYf zZKV;e=lr;ZC})R3jj^-x`}N8j2-d9t(}N9UkP@mhHm6HN)Wlbqq5iC%GAE*5 z{6bRR#+yV{pJhWx+h-+Apr=w)**Ns0OJUtmf|R7Oj{{!C2Bu>PQdv4i)Qw{#r$YuL zf>-f66j^U@{U!g{x2Hy+M=KPjl({LrTRxh)``wO2sTVQQdArF*?ZRUuSse1Dlc0m9-$e$Y?j$ z;0BwC36|R&I(@~}-89=3Cz*P?kvlDQ*}84WQi0O+iAT(u*d+_s!gY|B0(OEJ(L*uT z8Sh<|uxn%uOO+#uy70bquaO__1Jh}0snCS zuFD8S0{|ZW)@|_cf%rjupjWnTFdkc|l@srCd9bvcG*~(+R@^D*8FBE{6**UtWMBeV z?sFSiT&pyjihrPx$vXWS)Mc<`*v!C7x1=R)KCL89w^F}P?-c%SXbI#+EN;mcG@6hW zC}`Lu?Q^sQ=yEUE|F|LRR ze^SfjXHSyr&DsCbC|u&J#u|UG)yC>EEMbv-7wo_f%TkDzeznZ~-twnF%;=Ek#$4)F zxd+4fTuVVK6{SR;RIfeesHaDs zZKPsXO1Of@E$tHNpY?)m^zRwp>11^GG5?)jdP?&0YQQ#r+zvi`;K<9Ok#B?ogfFQf zIWR7!+8VrULDPK30a@rGN>VDPk7mJRyj`dNG@#u+^}&x&mtcckpN37rBAX~iS|-(4<+Mq>{8`ogZ}yp$DxJGQ+K_Siqf8}iO9MxE+$a0S8s%slAKN z)0t}DRKpeHPkq~z3As{}_*hJXwdI?m^P&tM>b0GHLb;$Szv5YlaNH2Ghy4 z{oSc3ho0y?L^LtAg(sh+)<~$;ST7LCqe`6~K|FAgPSs%3OiOfuE=cMd!{9qT|MZnB z#Loua9&71g9xYJQ)U@A3e#9NdJKc;D$EK5@l_x!G)&7b=B$FyQB`OC|(gW?)R*yMg z-%sC)#aKR}qvVB_Y89McwViz6NCg*na#<22-hQw4NLgM$TUsa5R7FE4!}#Ar$U7(e z^QndZWBFer$^WGM)u8TQ3;>`O&V1*EKT&=k+x;i%uWmGdpyclEIe&%v?SAv0z`r`Q z{te8BR}=pk_+S3~_Yw9x^7mZ-8(9lqA^98RKU4lY@AoYJ!SlQSH@rV3^LOs=Vf=$T d`FGrZMN?B1csG6p0I=>}&|RW~{$U)z{{e6bWa PrimusConfig: + """ + Load and parse the experiment YAML configuration. + + The config file (e.g., gpt_oss_20B-pretrain.yaml) defines: + - Model architecture (hidden size, num layers, attention heads, etc.) + - Training hyperparameters (batch size, learning rate, etc.) + - Data paths and tokenizer settings + - Parallelism settings (TP, PP, EP) + """ + # Create args namespace for Primus config loader + parser = argparse.ArgumentParser() + add_pretrain_parser(parser) + + args = parser.parse_args([ + '--config', config_path, + '--data_path', os.getenv('DATA_PATH', '/data'), + ]) + + primus_cfg, unknown_overrides = load_primus_config(args, overrides or []) + + print(f"[MLPerf Train] Loaded config from: {config_path}") + print(f"[MLPerf Train] Framework: {primus_cfg.get_module_config('pre_trainer').framework}") + + return primus_cfg, unknown_overrides + + +def create_trainer(primus_cfg: PrimusConfig, extra_args: list = None) -> MLPerfMegatronPretrainTrainer: + """ + Create the MLPerf-enabled Megatron trainer. + + The trainer handles: + - Model creation (GPT architecture with MoE) + - Optimizer setup (Adam with configurable betas) + - Learning rate scheduling (warmup + cosine decay) + - Distributed training coordination + - MLPerf logging and metrics + """ + # Get distributed training configuration from environment + # These are set by torchrun when launching distributed training + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = int(os.getenv("MASTER_PORT", "29500")) + + trainer = MLPerfMegatronPretrainTrainer( + module_name="pre_trainer", + primus_config=primus_cfg, + module_rank=rank, + module_world_size=world_size, + module_master_addr=master_addr, + module_master_port=master_port, + extra_args=extra_args, + ) + return trainer + +def main(): + config_path = os.environ.get("EXP", "/workspace/code/conf/gpt_oss_20B-pretrain.yaml") + + if not Path(config_path).exists(): + raise FileNotFoundError(f"Config not found: {config_path}") + + setup_environment(data_path=os.getenv('DATA_PATH', '/data')) + primus_cfg, extra_args = load_config(config_path) + + trainer = create_trainer(primus_cfg, extra_args) + trainer.init() + trainer.run() + +if __name__ == "__main__": + main() From b1832049fdf8e121a785444e8329233555e6bc70 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 13:33:55 -0600 Subject: [PATCH 09/14] update README --- gpt-oss-20b/primus/README.md | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/gpt-oss-20b/primus/README.md b/gpt-oss-20b/primus/README.md index c63c2447a..d1c866e85 100644 --- a/gpt-oss-20b/primus/README.md +++ b/gpt-oss-20b/primus/README.md @@ -4,23 +4,10 @@ GPT-OSS 20B (Mixture of Experts) ## Overview -This benchmark trains a 20B parameter GPT model with Mixture of Experts (MoE) architecture using the Primus framework on AMD GPUs. - -**Key Features:** -- 20B parameter MoE model -- Expert Parallelism (EP=8) -- FP8 hybrid precision training -- Primus Turbo optimizations (DeepEP, sync-free MoE) +This benchmark trains a 20B parameter GPT model with Mixture of Experts (MoE) architecture using the Primus framework on AMD and NVIDIA GPUs. # 1. Setup Docker Image -## Option 1: Pull Docker Image - -```bash -docker pull rocm/amd-mlperf:gpt_oss_20b_training_5.1 -``` - -## Option 2: Build Docker Image Run the following build command from this directory. The build process will take a while to complete. @@ -29,12 +16,6 @@ Run the following build command from this directory. The build process will take docker build -t rocm/amd-mlperf:gpt_oss_20b_training_5.1 . ``` -Or use the build script: - -```bash -bash dev/build_docker.sh -``` - # 2. Prepare Dataset The current codebase uses the c4/en/3.0.1 dataset from [HuggingFace/AllenAI](https://huggingface.co/datasets/allenai/c4) for training and evaluation. @@ -44,7 +25,7 @@ The current codebase uses the c4/en/3.0.1 dataset from [HuggingFace/AllenAI](htt The pre-tokenized dataset is available for download. Navigate to your desired download directory and run the following commands: ```bash -# Desired download directory +# Create desired download directory with the right permission cd /data/gpt_oss_20b # Download training and validation data @@ -130,7 +111,6 @@ We evaluate using **1024 sequences** from the validation dataset. | Model Size | 20B parameters | | Architecture | GPT with Mixture of Experts | | Sequence Length | 8192 | -| Precision | BF16 / FP8 hybrid | | Expert Parallelism | 8 | # 6. Training Configuration @@ -151,14 +131,12 @@ We evaluate using **1024 sequences** from the validation dataset. gpt-oss-20b/primus/ ├── conf/ # Configuration files │ └── gpt_oss_20B-pretrain.yaml -├── dev/ # Development scripts -│ ├── build_docker.sh -│ ├── run_docker.sh -│ └── run_and_time.sh ├── src/ # Training source code │ └── train.py -├── config_MI355X_1x8x1.sh # System configuration -├── Dockerfile +├── config_MI355X_1x8x1.sh # System configuration (MI355 - AMD) +├── config_B200_1x8x1.sh # System configuration (B200 - NVIDIA) +├── Dockerfile # Dockerfile (MI355 - AMD) +├── Dockerfile.nvidia # Dockerfile (B200 - NVIDIA) └── requirements.txt # Python dependencies (includes primus-mllog) ``` From a1a6356c80f704f77f7d9c0ee7e0d349aff912a4 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 14:07:52 -0600 Subject: [PATCH 10/14] update NV dockerfile and readme --- gpt-oss-20b/primus/Dockerfile.nvidia | 2 +- gpt-oss-20b/primus/README.md | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gpt-oss-20b/primus/Dockerfile.nvidia b/gpt-oss-20b/primus/Dockerfile.nvidia index 9107a3da0..e3bc50613 100644 --- a/gpt-oss-20b/primus/Dockerfile.nvidia +++ b/gpt-oss-20b/primus/Dockerfile.nvidia @@ -13,7 +13,7 @@ RUN pip install --no-cache-dir \ WORKDIR /workspace/deps RUN git clone --recursive https://github.com/AMD-AIG-AIMA/Primus.git && \ cd Primus && \ - git checkout 6014259525111fa70e531631cf03d53a656855b9 && \ + git checkout 85c51c0da12f7d9b819f944eba9ffeb313795b9a && \ git submodule update --init --recursive && \ pip install -r requirements.txt diff --git a/gpt-oss-20b/primus/README.md b/gpt-oss-20b/primus/README.md index d1c866e85..1fb6cef64 100644 --- a/gpt-oss-20b/primus/README.md +++ b/gpt-oss-20b/primus/README.md @@ -58,11 +58,12 @@ sudo chmod -R 777 $LOGDIR ## Set Configuration -Set appropriate configuration and system-specific hyperparameters: +Set appropriate configuration and system-specific hyperparameters based on hardware type: | Config File | System | GPUs | |-------------|--------|------| | `config_MI355X_1x8x1.sh` | MI355X | 1 node × 8 GPUs | +| `config_B200_1x8x1.sh` | B200 | 1 node × 8 GPUs | ```bash source config_MI355X_1x8x1.sh @@ -138,6 +139,4 @@ gpt-oss-20b/primus/ ├── Dockerfile # Dockerfile (MI355 - AMD) ├── Dockerfile.nvidia # Dockerfile (B200 - NVIDIA) └── requirements.txt # Python dependencies (includes primus-mllog) -``` - - +``` \ No newline at end of file From 77eaf43e56e79eac24d869e4544923b8c5b3f664 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 19:37:54 -0600 Subject: [PATCH 11/14] set random seed --- gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml | 1 + gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml | 1 + gpt-oss-20b/primus/config_B200_1x8x1.sh | 1 - gpt-oss-20b/primus/config_MI355X_1x8x1.sh | 1 - 4 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml index f16cbeab5..a938da5f6 100644 --- a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml +++ b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml @@ -39,6 +39,7 @@ modules: global_batch_size: ${PRIMUS_GLOBAL_BATCH_SIZE:16} seq_length: ${PRIMUS_SEQ_LENGTH:8192} max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:8192} + seed: ${SEED:1234} lr: ${PRIMUS_LR:8.0e-4} min_lr: 8.0e-5 # Set to 10% of max LR lr_warmup_iters: ${PRIMUS_LR_WARMUP_ITERS:128} diff --git a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml index d9371892e..97e8e7fbf 100644 --- a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml +++ b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml @@ -38,6 +38,7 @@ modules: global_batch_size: ${PRIMUS_GLOBAL_BATCH_SIZE:640} seq_length: ${PRIMUS_SEQ_LENGTH:8192} max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:8192} + seed: ${SEED:1234} lr: ${PRIMUS_LR:1.0e-5} min_lr: 1.0e-5 # Set to 10% of max LR like Llama 8B (8e-5 is 10% of 8e-4) lr_warmup_iters: ${PRIMUS_LR_WARMUP_ITERS:128} # Match Llama 8B warmup diff --git a/gpt-oss-20b/primus/config_B200_1x8x1.sh b/gpt-oss-20b/primus/config_B200_1x8x1.sh index 12f4ad2d1..67cb0248b 100755 --- a/gpt-oss-20b/primus/config_B200_1x8x1.sh +++ b/gpt-oss-20b/primus/config_B200_1x8x1.sh @@ -16,7 +16,6 @@ export PRIMUS_MICRO_BATCH_SIZE=2 export PRIMUS_GLOBAL_BATCH_SIZE=16 export PRIMUS_LR=8e-4 export PRIMUS_TRAIN_ITERS=20000 -export SEED=30279 # Evaluation frequency (sample-based, adjusts automatically with GBS) export EVAL_SAMPLES_INTERVAL=12288 # Evaluate every 12,288 samples diff --git a/gpt-oss-20b/primus/config_MI355X_1x8x1.sh b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh index ee0706594..82f7d0deb 100755 --- a/gpt-oss-20b/primus/config_MI355X_1x8x1.sh +++ b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh @@ -28,7 +28,6 @@ export PRIMUS_MICRO_BATCH_SIZE=2 export PRIMUS_GLOBAL_BATCH_SIZE=16 export PRIMUS_LR=8e-4 export PRIMUS_TRAIN_ITERS=20000 # 20K iters × 16 GBS = 320K samples -export SEED=30279 # Evaluation frequency (sample-based, adjusts automatically with GBS) export EVAL_SAMPLES_INTERVAL=12288 # Evaluate every 12,288 samples From aa7cbd721e90ac9c6054d87c84b7db6e1d0e0ca4 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 21:47:30 -0600 Subject: [PATCH 12/14] update license for amd + clean up rocm dockerfile --- gpt-oss-20b/primus/Dockerfile | 32 ++++++++++++++++++----- gpt-oss-20b/primus/config_MI355X_1x8x1.sh | 23 ++++++++++++++++ gpt-oss-20b/primus/run_and_time.sh | 24 ++++++++++++++++- gpt-oss-20b/primus/run_with_docker.sh | 31 ++++++++++++++-------- gpt-oss-20b/primus/src/train.py | 24 +++++++++++++++++ 5 files changed, 115 insertions(+), 19 deletions(-) diff --git a/gpt-oss-20b/primus/Dockerfile b/gpt-oss-20b/primus/Dockerfile index acf900d75..7e7c24971 100644 --- a/gpt-oss-20b/primus/Dockerfile +++ b/gpt-oss-20b/primus/Dockerfile @@ -1,26 +1,44 @@ -# ARG BASE_IMAGE=rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0 +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + ARG BASE_IMAGE=docker.io/rocm/primus:v25.11 FROM ${BASE_IMAGE} -# Set working directory WORKDIR /workspace/deps -# Clone Primus repository (remove existing if present in base image) RUN rm -rf Primus && \ git clone --recursive https://github.com/AMD-AIG-AIMA/Primus.git && \ cd Primus && \ - # main branch with evaluation bug fix git checkout 85c51c0da12f7d9b819f944eba9ffeb313795b9a && \ git submodule update --init --recursive && \ pip install -r requirements.txt -# Build Megatron C++ extensions (required for dataset helpers) RUN cd /workspace/deps/Primus/third_party/Megatron-LM && \ pip install -e . --no-deps WORKDIR /workspace/code -# Copy the current state of the code inside the image + COPY . . -# Install primus-mllog from local wheel RUN pip install primus_mllog-0.1.0-py3-none-any.whl \ No newline at end of file diff --git a/gpt-oss-20b/primus/config_MI355X_1x8x1.sh b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh index 82f7d0deb..260446a23 100755 --- a/gpt-oss-20b/primus/config_MI355X_1x8x1.sh +++ b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh @@ -1,4 +1,27 @@ #!/bin/bash +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# # ============================================================================= # MLPerf GPT-OSS-20B Configuration for MI355X (1 node, 8 GPUs) # ============================================================================= diff --git a/gpt-oss-20b/primus/run_and_time.sh b/gpt-oss-20b/primus/run_and_time.sh index e9ab83a06..e67606dfe 100755 --- a/gpt-oss-20b/primus/run_and_time.sh +++ b/gpt-oss-20b/primus/run_and_time.sh @@ -1,5 +1,27 @@ #!/bin/bash - +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# set -e # Create results directory diff --git a/gpt-oss-20b/primus/run_with_docker.sh b/gpt-oss-20b/primus/run_with_docker.sh index bb192abad..35bc4104a 100755 --- a/gpt-oss-20b/primus/run_with_docker.sh +++ b/gpt-oss-20b/primus/run_with_docker.sh @@ -1,18 +1,27 @@ #!/bin/bash - -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. # -# http://www.apache.org/licenses/LICENSE-2.0 +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. set -euxo pipefail diff --git a/gpt-oss-20b/primus/src/train.py b/gpt-oss-20b/primus/src/train.py index 84819c1a5..e0905cb5b 100644 --- a/gpt-oss-20b/primus/src/train.py +++ b/gpt-oss-20b/primus/src/train.py @@ -1,3 +1,27 @@ +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + import os import sys from pathlib import Path From c110d5dbad4f14ee2759e19277054008a88b549f Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 21:50:27 -0600 Subject: [PATCH 13/14] revisit the target log perplexity after establishing rcp --- gpt-oss-20b/primus/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt-oss-20b/primus/README.md b/gpt-oss-20b/primus/README.md index 1fb6cef64..849f2b712 100644 --- a/gpt-oss-20b/primus/README.md +++ b/gpt-oss-20b/primus/README.md @@ -95,7 +95,7 @@ Validation loss (log perplexity) ## Quality Target -Validation log perplexity = **3.3** +Validation log perplexity = **3.3** [TO BE REVISITED] ## Evaluation Frequency From f856c7f57a81602ecc3feba9b81d40f571693fc7 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 19 Jan 2026 21:51:45 -0600 Subject: [PATCH 14/14] remove target metric --- gpt-oss-20b/primus/README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gpt-oss-20b/primus/README.md b/gpt-oss-20b/primus/README.md index 849f2b712..d51af4ca3 100644 --- a/gpt-oss-20b/primus/README.md +++ b/gpt-oss-20b/primus/README.md @@ -93,10 +93,6 @@ After completion, logs will be available under `$LOGDIR`. Validation loss (log perplexity) -## Quality Target - -Validation log perplexity = **3.3** [TO BE REVISITED] - ## Evaluation Frequency Evaluation every **768 iterations** (12,288 samples with GBS=16)