diff --git a/.gitignore b/.gitignore index dc314947c..240309f05 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ __pycache__/ *.py[cod] *$py.class single_stage_detector/mlcube/workspace/* + +# Dev folder +dev/ +output/ \ No newline at end of file diff --git a/gpt-oss-20b/primus/Dockerfile b/gpt-oss-20b/primus/Dockerfile new file mode 100644 index 000000000..7e7c24971 --- /dev/null +++ b/gpt-oss-20b/primus/Dockerfile @@ -0,0 +1,44 @@ +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +ARG BASE_IMAGE=docker.io/rocm/primus:v25.11 +FROM ${BASE_IMAGE} + +WORKDIR /workspace/deps + +RUN rm -rf Primus && \ + git clone --recursive https://github.com/AMD-AIG-AIMA/Primus.git && \ + cd Primus && \ + git checkout 85c51c0da12f7d9b819f944eba9ffeb313795b9a && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt + +RUN cd /workspace/deps/Primus/third_party/Megatron-LM && \ + pip install -e . --no-deps + +WORKDIR /workspace/code + +COPY . . + +RUN pip install primus_mllog-0.1.0-py3-none-any.whl \ No newline at end of file diff --git a/gpt-oss-20b/primus/Dockerfile.nvidia b/gpt-oss-20b/primus/Dockerfile.nvidia new file mode 100644 index 000000000..e3bc50613 --- /dev/null +++ b/gpt-oss-20b/primus/Dockerfile.nvidia @@ -0,0 +1,31 @@ +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:25.12-py3 +FROM ${BASE_IMAGE} + +WORKDIR /workspace + +RUN pip install --no-cache-dir \ + pyyaml \ + pybind11 \ + ninja \ + packaging \ + transformers + +WORKDIR /workspace/deps +RUN git clone --recursive https://github.com/AMD-AIG-AIMA/Primus.git && \ + cd Primus && \ + git checkout 85c51c0da12f7d9b819f944eba9ffeb313795b9a && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt + +RUN cd /workspace/deps/Primus/third_party/Megatron-LM && \ + pip install -e . --no-deps + +ENV PYTHONPATH="/workspace/deps/Primus:/workspace/deps/Primus/third_party/Megatron-LM" + +WORKDIR /workspace/code +COPY . . + +# Install primus-mllog from local wheel +RUN pip install primus_mllog-0.1.0-py3-none-any.whl + +RUN pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 \ No newline at end of file diff --git a/gpt-oss-20b/primus/README.md b/gpt-oss-20b/primus/README.md new file mode 100644 index 000000000..d51af4ca3 --- /dev/null +++ b/gpt-oss-20b/primus/README.md @@ -0,0 +1,138 @@ +# GPT-OSS-20B Pretraining Benchmark + +GPT-OSS 20B (Mixture of Experts) + +## Overview + +This benchmark trains a 20B parameter GPT model with Mixture of Experts (MoE) architecture using the Primus framework on AMD and NVIDIA GPUs. + +# 1. Setup Docker Image + + +Run the following build command from this directory. The build process will take a while to complete. + +```bash +# From gpt-oss-20b/primus directory +docker build -t rocm/amd-mlperf:gpt_oss_20b_training_5.1 . +``` + +# 2. Prepare Dataset + +The current codebase uses the c4/en/3.0.1 dataset from [HuggingFace/AllenAI](https://huggingface.co/datasets/allenai/c4) for training and evaluation. + +## Download Preprocessed Data + +The pre-tokenized dataset is available for download. Navigate to your desired download directory and run the following commands: + +```bash +# Create desired download directory with the right permission +cd /data/gpt_oss_20b + +# Download training and validation data +bash <(curl -s https://raw.githubusercontent.com/mlcommons/r2-downloader/refs/heads/main/mlc-r2-downloader.sh) \ + -d data https://training.mlcommons-storage.org/metadata/llama-3-1-8b-preprocessed-c4-dataset.uri +``` + +After download, you should see files with the following naming conventions: +- Training: `c4-train.en_6_text_document.bin` and `.idx` +- Validation: `c4-validation-91205-samples.en_text_document.bin` and `.idx` + +The data directory is approximately **80 GB**. + +# 3. Run Training + +## Set Environment Variables + +Set the directory for data and results. Ensure `$LOGDIR` has write access. + +```bash +export DATADIR=/data/gpt_oss_20b/data +export LOGDIR=/data/gpt_oss_20b/results +export CONT=rocm/amd-mlperf:gpt_oss_20b_training_5.1 +export HF_TOKEN= + +# Create results directory +mkdir -p $LOGDIR +sudo chmod -R 777 $LOGDIR +``` + +## Set Configuration + +Set appropriate configuration and system-specific hyperparameters based on hardware type: + +| Config File | System | GPUs | +|-------------|--------|------| +| `config_MI355X_1x8x1.sh` | MI355X | 1 node × 8 GPUs | +| `config_B200_1x8x1.sh` | B200 | 1 node × 8 GPUs | + +```bash +source config_MI355X_1x8x1.sh +``` + +## Launch Training + +### Single Run + +```bash +export NEXP=1 +bash run_with_docker.sh +``` + +### Multiple Runs (for submission) + +```bash +export NEXP=10 +bash run_with_docker.sh +``` + +After completion, logs will be available under `$LOGDIR`. + +# 4. Quality Metrics + +## Quality Metric + +Validation loss (log perplexity) + +## Evaluation Frequency + +Evaluation every **768 iterations** (12,288 samples with GBS=16) + +## Evaluation Thoroughness + +We evaluate using **1024 sequences** from the validation dataset. + +# 5. Model Architecture + +| Parameter | Value | +|-----------|-------| +| Model Size | 20B parameters | +| Architecture | GPT with Mixture of Experts | +| Sequence Length | 8192 | +| Expert Parallelism | 8 | + +# 6. Training Configuration + +| Hyperparameter | Value | +|----------------|-------| +| Micro Batch Size | 2 | +| Global Batch Size | 16 | +| Learning Rate | 8e-4 | +| LR Schedule | Cosine decay with warmup | +| Weight Decay | 0.1 | +| Adam β1, β2 | 0.9, 0.95 | +| Training Iterations | 20,000 | + +# 7. Directory Structure + +``` +gpt-oss-20b/primus/ +├── conf/ # Configuration files +│ └── gpt_oss_20B-pretrain.yaml +├── src/ # Training source code +│ └── train.py +├── config_MI355X_1x8x1.sh # System configuration (MI355 - AMD) +├── config_B200_1x8x1.sh # System configuration (B200 - NVIDIA) +├── Dockerfile # Dockerfile (MI355 - AMD) +├── Dockerfile.nvidia # Dockerfile (B200 - NVIDIA) +└── requirements.txt # Python dependencies (includes primus-mllog) +``` \ No newline at end of file diff --git a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml new file mode 100644 index 000000000..a938da5f6 --- /dev/null +++ b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain-nvidia.yaml @@ -0,0 +1,107 @@ +work_group: ${TEAM:nvidia} +user_name: ${USER:root} +exp_name: ${EXP_NAME:gpt_oss_20b_nvidia} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_20B}.yaml + overrides: + # log + wandb_project: "Primus_GPT_OSS_20B_NVIDIA" + stderr_sink_level: DEBUG + log_interval: 99999999 # Suppress console logs + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # profile + profile: false + use_pytorch_profiler: false + profile_step_end: 7 + profile_step_start: 6 + + # precision (mixed precision training) + # Using bf16 for B200 + bf16: true + fp16: false + fp8: null # Disabled - using bf16 instead + + # hyper parameters + train_iters: ${PRIMUS_TRAIN_ITERS:20000} + micro_batch_size: ${PRIMUS_MICRO_BATCH_SIZE:2} + global_batch_size: ${PRIMUS_GLOBAL_BATCH_SIZE:16} + seq_length: ${PRIMUS_SEQ_LENGTH:8192} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:8192} + seed: ${SEED:1234} + lr: ${PRIMUS_LR:8.0e-4} + min_lr: 8.0e-5 # Set to 10% of max LR + lr_warmup_iters: ${PRIMUS_LR_WARMUP_ITERS:128} + lr_decay_iters: 1199872 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:1} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: false + train_data_path: "10 /data/c4-train.en_6_text_document" + valid_data_path: "/data/c4-validation-91205-samples.en_text_document" + test_data_path: "/data/c4-validation-91205-samples.en_text_document" + + # fusion (standard Megatron optimizations) + moe_permute_fusion: false + gradient_accumulation_fusion: false + moe_grouped_gemm: false # Disable grouped_gemm requirement + moe_use_legacy_grouped_gemm: false + moe_use_fused_router_with_aux_score: false + multi_latent_attention: false + apply_rope_fusion: false + + # MoE router configuration + moe_shared_expert_overlap: false + moe_router_dtype: fp32 + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + exit_on_missing_checkpoint: false + ckpt_format: torch + eval_iters: 64 # 64 iters × 2 MBS × 8 GPUs = 1024 eval samples + eval_interval: ${PRIMUS_EVAL_INTERVAL:768} + + # Turbo features disabled for NVIDIA + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + + use_turbo_deepep: false + turbo_deepep_num_cu: 0 + turbo_deepep_use_comm_stream: false + + turbo_sync_free_moe_stage: 0 + diff --git a/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml new file mode 100644 index 000000000..97e8e7fbf --- /dev/null +++ b/gpt-oss-20b/primus/conf/gpt_oss_20B-pretrain.yaml @@ -0,0 +1,122 @@ +work_group: ${TEAM:amd} +user_name: ${USER:root} +exp_name: ${EXP_NAME:gpt_oss_20b} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_20B}.yaml + overrides: + # log + wandb_project: "Primus_DeepSeek_Pretrain" + stderr_sink_level: DEBUG + log_interval: 99999999 # Suppress console logs + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # profile + profile: false + use_pytorch_profiler: false + profile_step_end: 7 + profile_step_start: 6 + + # precision (mixed precision training) + bf16: true + fp16: false + fp8: null # Set to "e4m3" or "hybrid" for FP8 training + + # hyper parameters + train_iters: ${PRIMUS_TRAIN_ITERS:20} + micro_batch_size: ${PRIMUS_MICRO_BATCH_SIZE:4} + global_batch_size: ${PRIMUS_GLOBAL_BATCH_SIZE:640} + seq_length: ${PRIMUS_SEQ_LENGTH:8192} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:8192} + seed: ${SEED:1234} + lr: ${PRIMUS_LR:1.0e-5} + min_lr: 1.0e-5 # Set to 10% of max LR like Llama 8B (8e-5 is 10% of 8e-4) + lr_warmup_iters: ${PRIMUS_LR_WARMUP_ITERS:128} # Match Llama 8B warmup + lr_decay_iters: 1199872 # Match Llama 8B decay schedule (75M samples / 640 GBS = 117187 iters, but using their value) + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:1} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: false + train_data_path: "10 /data/c4-train.en_6_text_document" + valid_data_path: "/data/c4-validation-91205-samples.en_text_document" + test_data_path: "/data/c4-validation-91205-samples.en_text_document" + + # fusion + # 20250321: need latest megatron docker image + moe_permute_fusion: false + # fused wgrad gemm and accumulation + gradient_accumulation_fusion: false + # recommend set `false` in fp8 + moe_use_legacy_grouped_gemm: true + # fused topk router with aux score + moe_use_fused_router_with_aux_score: false + # MLA + multi_latent_attention: false + # rope fusion + apply_rope_fusion: false + + # DeepEP does not support moe_shared_expert_overlap + moe_shared_expert_overlap: false + # DeepEP only supports float32 probs + moe_router_dtype: fp32 + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + exit_on_missing_checkpoint: false + ckpt_format: torch + eval_iters: 64 # 64 iters × 2 MBS × 8 GPUs = 1024 eval samples (matching Llama 8B) + eval_interval: ${PRIMUS_EVAL_INTERVAL:5} + + # Turbo + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: true + + # deepep + use_turbo_deepep: true + + # 64 or 80 for ep8, 32 for ep16-64 is best practice + turbo_deepep_num_cu: 64 + turbo_deepep_use_comm_stream: false + + # sync-free moe support stage 0-3, 0 means not use sync-free moe + # stage 3 is completely no gpu-cpu sync in MoE, but cost more memory + # stage 2 is recommended for better performance + turbo_sync_free_moe_stage: 2 + + # Cross entropy flags + # cross_entropy_fusion_impl: "te" + # cross_entropy_loss_fusion: true + diff --git a/gpt-oss-20b/primus/config_B200_1x8x1.sh b/gpt-oss-20b/primus/config_B200_1x8x1.sh new file mode 100755 index 000000000..67cb0248b --- /dev/null +++ b/gpt-oss-20b/primus/config_B200_1x8x1.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +export DGXSYSTEM=B200_1x8x1 +export GPUS_PER_NODE=8 +export NNODES=1 +export NODE_RANK=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29501 + +export PRIMUS_PATH=/workspace/deps/Primus +export PYTHONPATH="${PRIMUS_PATH}:${PRIMUS_PATH}/third_party/Megatron-LM:${PYTHONPATH}" +export EXP=/workspace/code/conf/gpt_oss_20B-pretrain-nvidia.yaml +export DATA_PATH=/data + +export PRIMUS_MICRO_BATCH_SIZE=2 +export PRIMUS_GLOBAL_BATCH_SIZE=16 +export PRIMUS_LR=8e-4 +export PRIMUS_TRAIN_ITERS=20000 + +# Evaluation frequency (sample-based, adjusts automatically with GBS) +export EVAL_SAMPLES_INTERVAL=12288 # Evaluate every 12,288 samples +export PRIMUS_EVAL_INTERVAL=$((EVAL_SAMPLES_INTERVAL / PRIMUS_GLOBAL_BATCH_SIZE)) # Auto-computed + +export PRIMUS_BF16=true +export PRIMUS_FP16=false +export PRIMUS_FP8=null + +export PRIMUS_TURBO_ENABLED=false +export USE_TURBO_ATTENTION=false +export USE_TURBO_GROUPED_MLP=false +export USE_TURBO_DEEPEP=false +export TURBO_DEEPEP_NUM_CU=0 +export TURBO_SYNC_FREE_MOE_STAGE=0 + +export PRIMUS_APPLY_ROPE_FUSION=false +export USE_ROCM_MEM_INFO=false + +export OVERLAP_GRAD_REDUCE=true +export OVERLAP_PARAM_GATHER=true +export GRADIENT_ACCUMULATION_FUSION=false + +export PRIMUS_TP=1 +export PRIMUS_PP=1 +export PRIMUS_EP=8 + +export ENABLE_MLLOG=1 +export MLLOG_OUTPUT_FILE=/results/mlperf_output.log +export MLLOG_TRAIN_LOSS_LOG_FREQ=1 +export MLLOG_TARGET_EVAL_LOSS=3.3 +export MLLOG_SUBMISSION_BENCHMARK=gpt-oss-20b +export MLLOG_SUBMISSION_DIVISION=closed +export MLLOG_SUBMISSION_ORG=NVIDIA +export MLLOG_SUBMISSION_PLATFORM=B200 + +export HF_TOKEN="${HF_TOKEN:-}" + diff --git a/gpt-oss-20b/primus/config_MI355X_1x8x1.sh b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh new file mode 100755 index 000000000..260446a23 --- /dev/null +++ b/gpt-oss-20b/primus/config_MI355X_1x8x1.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# ============================================================================= +# MLPerf GPT-OSS-20B Configuration for MI355X (1 node, 8 GPUs) +# ============================================================================= + +# ----------------------------------------------------------------------------- +# System Configuration +# ----------------------------------------------------------------------------- +export DGXSYSTEM=MI355X_1x8x1 +export GPUS_PER_NODE=8 +export NNODES=1 +export NODE_RANK=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29501 + +# ----------------------------------------------------------------------------- +# Paths +# ----------------------------------------------------------------------------- +export PRIMUS_PATH=/workspace/deps/Primus +export PYTHONPATH="${PRIMUS_PATH}:${PRIMUS_PATH}/third_party/Megatron-LM:${PYTHONPATH}" +export EXP=/workspace/code/conf/gpt_oss_20B-pretrain.yaml +export DATA_PATH=/data + +# ----------------------------------------------------------------------------- +# Training Hyperparameters +# ----------------------------------------------------------------------------- +export PRIMUS_MICRO_BATCH_SIZE=2 +export PRIMUS_GLOBAL_BATCH_SIZE=16 +export PRIMUS_LR=8e-4 +export PRIMUS_TRAIN_ITERS=20000 # 20K iters × 16 GBS = 320K samples + +# Evaluation frequency (sample-based, adjusts automatically with GBS) +export EVAL_SAMPLES_INTERVAL=12288 # Evaluate every 12,288 samples +export PRIMUS_EVAL_INTERVAL=$((EVAL_SAMPLES_INTERVAL / PRIMUS_GLOBAL_BATCH_SIZE)) # Auto-computed + +# ----------------------------------------------------------------------------- +# Optimizations +# ----------------------------------------------------------------------------- +export PRIMUS_APPLY_ROPE_FUSION=True +export PRIMUS_FP8_RECIPE=hybrid + +# ----------------------------------------------------------------------------- +# MLPerf Logging +# ----------------------------------------------------------------------------- +export ENABLE_MLLOG=1 +export MLLOG_OUTPUT_FILE=/results/mlperf_output.log +export MLLOG_TRAIN_LOSS_LOG_FREQ=100 +export MLLOG_TARGET_EVAL_LOSS=3.3 +export MLLOG_SUBMISSION_BENCHMARK=gpt-oss-20b +export MLLOG_SUBMISSION_DIVISION=closed +export MLLOG_SUBMISSION_ORG=AMD +export MLLOG_SUBMISSION_PLATFORM=MI355X + +# ----------------------------------------------------------------------------- +# TE Configuration +# ----------------------------------------------------------------------------- +export NVTE_ROCM_ENABLE_MXFP8=0 diff --git a/gpt-oss-20b/primus/primus_mllog-0.1.0-py3-none-any.whl b/gpt-oss-20b/primus/primus_mllog-0.1.0-py3-none-any.whl new file mode 100644 index 000000000..75f0a035d Binary files /dev/null and b/gpt-oss-20b/primus/primus_mllog-0.1.0-py3-none-any.whl differ diff --git a/gpt-oss-20b/primus/requirements.txt b/gpt-oss-20b/primus/requirements.txt new file mode 100644 index 000000000..74fc25c1b --- /dev/null +++ b/gpt-oss-20b/primus/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 \ No newline at end of file diff --git a/gpt-oss-20b/primus/run_and_time.sh b/gpt-oss-20b/primus/run_and_time.sh new file mode 100755 index 000000000..e67606dfe --- /dev/null +++ b/gpt-oss-20b/primus/run_and_time.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +set -e + +# Create results directory +mkdir -p /results + +cd /workspace/code + +echo "============================================" +echo "MLPerf GPT-OSS-20B Training" +echo "============================================" +echo "Config: ${EXP}" +echo "Data: ${DATA_PATH}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Nodes: ${NNODES}" +echo "============================================" + +# Start timing +start=$(date +%s) +start_fmt=$(date +%Y-%m-%d\ %r) +echo "STARTING TIMING RUN AT $start_fmt" + +# Launch distributed training +torchrun \ + --nproc_per_node=${GPUS_PER_NODE} \ + --nnodes=${NNODES} \ + --node_rank=${NODE_RANK} \ + --master_addr=${MASTER_ADDR} \ + --master_port=${MASTER_PORT} \ + src/train.py + +ret_code=$? + +# End timing +end=$(date +%s) +end_fmt=$(date +%Y-%m-%d\ %r) +echo "ENDING TIMING RUN AT $end_fmt" + +# Report result +result=$(( end - start )) +result_name="GPT_OSS_20B" +echo "RESULT,$result_name,,$result,AMD,$start_fmt" + +if [[ $ret_code != 0 ]]; then + echo "Training failed with exit code: $ret_code" + exit $ret_code +fi + +exit 0 diff --git a/gpt-oss-20b/primus/run_with_docker.sh b/gpt-oss-20b/primus/run_with_docker.sh new file mode 100755 index 000000000..35bc4104a --- /dev/null +++ b/gpt-oss-20b/primus/run_with_docker.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +set -euxo pipefail + +# Change directory to the primus directory +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +cd $SCRIPT_DIR + +# Vars without defaults +: "${DGXSYSTEM:?DGXSYSTEM not set}" +: "${CONT:?CONT not set}" +: "${DATADIR:?DATADIR not set}" +: "${LOGDIR:?LOGDIR not set}" + +# Vars with defaults +: "${NEXP:=1}" +: "${DATESTAMP:=$(date +'%y%m%d%H%M%S%N')}" +: "${CLEAR_CACHES:=1}" +: "${CHECK_COMPLIANCE:=0}" +: "${MLPERF_RULESET:=5.1.0}" +: "${UTILITIES:="$(pwd)/../../utilities"}" + +: "${CONT_NAME:=dev}" +: "${NGPU:=1}" +: "${LOG_FREQ:=0}" +: "${HF_TOKEN:=""}" + +# Other vars +readonly _config_file="./config_${DGXSYSTEM}.sh" +readonly _logfile_base="${LOGDIR}/${DATESTAMP}" +readonly _cont_name="${CONT_NAME}" +_cont_mounts=("--volume=${DATADIR}:/data" "--volume=$(pwd):/workspace/code" "--volume=$(pwd)/../../AMD:/workspace/AMD" "--volume=${UTILITIES}:/workspace/utilities" "--volume=${LOGDIR}:/results") + + +# Setup directories +mkdir -p "${LOGDIR}" +mkdir -p "${LOGDIR}/artifacts/" + +# Get list of envvars to pass to docker +mapfile -t _config_env < <(env -i bash -c ". ${_config_file} && compgen -e" | grep -E -v '^(PWD|SHLVL)') +_config_env+=(DATADIR) +_config_env+=(DGXSYSTEM) +_config_env+=(PROFILER) +_config_env+=(LOGDIR) +_config_env+=(HIPBLASLT_LOG) +_config_env+=(GEMM_OFFLINE_TUNING) +_config_env+=(GEMM_USE_TUNING_RESULTS) +_config_env+=(HF_TOKEN) +_config_env+=(SEED) + +echo ${_config_env[@]} +mapfile -t _config_env < <(for v in "${_config_env[@]}"; do echo "--env=$v"; done) + +# Cleanup container +cleanup_docker() { + if docker ps -a --format '{{.Names}}' | grep -q "^${_cont_name}$"; then + docker container rm -f "${_cont_name}" || true + else + echo "Container ${_cont_name} does not exist. Skipping removal." + fi +} +cleanup_docker +trap 'set -eux; cleanup_docker' EXIT + +# Setup container +# Use DGXSYSTEM to determine hardware type (MI* = AMD/ROCm, otherwise NVIDIA) +if [[ "${DGXSYSTEM}" == MI* ]]; then + echo "Using AMD/ROCm container flags" + docker run --rm --init --detach \ + --net=host --uts=host --ipc=host \ + --device /dev/dri --device /dev/kfd --device=/dev/infiniband \ + --cap-add=SYS_PTRACE --cap-add=CAP_SYS_ADMIN \ + --security-opt=seccomp=unconfined \ + --group-add video \ + --privileged \ + --name="${_cont_name}" "${_cont_mounts[@]}" \ + -e IMAGE_NAME="${CONT}" \ + "${CONT}" sleep infinity +else + echo "Using NVIDIA container flags" + docker run --rm --init --detach \ + --net=host --uts=host \ + --ipc=host --gpus all \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --device=/dev/infiniband \ + --security-opt=seccomp=unconfined \ + --name="${_cont_name}" "${_cont_mounts[@]}" \ + -e IMAGE_NAME="${CONT}" \ + "${CONT}" sleep infinity +fi + + +# Make sure container has time to finish initialization +sleep 5 +docker exec "${_cont_name}" true + +# Run experiments +for _experiment_index in $(seq 1 "${NEXP}"); do + ( + echo "Beginning trial ${_experiment_index} of ${NEXP}" + if [[ $CLEAR_CACHES == 1 ]]; then + bash -c "echo -n 'Clearing cache on ' && hostname && sync && sudo /sbin/sysctl vm.drop_caches=3" + fi + # Use existing SEED if set; otherwise use a new RANDOM value + _config_env+=(--env=SEED="${SEED:-$RANDOM}") + echo 'launching experiment using:' ${_config_env[@]} ${_cont_name} /workspace/code/run_and_time.sh + docker exec ${_config_env[@]} ${_cont_name} bash /workspace/code/run_and_time.sh + ) | grep --line-buffered -v "connected peer ranks" | tee "${_logfile_base}_${_experiment_index}.log" + + if [ "${CHECK_COMPLIANCE}" -eq 1 ]; then + docker exec "${_config_env[@]}" "${_cont_name}" \ + python3 -m mlperf_logging.compliance_checker --usage training \ + --ruleset "${MLPERF_RULESET}" \ + --log_output "/results/compliance_${DATESTAMP}.out" \ + "/results/${DATESTAMP}_${_experiment_index}.log" + fi + +done + diff --git a/gpt-oss-20b/primus/src/train.py b/gpt-oss-20b/primus/src/train.py new file mode 100644 index 000000000..e0905cb5b --- /dev/null +++ b/gpt-oss-20b/primus/src/train.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import os +import sys +from pathlib import Path + +# Primus and Megatron paths (set by run_and_time.sh or here as fallback) +PRIMUS_PATH = os.getenv("PRIMUS_PATH", "/workspace/deps/Primus") +MEGATRON_PATH = os.path.join(PRIMUS_PATH, "third_party/Megatron-LM") + +if PRIMUS_PATH not in sys.path: + sys.path.insert(0, PRIMUS_PATH) +if MEGATRON_PATH not in sys.path: + sys.path.insert(0, MEGATRON_PATH) + +from primus.core.launcher.config import PrimusConfig +from primus.core.launcher.parser import load_primus_config, add_pretrain_parser +from primus_mllog import MLPerfMegatronPretrainTrainer + +import argparse + +def setup_environment(data_path: str = None): + """Setup HuggingFace home and other environment variables.""" + if data_path and "HF_HOME" not in os.environ: + hf_home = os.path.join(data_path, "huggingface") + os.environ["HF_HOME"] = hf_home + print(f"[MLPerf Train] HF_HOME={hf_home}") + + +def load_config(config_path: str, overrides: list = None) -> PrimusConfig: + """ + Load and parse the experiment YAML configuration. + + The config file (e.g., gpt_oss_20B-pretrain.yaml) defines: + - Model architecture (hidden size, num layers, attention heads, etc.) + - Training hyperparameters (batch size, learning rate, etc.) + - Data paths and tokenizer settings + - Parallelism settings (TP, PP, EP) + """ + # Create args namespace for Primus config loader + parser = argparse.ArgumentParser() + add_pretrain_parser(parser) + + args = parser.parse_args([ + '--config', config_path, + '--data_path', os.getenv('DATA_PATH', '/data'), + ]) + + primus_cfg, unknown_overrides = load_primus_config(args, overrides or []) + + print(f"[MLPerf Train] Loaded config from: {config_path}") + print(f"[MLPerf Train] Framework: {primus_cfg.get_module_config('pre_trainer').framework}") + + return primus_cfg, unknown_overrides + + +def create_trainer(primus_cfg: PrimusConfig, extra_args: list = None) -> MLPerfMegatronPretrainTrainer: + """ + Create the MLPerf-enabled Megatron trainer. + + The trainer handles: + - Model creation (GPT architecture with MoE) + - Optimizer setup (Adam with configurable betas) + - Learning rate scheduling (warmup + cosine decay) + - Distributed training coordination + - MLPerf logging and metrics + """ + # Get distributed training configuration from environment + # These are set by torchrun when launching distributed training + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = int(os.getenv("MASTER_PORT", "29500")) + + trainer = MLPerfMegatronPretrainTrainer( + module_name="pre_trainer", + primus_config=primus_cfg, + module_rank=rank, + module_world_size=world_size, + module_master_addr=master_addr, + module_master_port=master_port, + extra_args=extra_args, + ) + return trainer + +def main(): + config_path = os.environ.get("EXP", "/workspace/code/conf/gpt_oss_20B-pretrain.yaml") + + if not Path(config_path).exists(): + raise FileNotFoundError(f"Config not found: {config_path}") + + setup_environment(data_path=os.getenv('DATA_PATH', '/data')) + primus_cfg, extra_args = load_config(config_path) + + trainer = create_trainer(primus_cfg, extra_args) + trainer.init() + trainer.run() + +if __name__ == "__main__": + main()