diff --git a/mixture_of_experts_pretraining/README.md b/mixture_of_experts_pretraining/README.md new file mode 100644 index 000000000..1326e8a98 --- /dev/null +++ b/mixture_of_experts_pretraining/README.md @@ -0,0 +1,518 @@ +# 1. MLPerf Training: MoE Benchmark +This benchmark focuses on training Mixtral8x22B with a 32,768 token sequence length with key features: +* spare mixture-of-experts architecture: we specifically use the [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1) architecture and checkpoint. This allows for greater computational efficiency compared to dense models like GPT-3 or LLaMA2-3, as only a subset of experts are activated during training and inferencing. +* extended sequence length: handles sequences up to 32,768 tokens long, enabling larger contexts window +* dropped implementation: means dropping tokens assigned to experts that are already at capacity. That would provide more consistent performance to effectively address load balancing issue. Inspired by [Switch Transformer](https://arxiv.org/pdf/2101.03961), we set `capacity_factor=1.25` to determine the maximum token load for each expert. + +# 2. Dataset +This benchmark uses the +[C4](https://www.tensorflow.org/datasets/catalog/c4)/en/3.0.1 +dataset from TensorFlow Dataset. A version is +[available](https://huggingface.co/datasets/allenai/c4/tree/main/en) +from Hugging Face. + +The benchmark uses a different split than the original C4/en/3.0.1: + +| Split | What's in it | What it is used for | number of samples | +| - | - | - | - | +| train1 | first 768 of 1024 files of C4/en/3.0.1:train | training the initial checkpoint | 274,651,678 | +| train2 | last 256 of 1024 files of C4/en/3.0.1:train | training dataset of the benchmark | 91,217,223 | +| validation\_24567exp | 1/20th of C4/en/3.0.1:validation | validation datset of the benchmark | 24,567 | + +The dataset is availabe as s3 artifacts. See [the guide](#9-s3-artifacts-download) for downloading. + +Note this benchmark uses the same dataset as gpt3-175b benchmark see [dataset in gpt3-175b benchmark for reference](https://github.com/mlcommons/training/blob/master/large_language_model/paxml/README.md#2-dataset). + +# 3. Model, Checkpoint, Optimizer, & Tokenizer +we specifically use the [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1) architecture and checkpoint. + +| Config | Value | +| - | - | +| num_hidden_layers | 56 | +| num_attention_heads | 48 | +| num_experts_per_tok | 2 | +| num_key_value_heads | 8 | +| num_local_experts | 8 | +| vocab_size | 32000 | +| hidden_size | 6144 | +| intermediate_size | 16384 | + +As for optimizer, we use [adamw](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + +We are using the sentencepiece tokenizer under [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1) +``` +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mixtral-8x22B-v0.1", + add_eos_token=False, + add_bos_token=False, + use_fast=False, +) +``` +Note, we should use `use_fast=False` to avoid using the wrong tokenizer implmented in rust. We already found some token mismatch between 2 tokenizers, and it affected loss curve sutbly in previous experiments. + +# 4. Evaluation +## Evaluation loss metric +Negative log likelihood loss for next token prediction + +## [TBD] Target Evaluation Loss +1.8 + +## [TBD] Evaluation frequency +24 * 1024 * 2048 tokens + +# 5.Training with [torch_xla](https://github.com/pytorch/xla/tree/master) on TPU Device +## Environment Setup + +Docker image is used in this repo for environment setup. + +The following command is to create an enviroment with necessary libraries, mainly including: +* ML Framework: [torch_xla](https://github.com/pytorch/xla.git) +* Models: [transformers](https://github.com/huggingface/transformers.git) +* config tool: [hydra-core](https://hydra.cc/) +```bash +# This command will create, tag, and push an image default to gcr.io/${PROJECT_ID}/${USER}-pytorch-xla-moe-${DATE} +bash docker/tpu/build_and_push_image.sh +``` + +```bash +# Alternatively, create, tag, and push an image with different name +IMAGE= bash docker/tpu/build_and_push_image.sh +``` + +### Prebuilt Docker Images + +For now, we have uploaded docker image tar ball to s3 bucket. See [the guide](#9-s3-artifacts-download) for downloading. + +Once downloaded, we can use the following command to extract: +``` +docker load -i pytorch-xla-moe-20250101.tar +``` + +## Checkpoint +We provided 2 pre-converted checkpoint for full FSDP and 2D FSDP TP sharding respectively: +* Mixtral-8x22B-v0.1-fsdp: use for `tensor_parallelism=1` +* Mixtral-8x22B-v0.1-2d-fsdp-tp: use for `tensor_parallelism` > 1 + +See [the guide](#9-s3-artifacts-download) for downloading. + +These above checkpoint conversion is done by using [distributed_checkpoint_saving.py](scripts/tpu/distributed_checkpoint_saving.py). + +See the following detail guides: +* [Checkpoint Conversion in GKE](#checkpoint-conversion-in-gke) +* [Checkpoint Conversion in GCE](#checkpoint-conversion-in-gce) + +The script will do the following steps, which usually take less than 1 hour to complete +1) downloading [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1) from Hugging face, +2) model weight sharding according to the assigned strategy +3) save the distributed state_dict to disks. + +Model conversion is feasible when +* the total accelerator HBM across all devices can accommodate the model's size: This is because we distribute the model's shards equally among the available devices in either FSDP or 2D FSDP & TP sharding +* each host have enough CPU to accomodate the full model size since we need to load full model weights to CPU as the first step + +## Capacity Needed +To train the Mixtral 8x22B model with a 32,768 token sequence length: + +* Minimum Requirement: 64 TPU v5p chips (v5p-128). +* Convergence Test: 256 TPU v5p chips (v5p-512) were used. + +## [recommended] Run Experiments in GKE + +### Install XPK and create GKE cluster. +``` +pip install xpk +python ~/xpk/xpk.py cluster create --cluster --tpu-type= --num-slices= +``` + +### Checkpoint Conversion in GKE + +```bash +# login token required since +# the mixtral model is a restricted model +# that requires users e-signed agreement in place before accessing it +export HF_TOKEN= + +cat << EOS > script.sh +# Setup envs +export HF_HOME=/tmp +export HYDRA_FULL_ERROR=1 +export WANDB_MODE=offline + +export PJRT_DEVICE=TPU +export XLA_USE_SPMD=1 + +# Debug info +export XLA_IR_DEBUG=1 +export XLA_HLO_DEBUG=1 + +# Avoid circular import +export USE_JAX=false + +cd /app +git pull +huggingface-cli login --token ${HF_TOKEN} + +# conversion script for Mixtral-8x22B-v0.1-fsdp +python -m mixture_of_experts_pretraining.scripts.tpu.distributed_checkpoint_saving model.name_or_path=mistralai/Mixtral-8x22B-v0.1 checkpoint_manager_path=/tmp/checkpoints + +gsutil -m cp -r /tmp/checkpoints gs://bucket/path/to/checkpoints +EOF + +python ~/xpk/xpk.py workload create \ +--cluster \ +--base-docker-image ${IMAGE} \ +--workload ${USER}-run \ +--tpu-type= \ +--num-slices= \ +--command="bash script.sh" +``` + +Add a valid `tensor_parallelism` bigger than 1 like `tensor_parallelism=2` to conversion command for `Mixtral-8x22B-v0.1-2d-fsdp-tp` conversion. + +### Run workload in GKE +```bash +# login token required since +# the mixtral model is a restricted model +# that requires users e-signed agreement in place before accessing it +export HF_TOKEN= + +cat << EOS > script.sh +# Setup envs +export HF_HOME=/tmp +export HYDRA_FULL_ERROR=1 +export WANDB_MODE=offline + +export PJRT_DEVICE=TPU +export XLA_USE_SPMD=1 + +# Debug info +export XLA_IR_DEBUG=1 +export XLA_HLO_DEBUG=1 + +# Avoid circular import +export USE_JAX=false + +cd /app/mixture_of_experts_pretraining +git pull +huggingface-cli login --token ${HF_TOKEN} +# workload script +# equivalent of +# python run_clm.py +experiment=gbs256_tpu +python run_clm.py model.config_path=mixtral822.json per_device_train_batch_size=1 optimizer=ADAMW_TORCH_XLA checkpoint_manager_path=gs://lizhiyu-multipods-eu-west/moe/checkpoints-20240803/mixtral822/ model.name_or_path=mistralai/Mixtral-8x22B-v0.1 dataset.dataset_name=c4_mlperf max_steps=250 max_grad_norm=1.0 seed=4321 model.dtype=bfloat16 output_dir=/app/output max_length=32768 dataset.streaming=True tensor_parallelism=1 exp_name=convergence_exp model.capacity_factor=1.25 lr=2e-5 sched=WarmupHoldPolicy +EOF + +python ~/xpk/xpk.py workload create \ +--cluster \ +--base-docker-image ${IMAGE} \ +--workload ${USER}-run \ +--tpu-type= \ +--num-slices= \ +--command="bash script.sh" +``` +Note that the dataset path defaults as follows in [`dataset/c4_mlperf.yaml`](config/dataset/c4_mlperf.yaml) +``` +train_dataset_path: gs://mlperf-llm-public2/c4/en_json/3.0.1 +eval_dataset_path: gs://mlperf-llm-public2/c4/en_val_subset_json +``` +You can freely overwrite the workload command by adding +`dataset.train_dataset_path=/path/to/train/dir dataset.eval_dataset_path=/path/to/eval/dir`, and the path should support both local directory and gcs buckets. + +## Run Experiments in GCE + +### set project and zone + +```bash +# change to a valid PROJECT_ID and ZONE +export PROJECT_ID=cloud-tpu-multipod-dev +export ZONE=us-central2-b + +gcloud config set project ${PROJECT_ID} +gcloud config set compute/zone ${ZONE} +``` + +### Create TPU VMs +```bash +# create tpu vm say v4-8 as an example +export RUNTIME_VERSION=tpu-ubuntu2204-base +export TPU_NAME=${USER}-mlperf +gcloud compute tpus tpu-vm create ${TPU_NAME} --zone=${ZONE} --accelerator-type='v4-8' --version=${RUNTIME_VERSION} +``` + +### Docker Authorization +Pull docker image, say a pre-built image `gcr.io/cloud-tpu-multipod-dev/lizhiyu-pytorch-xla-moe-20241031` +```bash +# change to a valid docker image +export IMAGE=gcr.io/cloud-tpu-multipod-dev/lizhiyu-pytorch-xla-moe-20241031 + +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--worker=all \ +--command=" +yes Y | sudo gcloud auth configure-docker +sudo docker pull ${IMAGE} +" +``` + +### Checkpoint Conversion in GCE + +```bash +# login token required since +# the mixtral model is a restricted model +# that requires users e-signed agreement in place before accessing it +export HF_TOKEN= + +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--worker=all \ +--command=" +sudo docker run --privileged --net host --shm-size=16G --interactive -v /tmp:/tmp ${IMAGE} bash -s < + +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--worker=all \ +--command=" +sudo docker run --privileged --net host --shm-size=16G --interactive -v /tmp:/tmp ${IMAGE} bash -s < +``` + +## Dry-run Tests +To perform dry-run tests with different models/parameters on a smaller scale like `v4-8`, use the following python commands as workload: +### Test with a smaller mixtral model +``` +python run_clm.py model.config_path=mixtral80.json eval_frequency=3 n_eval_examples=100 per_device_train_batch_size=4 max_steps=30 +``` +### Test with a gpt2 model +``` +python run_clm.py model.name_or_path=gpt2 eval_frequency=3 n_eval_examples=100 per_device_train_batch_size=4 gradient_accumulation_steps=2 sched.warmup_ratio=0. max_steps=30 +``` + +# 6. Training Mixtral 8x22B with NeMo on GPU Device + +**IMPORTANT** GPU implementation is a supplementary reference and it is not used for RCP +generation. There is convergence gap between GPU and TPU reference and GPU code cannot be used as a +inplace substitute of TPU code. + +## Docker Image + +Build and push docker image: + +```shell +docker build -t : -f docker/gpu/Dockerfile . +docker push : +``` + +## Kubernetes workflow +### Run workflow + +In order for this workflow to function, in the ```helm-context``` directory, there must exist a **_select-configuration.yaml_** file. + +Package and schedule job. An example job name could be "nemo-gpt3-175b-nemo-16gpus". Use whatever is convenient when searching for later. + + +```shell +helm install helm-context/ +``` + +### Monitor workflow + +Check pod status (use this to find the name of the pod you want logs from) + + +```shell +kubectl get pods | grep "" +``` + + +Check job status + + +```shell +kubectl get jobs | grep "" +``` + + +Get logs (Using pod name from earlier) + + +```shell +kubectl logs "" +``` + +## Slurm/Pyxis workflow + +### Preprocessing + +For GPU implementation both dataset and checkpoint has to be preprocessed. This can be done once, +before experimentation and saved. **IMPORTANT** saved checkpoint and dataset has to be accessible by +all nodes in the system. + +To get preprocessed checkpoint, run checkpoint_download.py script +```shell +python scripts/gpu/checkpoint_download.py --checkpoint_id mistralal/Mixtral-8x22B-v0.1 \ + --output_dir --hf_token +``` +This script will download specified checkpoint directly from huggingface repository, preprocess it and save +into specified directory + +Preprocessed dataset can be downloaded from mlcommons S3 bucket, follow S3 artifacts download +section for more information. Preprocessed dataset is located in `preprocessed_c4` directory. +This option is highly recommended. + +For manual downloading and preprocessing the dataset the dataset_preprocessing.py script can be used +```shell +python scripts/gpu/dataset_preprocessing.py --input-tokenizer \ + --workdir +``` +After preprocessing, dataset will be saved into /output + +### Running + +Slurm workflow by default loads config /config/config.yaml. Make sure correct config is specified or +modify the script to mount correct config into /app/training/config/config.yaml + +To run the job specify required input environmental variables: + +```shell +export CONT=: +export DATA= +export CKPT= +export NODES= +export OUTPUT= +``` + +After that run sbatch command using scripts/gpu/run.sub: +```shell + +sbatch -N${NODES} scripts/gpu/run.sub +``` + +# 7. Reference +* [MLPerf Training: MoE Benchmark Proposal from Nvidia](https://docs.google.com/document/d/1NOJ_vt-o2WHFXmisLRk6Mn7Ki2CeB5UNeTkFrYHoE1I/edit?usp=sharing) +* [Mixtral of Experts](https://arxiv.org/pdf/2401.04088) + +# 8. Lint + +``` +black mixture_of_experts_pretraining/ +``` + +# 9. S3 artifacts download +The dataset, docker image and the checkpoints are available to download from an S3 bucket. You can download this data from the bucket using Rclone as follows: + +To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows). +To install Rclone on Linux/macOS/BSD systems, run: +``` +sudo -v ; curl https://rclone.org/install.sh | sudo bash +``` +Once Rclone is installed, run the following command to authenticate with the bucket: +``` +rclone config create mlc-training s3 provider=Cloudflare access_key_id=76ea42eadb867e854061a1806220ee1e secret_access_key=a53625c4d45e3ca8ac0df8a353ea3a41ffc3292aa25259addd8b7dc5a6ce2936 endpoint=https://c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com +``` +You can then navigate in the terminal to your desired download directory and run the following commands to download the dataset and checkpoints: + +## Text Datasets +**Dataset** +* Train Dataset`c4/en_json/3.0.1` +* Eval Dataset `c4/en_val_subset_json` +* Preprocessed GPU dataset `preprocessed_c4` +``` +mkdir -p datasets +rclone copy mlc-training:mlcommons-training-wg-public/mixtral_8x22b/datasets ./datasets -P +``` +## Checkpoints +* Mixtral-8x22B-v0.1-fsdp: use for `tensor_parallelism=1` +``` +mkdir -p checkpoints/Mixtral-8x22B-v0.1-fsdp +rclone copy mlc-training:mlcommons-training-wg-public/mixtral_8x22b/checkpoints/Mixtral-8x22B-v0.1-fsdp ./datasets/Mixtral-8x22B-v0.1-fsdp -P +``` +* Mixtral-8x22B-v0.1-2d-fsdp-tp: use for `tensor_parallelism` > 1 +``` +mkdir -p checkpoints/Mixtral-8x22B-v0.1-2d-fsdp-tp +rclone copy mlc-training:mlcommons-training-wg-public/mixtral_8x22b/checkpoints/Mixtral-8x22B-v0.1-2d-fsdp-tp ./datasets/Mixtral-8x22B-v0.1-fsdp -P +``` + +## Docker Images +``` +mkdir -p docker-images +rclone copy mlc-training:mlcommons-training-wg-public/mixtral_8x22b/docker-images ./docker-images -P +``` diff --git a/mixture_of_experts_pretraining/clm_datasets.py b/mixture_of_experts_pretraining/clm_datasets.py new file mode 100644 index 000000000..c28dc7343 --- /dev/null +++ b/mixture_of_experts_pretraining/clm_datasets.py @@ -0,0 +1,270 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +from itertools import chain + +import pytorch_lightning as pl +import transformers +from datasets import Features, Value, concatenate_datasets, load_dataset +from transformers import logging +from transformers.testing_utils import CaptureLogger +from omegaconf import DictConfig +from transformers import AutoTokenizer +import hydra + +logger = logging.get_logger(__name__) + + +def get_datasets(config): + # Downloading and loading a dataset from the hub. + if config.dataset.dataset_name == "c4_mlperf": + train_data_files = { + "train": [ + f'{os.path.join(config.dataset.train_dataset_path, f"c4-train.{i:05d}-of-01024.json")}' + for i in range(768, 1024) + ], + } + eval_data_files = { + "validation": [ + f'{os.path.join(config.dataset.eval_dataset_path, "c4-validation_24567exp.json")}' + ], + } + features = Features( + { + "text": Value(dtype="string", id=None), + "timestamp": Value(dtype="string", id=None), + "url": Value(dtype="string", id=None), + } + ) + raw_datasets = { + "train": load_dataset( + "json", + data_files=train_data_files, + features=features, + cache_dir=config.cache_local_dir, + streaming=config.dataset.streaming, + split="train", + ), + "validation": load_dataset( + "json", + data_files=eval_data_files, + features=features, + cache_dir=config.cache_local_dir, + split="validation", + ), + } + if config.n_eval_examples: + raw_datasets["validation"] = raw_datasets["validation"].select( + range(config.n_eval_examples) + ) + else: + raw_datasets = load_dataset( + config.dataset.dataset_name, + config.dataset.dataset_config_name, + cache_dir=config.cache_local_dir, + streaming=config.dataset.streaming, + ) + return raw_datasets + + +def get_dataset_cuda(config): + import os + + from nemo.collections import llm + from nemo.collections.common.tokenizers.sentencepiece_tokenizer import ( + SentencePieceTokenizer, + ) + + class PreTrainingDataModule(llm.PreTrainingDataModule): + @property + def gpt_dataset_config(self): + config = super().gpt_dataset_config + config.drop_last_partial_validation_sequence = False + return config + + INDEX_MAPPING_DIR = "/cache/dataset" + os.makedirs(INDEX_MAPPING_DIR, exist_ok=True) + tokenizer = SentencePieceTokenizer( + model_path=os.path.join( + config.checkpoint_manager_path, + "context", + "nemo_tokenizer", + "tokenizer.model", + ) + ) + + dataset_train = [ + os.path.join(config.dataset.train_dataset_path, "c4-train.en_6_text_document"), + os.path.join(config.dataset.train_dataset_path, "c4-train.en_7_text_document"), + ] + + dataset_valid = [ + os.path.join( + config.dataset.eval_dataset_path, "c4-validation-small.en_text_document" + ) + ] + + return PreTrainingDataModule( + paths={ + "train": dataset_train, + "validation": dataset_valid, + "test": dataset_valid, + }, + seq_length=config.max_length, + global_batch_size=config.global_train_batch_size, + micro_batch_size=config.per_device_train_batch_size, + tokenizer=tokenizer, + index_mapping_dir=INDEX_MAPPING_DIR, + num_workers=2, + persistent_workers=True, + ) + + +def process_datasets(raw_datasets, tokenizer, config, use_cuda: bool = True): + # First we tokenize all the texts. + column_names = list(raw_datasets["train"].features) + text_column_name = "text" if "text" in column_names else column_names[0] + + # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function + tok_logger = transformers.utils.logging.get_logger( + "transformers.tokenization_utils_base" + ) + + def process_datasets_function(src_datasets, function, desc): + tgt_datasets = {} + for key in src_datasets.keys(): + # use validation batch_size to avoid dropping remainders in group_text + # 2x max_sequence_length is a good batch_size to avoid too many paddings + batch_size = 24567 if key == "validation" else 65536 + # only apply streaming in train dataset + if key == "train" and config.dataset.streaming: + tgt_datasets[key] = src_datasets[key].map( + function, + batched=True, + batch_size=batch_size, + ) + else: + tgt_datasets[key] = src_datasets[key].map( + function, + batched=True, + batch_size=batch_size, + num_proc=config.dataset.num_proc, + load_from_cache_file=config.dataset.load_from_cache_file, + desc=desc, + ) + return tgt_datasets + + def tokenize_function(examples): + with CaptureLogger(tok_logger) as cl: + output = tokenizer(examples[text_column_name]) + # clm input could be much much longer than block_size + if "Token indices sequence length is longer than the" in cl.out: + tok_logger.warning( + "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" + " before being passed to the model." + ) + return output + + tokenized_datasets = process_datasets_function( + raw_datasets, tokenize_function, desc="Running tokenizer on dataset" + ) + tokenized_datasets = { + key: dataset.remove_columns(column_names) + for key, dataset in tokenized_datasets.items() + } + block_size = config.max_length + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + + if total_length % block_size != 0: + pad_length = (total_length // block_size + 1) * block_size - total_length + for k in concatenated_examples.keys(): + concatenated_examples[k].extend([config.pad_token_id] * pad_length) + total_length += pad_length + else: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + + result["labels"] = result["input_ids"].copy() + + return result + + lm_datasets = process_datasets_function( + tokenized_datasets, + group_texts, + desc=f"Grouping texts in chunks of {block_size}", + ) + if config.shuffle: + lm_datasets["train"] = lm_datasets["train"].shuffle( + seed=config.seed, buffer_size=config.dataset.shuffle_buffer_size + ) + + # pad to multiple of batch size in eval/validation dataset + if len(lm_datasets["validation"]) % config.global_eval_batch_size > 0: + num_eval_batches = ( + len(lm_datasets["validation"]) // config.global_eval_batch_size + 1 + ) + pad_number = num_eval_batches * config.global_eval_batch_size - len( + lm_datasets["validation"] + ) + logger.info( + f"Eval data has {len(lm_datasets['validation'])} entries, padding now with " + f"{pad_number} extra entries to get {num_eval_batches * config.global_eval_batch_size} batches." + ) + + def mask_pad(examples): + examples["labels"] = [config.pad_token_id] * len(examples["labels"]) + return examples + + pad_validation_dataset = ( + lm_datasets["validation"].select(range(pad_number)).map(mask_pad) + ) + lm_datasets["validation"] = concatenate_datasets( + [lm_datasets["validation"], pad_validation_dataset] + ) + + return lm_datasets + + +# need to run in cpu with single process +# to walk around undefined `OmegaConf.register_new_resolver` need to overwrite `run_dir` `global_train_batch_size` `global_eval_batch_size` +# python clm_datasets.py model.name_or_path=mistralai/Mixtral-8x22B-v0.1 run_dir=/tmp global_train_batch_size=1 global_eval_batch_size=1 max_length=32768 +@hydra.main(version_base=None, config_path="config", config_name="config") +def main(config: DictConfig): + tokenizer = AutoTokenizer.from_pretrained( + config.model.name_or_path, + add_eos_token=False, + add_bos_token=False, + use_fast=False, + ) + raw_datasets = get_datasets(config) + lm_datasets = process_datasets(raw_datasets, tokenizer, config) + + for i, batch in enumerate(lm_datasets["validation"]): + print(f"{i=}: {batch=}") + + +if __name__ == "__main__": + main() diff --git a/mixture_of_experts_pretraining/config/config.yaml b/mixture_of_experts_pretraining/config/config.yaml new file mode 100644 index 000000000..725c5e2d0 --- /dev/null +++ b/mixture_of_experts_pretraining/config/config.yaml @@ -0,0 +1,87 @@ +defaults: +- _self_ +- model: blank_model +- sched: WarmupHoldPolicy +- dataset: c4_mlperf + +# name for this experiment in the local run directory +exp_name: moe_trial + +# random seed for batch sampling +seed: 0 + +# the batch size for for each accelerator/device +# global_train_batch_size = per_device_train_batch_size * num_devices +per_device_train_batch_size: 1 +global_train_batch_size: ${get_global_batch_size:${per_device_train_batch_size}} + +# the batch size during evaluation and sampling, if enabled +per_device_eval_batch_size: ${per_device_train_batch_size} +global_eval_batch_size: ${get_global_batch_size:${per_device_eval_batch_size}} + +max_grad_norm: 1. + +max_steps: 10 + +pad_token_id: -100 + +output_dir: /tmp + +# early stop once reaching target eval_loss +target_eval_loss: 0 + +# whether to eval at the very beginning of training +do_first_eval: false + +# an OmegaConf resolver that returns the local run directory, calling a function in utils.py +run_dir: ${path_join:${output_dir},${exp_name}} + +# the learning rate +lr: 2e-5 + +# number of steps to accumulate over for each batch +# (e.g. if global_train_batch_size=4 and gradient_accumulation_steps=2, then we will +# accumulate gradients over equivalent batch size of 8 i.e. 2 microbatches of size 4) +gradient_accumulation_steps: 1 + +# the maximum allowed length for an input +max_length: 512 + +# the max number of examples to evaluate on +n_eval_examples: null + +# The optimizer to use; we use RMSprop because it works about as well as Adam and is more memory-efficient +optimizer: ADAMW_TORCH_XLA +weight_decay: 0.1 + +# evaluate and save model every eval_every steps +eval_frequency: -1 + +# path to load checkpoint +checkpoint_manager_path: null + +# shuffle train data set +shuffle: True + +# use float32 in matmul in torch xla +full_precision: False + +# path to save compile cache for torch xla +local_compile_cache_dir: ${run_dir} + +# tensor_parallelism and fsdp parallelism would be num_devices / tensor_parallelism +tensor_parallelism: 1 +context_parallelism: 1 +pipeline_parallelism: 1 +virtual_pipeline_parallelism: 1 + +# cache of models +cache_local_dir: null + +xla_profile_step: -1 + +log_frequency: 1 + +hydra: + run: + dir: ${run_dir} diff --git a/mixture_of_experts_pretraining/config/dataset/c4_mlperf.yaml b/mixture_of_experts_pretraining/config/dataset/c4_mlperf.yaml new file mode 100644 index 000000000..783d3fc69 --- /dev/null +++ b/mixture_of_experts_pretraining/config/dataset/c4_mlperf.yaml @@ -0,0 +1,12 @@ +dataset_name: c4_mlperf +train_dataset_path: gs://mlperf-llm-public2/c4/en_json/3.0.1 +eval_dataset_path: gs://mlperf-llm-public2/c4/en_val_subset_json +streaming: True + +# num of process in data processing +num_proc: 1 + +# whether to load dataset from cache +load_from_cache_file: True + +shuffle_buffer_size: 256 \ No newline at end of file diff --git a/mixture_of_experts_pretraining/config/dataset/wikitext.yaml b/mixture_of_experts_pretraining/config/dataset/wikitext.yaml new file mode 100644 index 000000000..cff4634ae --- /dev/null +++ b/mixture_of_experts_pretraining/config/dataset/wikitext.yaml @@ -0,0 +1,11 @@ +dataset_name: wikitext +dataset_config_name: wikitext-2-raw-v1 +streaming: False + +# num of process in data processing +num_proc: 1 + +# whether to load dataset from cache +load_from_cache_file: True + +shuffle_buffer_size: 256 \ No newline at end of file diff --git a/mixture_of_experts_pretraining/config/experiment/convergence_template.yaml b/mixture_of_experts_pretraining/config/experiment/convergence_template.yaml new file mode 100644 index 000000000..8185e14d0 --- /dev/null +++ b/mixture_of_experts_pretraining/config/experiment/convergence_template.yaml @@ -0,0 +1,52 @@ +# @package _global_ +exp_name: convergence_exp +config_path: null +seed: 4321 +per_device_train_batch_size: 1 +global_train_batch_size: ${get_global_batch_size:${per_device_train_batch_size}} +per_device_eval_batch_size: 1 +global_eval_batch_size: ${get_global_batch_size:${per_device_eval_batch_size}} +max_grad_norm: 1.0 +max_steps: 250 +pad_token_id: -100 +output_dir: /tmp +do_first_eval: true +run_dir: ${path_join:${output_dir},${exp_name}} +lr: 2.0e-05 +gradient_accumulation_steps: 1 +max_length: 32768 +n_eval_examples: null +optimizer: ADAMW_TORCH_XLA +weight_decay: 0.1 +eval_frequency: -1 +checkpoint_manager_path: null +dry_run: false +shuffle: true +full_precision: false +local_compile_cache_dir: ${run_dir} +tensor_parallelism: 1 +cache_local_dir: null +model: + config_path: mixtral822.json + name_or_path: mistralai/Mixtral-8x22B-v0.1 + dtype: bfloat16 + flash_attention: true + capacity_factor: 1.25 + fsdp_config: + fsdp_transformer_layer_cls_to_wrap: + - MixtralDecoderLayer + min_num_params: 0 + xla_fsdp_grad_ckpt: true +sched: + name: WarmupHoldPolicy + warmup_ratio: 0.25 + hold_steps: 10000000000000 + max_steps: 250 +dataset: + dataset_name: c4_mlperf + train_dataset_path: gs://mlperf-llm-public2/c4/en_json/3.0.1 + eval_dataset_path: gs://mlperf-llm-public2/c4/en_val_subset_json + streaming: true + num_proc: 1 + load_from_cache_file: true + shuffle_buffer_size: 256 diff --git a/mixture_of_experts_pretraining/config/experiment/gbs256_tpu.yaml b/mixture_of_experts_pretraining/config/experiment/gbs256_tpu.yaml new file mode 100644 index 000000000..105be5f97 --- /dev/null +++ b/mixture_of_experts_pretraining/config/experiment/gbs256_tpu.yaml @@ -0,0 +1,52 @@ +# @package _global_ +exp_name: mixtral8x22-dropped-241125-1659 +config_path: null +seed: 4321 +per_device_train_batch_size: 1 +global_train_batch_size: 256 +per_device_eval_batch_size: 1 +global_eval_batch_size: 256 +max_grad_norm: 1.0 +max_steps: 250 +pad_token_id: -100 +output_dir: /app/output +do_first_eval: false +run_dir: /app/output/mixtral8x22-dropped-241125-1659 +lr: 2.0e-05 +gradient_accumulation_steps: 1 +max_length: 32768 +n_eval_examples: null +optimizer: ADAMW_TORCH_XLA +weight_decay: 0.1 +eval_frequency: 6 +checkpoint_manager_path: gs://lizhiyu-multipods-eu-west/moe/checkpoints-20240803/mixtral822/ +dry_run: false +shuffle: true +full_precision: false +local_compile_cache_dir: /app/output/mixtral8x22-dropped-241125-1659 +tensor_parallelism: 1 +cache_local_dir: null +model: + config_path: mixtral822.json + name_or_path: mistralai/Mixtral-8x22B-v0.1 + dtype: bfloat16 + flash_attention: true + capacity_factor: 1.25 + fsdp_config: + fsdp_transformer_layer_cls_to_wrap: + - MixtralDecoderLayer + min_num_params: 0 + xla_fsdp_grad_ckpt: true +sched: + name: WarmupHoldPolicy + warmup_ratio: 0.25 + hold_steps: 10000000000000 + max_steps: 250 +dataset: + dataset_name: c4_mlperf + train_dataset_path: gs://mlperf-llm-public2/c4/en_json/3.0.1 + eval_dataset_path: gs://mlperf-llm-public2/c4/en_val_subset_json + streaming: true + num_proc: 1 + load_from_cache_file: true + shuffle_buffer_size: 256 diff --git a/mixture_of_experts_pretraining/config/model/blank_model.yaml b/mixture_of_experts_pretraining/config/model/blank_model.yaml new file mode 100644 index 000000000..503e63040 --- /dev/null +++ b/mixture_of_experts_pretraining/config/model/blank_model.yaml @@ -0,0 +1,11 @@ +config_path: null +name_or_path: mistralai/Mixtral-8x7B-v0.1 +dtype: bfloat16 +flash_attention: True +capacity_factor: 0 # dropped implementation with a positive number +max_sequence_length: ${max_length} + +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: ["MixtralDecoderLayer"] + min_num_params: 0 + xla_fsdp_grad_ckpt: true diff --git a/mixture_of_experts_pretraining/config/sched/CosineAnnealing.yaml b/mixture_of_experts_pretraining/config/sched/CosineAnnealing.yaml new file mode 100644 index 000000000..c44d9614a --- /dev/null +++ b/mixture_of_experts_pretraining/config/sched/CosineAnnealing.yaml @@ -0,0 +1,5 @@ +name: CosineAnnealing +warmup_ratio: 0.25 +# warmup_steps: 150 +min_lr: ${multiply:0.1,${lr}} +max_steps: ${max_steps} diff --git a/mixture_of_experts_pretraining/config/sched/WarmupHoldPolicy.yaml b/mixture_of_experts_pretraining/config/sched/WarmupHoldPolicy.yaml new file mode 100644 index 000000000..2b2111464 --- /dev/null +++ b/mixture_of_experts_pretraining/config/sched/WarmupHoldPolicy.yaml @@ -0,0 +1,5 @@ +name: WarmupHoldPolicy +warmup_ratio: 0.25 +# warmup_steps: 150 +hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant +max_steps: ${max_steps} diff --git a/mixture_of_experts_pretraining/docker/gpu/Dockerfile b/mixture_of_experts_pretraining/docker/gpu/Dockerfile new file mode 100644 index 000000000..7eff25010 --- /dev/null +++ b/mixture_of_experts_pretraining/docker/gpu/Dockerfile @@ -0,0 +1,49 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM nvcr.io/nvidia/pytorch:24.09-py3 +WORKDIR /app + +ARG NEMO_REVISION=7f3da35e1 +RUN git clone https://github.com/NVIDIA/NeMo.git && \ + rm -rf /opt/NeMo && \ + cd NeMo && \ + git checkout ${NEMO_REVISION} && \ + sed -i "/mamba-ssm/d" requirements/requirements_nlp.txt && \ + sed -i 's/tensorstore<0.1.46/tensorstore/g' requirements/requirements_nlp.txt && \ + sed -i "/triton>=3.1.0/d" requirements/requirements.txt && \ + pip install --no-build-isolation -e ".[nlp]" + +# ## Megatron-core +ARG MCORE_REVISION=a616d459039ae103257f6a20922261ac11ccbdf6 +RUN pip uninstall -y megatron-core && \ + rm -rf /opt/megatron-lm && \ + git clone https://github.com/NVIDIA/Megatron-LM.git && \ + cd Megatron-LM && \ + git checkout ${MCORE_REVISION} && \ + echo MCORE_COMMIT_HASH=$(git rev-parse HEAD) && \ + pip install . && \ + cd megatron/core/datasets && \ + make +ENV PYTHONPATH "${PYTHONPATH}:/app/Megatron-LM" + +RUN pip install git+https://github.com/NVIDIA/dllogger#egg=dllogger +RUN pip install datasets==2.20.0 hydra-core sentencepiece +RUN pip install "git+https://github.com/mlperf/logging.git" +RUN pip install git+https://github.com/NVIDIA/NeMo-Run.git + +WORKDIR /app/training +ADD . /app/training +RUN patch --directory=/app/Megatron-LM -p1 < docker/gpu/megatron_core.patch + diff --git a/mixture_of_experts_pretraining/docker/gpu/Dockerfile.GCP b/mixture_of_experts_pretraining/docker/gpu/Dockerfile.GCP new file mode 100644 index 000000000..972773cf7 --- /dev/null +++ b/mixture_of_experts_pretraining/docker/gpu/Dockerfile.GCP @@ -0,0 +1,40 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +ARG FROM_BASE_IMAGE + +FROM ${FROM_BASE_IMAGE} +WORKDIR /app + +# GCSfuse components (used to provide shared storage, not intended for high performance) +RUN apt-get update && apt-get install --yes --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + && echo "deb https://packages.cloud.google.com/apt gcsfuse-buster main" \ + | tee /etc/apt/sources.list.d/gcsfuse.list \ + && echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \ + | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ + && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \ + && apt-get update \ + && apt-get install --yes gcsfuse \ + && apt-get install --yes google-cloud-cli \ + && apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \ + && mkdir /gcs + +# Install the Google Cloud SDK +RUN apt-get update && apt-get install -y google-cloud-sdk vim + +# checkpoint loading in gcs +RUN pip install gcsfs + diff --git a/mixture_of_experts_pretraining/docker/gpu/build_and_push_image.sh b/mixture_of_experts_pretraining/docker/gpu/build_and_push_image.sh new file mode 100644 index 000000000..5d1eadc85 --- /dev/null +++ b/mixture_of_experts_pretraining/docker/gpu/build_and_push_image.sh @@ -0,0 +1,24 @@ +set -euox pipefail +SCRIPTS_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" && pwd )" +DATE=$(date +%Y%m%d) +: ${PROJECT_ID:=cloud-tpu-multipod-dev} +: ${IMAGE:=gcr.io/${PROJECT_ID}/${USER}-pytorch-nemo-moe-${DATE}} +: ${DOCKER_BUILD_ARGS:=""} + +pushd ${SCRIPTS_DIR} + +docker build --network host \ + --file Dockerfile \ + --tag ${IMAGE}-base \ + ${DOCKER_BUILD_ARGS} \ + . + +docker build --network host \ + --file Dockerfile.GCP \ + --tag ${IMAGE} \ + --build-arg FROM_BASE_IMAGE=${IMAGE}-base \ + . + +popd + +docker push ${IMAGE} diff --git a/mixture_of_experts_pretraining/docker/gpu/megatron_core.patch b/mixture_of_experts_pretraining/docker/gpu/megatron_core.patch new file mode 100644 index 000000000..19b975f4f --- /dev/null +++ b/mixture_of_experts_pretraining/docker/gpu/megatron_core.patch @@ -0,0 +1,89 @@ +diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py +index 2eb7702b..d1f0b9a9 100644 +--- a/megatron/core/datasets/gpt_dataset.py ++++ b/megatron/core/datasets/gpt_dataset.py +@@ -407,9 +407,10 @@ class GPTDataset(MegatronDataset): + + numpy_random_state = numpy.random.RandomState(self.config.random_seed) + ++ shuffle = self.index_split == Split.train + # Build the document index + document_index = _build_document_index( +- self.indices, num_epochs, numpy_random_state, separate_final_epoch ++ self.indices, num_epochs, numpy_random_state, separate_final_epoch, shuffle + ) + + drop_last_partial_sequence = True +@@ -450,11 +451,11 @@ class GPTDataset(MegatronDataset): + # Build the shuffle index + if separate_final_epoch: + shuffle_index = _build_shuffle_index( +- num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state ++ num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state, shuffle + ) + else: + shuffle_index = _build_shuffle_index( +- sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state ++ sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state, shuffle + ) + + if path_to_cache: +@@ -558,6 +559,7 @@ def _build_document_index( + num_epochs: int, + numpy_random_state: numpy.random.RandomState, + separate_final_epoch: bool, ++ shuffle: bool = True, + ) -> numpy.ndarray: + """Build an array with length = num epochs * num documents + +@@ -578,7 +580,8 @@ def _build_document_index( + document_index[:] = documents + document_index = document_index.reshape(-1) + document_index = document_index.astype(numpy.int32) +- numpy_random_state.shuffle(document_index) ++ if shuffle: ++ numpy_random_state.shuffle(document_index) + return document_index + + doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False) +@@ -587,7 +590,8 @@ def _build_document_index( + + + def _build_shuffle_index( +- num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState ++ num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState, ++ shuffle: bool = True + ) -> numpy.ndarray: + """Build the range [0, size) and shuffle + +@@ -607,12 +611,16 @@ def _build_shuffle_index( + dtype_ = numpy.int64 + + shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_) +- numpy_random_state.shuffle(shuffle_idx_first) ++ ++ if shuffle: ++ numpy_random_state.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) +- numpy_random_state.shuffle(shuffle_idx_last) ++ ++ if shuffle: ++ numpy_random_state.shuffle(shuffle_idx_last) + + return numpy.concatenate((shuffle_idx_first, shuffle_idx_last)) + +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 0c1504d4..71d29629 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -264,6 +264,7 @@ def topk_softmax_with_capacity( + # Pre softmax + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = torch.topk(scores, k=topk, dim=1) ++ probs /= probs.sum(dim=-1, keepdim=True) + else: + # Post softmax + if topk == 1: diff --git a/mixture_of_experts_pretraining/docker/tpu/Dockerfile b/mixture_of_experts_pretraining/docker/tpu/Dockerfile new file mode 100644 index 000000000..5a102923e --- /dev/null +++ b/mixture_of_experts_pretraining/docker/tpu/Dockerfile @@ -0,0 +1,53 @@ +FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm + +WORKDIR /app + +RUN pip install datasets==3.2.0 accelerate==1.2.1 evaluate==0.4.3 scikit-learn==1.6.0 + +# Install jax, jaxlib, libtpu nightly +RUN pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +RUN pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +RUN pip install libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -U --pre + +# Add --no-deps to avoid version dependency conflicts between all above libraries like jax, pallas, torch or torch_xla +RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html --no-deps + +# custom transformers with static mixtral implementation +# TODO: import changes in the current repo +# branch: lizhiyu/dpo_static_default +ARG TRANSFORMERS_REVISION=6172624929ce75c0f0ececa776d70415b9829c75 +RUN git clone https://github.com/pytorch-tpu/transformers && \ + cd transformers && \ + echo TRANSFORMERS_REVISION=${TRANSFORMERS_REVISION} && \ + git checkout ${TRANSFORMERS_REVISION} && \ + echo TRANSFORMERS_REVISION=$(git rev-parse HEAD) && \ + pip install -e . + +# Add the Google Cloud SDK package repository +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list +RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - + +# Install the Google Cloud SDK +RUN apt-get update && apt-get install -y google-cloud-sdk vim + +RUN pip install hydra-core==1.3.2 +RUN pip install tensorboard==2.18.0 tensorboardX==2.6.2.2 +RUN pip install sentencepiece==0.2.0 + +# checkpoint loading in gcs +RUN pip install gcsfs==2024.12.0 + +# mlperf log +RUN pip install git+https://github.com/mlperf/logging.git@eb9e1a39bc313d964e9c1955d76384a6f3a731d3 + +# import schedulers from nemo +RUN pip install nemo_toolkit==1.23.0 pytorch-lightning==2.5.0.post0 huggingface-hub==0.23.2 + +WORKDIR /app + +# TODO change to mlcommon git +RUN git clone -b lizhiyu/moe --filter=blob:none --sparse https://github.com/ZhiyuLi-goog/training.git training + +WORKDIR /app/training + +RUN git sparse-checkout set mixture_of_experts_pretraining \ No newline at end of file diff --git a/mixture_of_experts_pretraining/docker/tpu/build_and_push_image.sh b/mixture_of_experts_pretraining/docker/tpu/build_and_push_image.sh new file mode 100644 index 000000000..5ad94f2d0 --- /dev/null +++ b/mixture_of_experts_pretraining/docker/tpu/build_and_push_image.sh @@ -0,0 +1,17 @@ +set -euox pipefail +SCRIPTS_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" && pwd )" +DATE=$(date +%Y%m%d) +: ${PROJECT_ID:=cloud-tpu-multipod-dev} +: ${IMAGE:=gcr.io/${PROJECT_ID}/${USER}-pytorch-xla-moe-${DATE}} +: ${DOCKER_BUILD_ARGS:=""} + +pushd ${SCRIPTS_DIR} + +docker build --network host \ + --file Dockerfile \ + --tag ${IMAGE} \ + ${DOCKER_BUILD_ARGS} \ + . +popd + +docker push ${IMAGE} diff --git a/mixture_of_experts_pretraining/download_dataset.py b/mixture_of_experts_pretraining/download_dataset.py new file mode 100644 index 000000000..476b165fb --- /dev/null +++ b/mixture_of_experts_pretraining/download_dataset.py @@ -0,0 +1,20 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import sys +from huggingface_hub import snapshot_download + +snapshot_download(repo_id=sys.argv[1], local_dir=sys.argv[2]) diff --git a/mixture_of_experts_pretraining/file_utils.py b/mixture_of_experts_pretraining/file_utils.py new file mode 100644 index 000000000..ef8628ffb --- /dev/null +++ b/mixture_of_experts_pretraining/file_utils.py @@ -0,0 +1,51 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os + +try: + from google.cloud import storage + + HAS_IMPORT_GOOGLE_CLOUD_SDK_EXCEPTION = None +except ImportError as e: + HAS_IMPORT_GOOGLE_CLOUD_SDK_EXCEPTION = e + + +def parse_gcs_bucket_and_blob_name(gcs_path): + splits = gcs_path.replace("gs://", "").split("/", 1) + bucket = splits[0] + blob_name = "" if len(splits) == 1 else splits[1] + return bucket, blob_name + + +def get_blob(gcs_path): + bucket, blob_name = parse_gcs_bucket_and_blob_name(gcs_path) + assert blob_name, f"{blob_name=} should be a valid name" + storage_client = storage.Client() + bucket = storage_client.bucket(bucket) + blob = bucket.blob(blob_name) + return blob + + +def get_file(path, mode): + if path.startswith("gs://"): + if HAS_IMPORT_GOOGLE_CLOUD_SDK_EXCEPTION: + raise HAS_IMPORT_GOOGLE_CLOUD_SDK_EXCEPTION + return get_blob(path).open(mode) + else: + file_dir = os.path.dirname(path) + os.makedirs(file_dir, exist_ok=True) + return open(path, mode) diff --git a/mixture_of_experts_pretraining/helm_context/Chart.yaml b/mixture_of_experts_pretraining/helm_context/Chart.yaml new file mode 100644 index 000000000..ff5eafe72 --- /dev/null +++ b/mixture_of_experts_pretraining/helm_context/Chart.yaml @@ -0,0 +1,6 @@ +apiVersion: v2 +name: megatron_moe_benchmark +description: megatron_moe_benchmark +type: application +version: 0.1.0 +appVersion: "1.16.0" \ No newline at end of file diff --git a/mixture_of_experts_pretraining/helm_context/selected-configuration.yaml b/mixture_of_experts_pretraining/helm_context/selected-configuration.yaml new file mode 100644 index 000000000..62ff01cda --- /dev/null +++ b/mixture_of_experts_pretraining/helm_context/selected-configuration.yaml @@ -0,0 +1,265 @@ +name: megatron_mixtral_8x7b_sft +run: + name: megatron_mixtral_8x7b_sft + results_dir: /app + time_limit: 01:00:00 + dependency: singleton +trainer: + devices: 8 + accelerator: gpu + precision: bf16 + + sft: + max_epochs: 1 + max_steps: 50 + + val_check_interval: 100 + save_interval: ${.val_check_interval} + limit_train_batches: 1.0 + + limit_val_batches: 1.0 + gradient_clip_val: 1.0 + + # can be used to register any custom metrics that require token-by-token generation + # inference_metrics: + # my_metric_name1: + # _target_: + # my_metric_name2: + # _target_: + # + + # do not change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.sft.max_epochs} + max_steps: ${.sft.max_steps} + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + save_nemo_on_train_end: False + filename: 'megatron_gpt_sft--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: False # need to keep this false otherwise it will create multiple last.ckpt files because restore reset the previous best model + +pretrained_checkpoint: + restore_from_path: ./checkpoints/mixtral-8x7b.nemo + +model: + name_or_path: mistralai/Mixtral-8x7B-v0.1 + seed: 1234 + tensor_model_parallel_size: 4 # intra-layer model parallelism + pipeline_model_parallel_size: 4 # inter-layer model parallelism + expert_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: 1 + restore_from_path: ./checkpoints/mixtral-8x7b.nemo # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + encoder_seq_length: 32768 # the sequence length of the encoder model, it will be overwriten by loaded GPT model + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: True + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. More details in megatron_gpt_config.yaml. + answer_only_loss: False # not used right now + gradient_as_bucket_view: False + seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value + use_flash_attention: null # if not None, will match the base model's value + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + steerlm2: + forward_micro_batch_size: 1 # the micro batch size for the forward pass, used to compute the weights + micro_batch_size: 1 # the steerlm2 training micro batch size + + # can be used to customize behavior of model.generate for inference metrics + # note that you have to specify all parameters explicitly even if they match defaults + # as long as you change at least one parameter + # + # inference: + # sampling_params: + # use_greedy: False + # temperature: 0.7 + # top_k: 0 + # top_p: 0.95 + # repetition_penalty: 1.0 + # add_BOS: True + # all_probs: False + # compute_logprob: False + # end_strings: ["<|endoftext|>", ""] + # length_params: + # min_length: 0 + # max_length: 512 + # strategy: + # _target_: + # + + + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + + lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + + data: + data_impl: "custom" + dataloader_type: "single" + chat: False # whether use chatbot data or not + chat_prompt_tokens: # special tokens for the chat prompts, a dictionary of {token_type: token}. note that some tokenizer may combine the characters at the junction between {end_of_turn}{turn_start}. e.g. '', the '><' sometimes is merged to be a single token. This is not supported, try to avoid + system_turn_start: "\x00" + turn_start: "\x11" + label_start: "\x12" + end_of_turn: "\x0A" # \0x0A is '\n' + end_of_name: "\x0A" # \0x0A is '\n' + sample: False # create the index mapping files for the sample data, so max_steps * global_batch_size can be larger than the dataset size + num_workers: 0 + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_path: "wikitext" # Path to a JSONL file corresponding to the source data. Data format is identical to validation_ds. + global_batch_size: 128 + micro_batch_size: 1 + shuffle: True + memmap_workers: null + max_seq_length: ${model.encoder_seq_length} + min_seq_length: 1 + drop_last: True # note that `False` is not currently supported + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + label_key: 'output' + add_eos: True + add_sep: False + add_bos: False + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + + validation_ds: + file_path: "wikitext" # Path to a JSONL file corresponding to the source data. Data format is identical to validation_ds. + global_batch_size: ${model.data.train_ds.global_batch_size} + micro_batch_size: ${model.data.train_ds.micro_batch_size} + shuffle: False + memmap_workers: ${model.data.train_ds.memmap_workers} + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: True # note that `False` is not currently supported + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + output_original_text: True # needed for the proper metrics support + + optim: + name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work. + lr: 3e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-7 + +# an OmegaConf resolver that returns the local run directory, calling a function in utils.py +run_dir: "" + +# the batch size for training; for FSDP, the batch size per GPU is batch_size / (grad_accumulation_steps * num_gpus) +per_device_train_batch_size: 1 + +# the batch size during evaluation and sampling, if enabled +per_device_eval_batch_size: ${per_device_train_batch_size} + +# number of steps to accumulate over for each batch +gradient_accumulation_steps: 1 + +precision: ${trainer.precision} +seed: ${model.seed} +vocab_size: 32000 +max_steps: ${trainer.max_steps} + +global_train_batch_size: ${model.data.train_ds.global_batch_size} + +max_grad_norm: 0. + +# whether to eval at the very beginning of training +do_first_eval: false + +# evaluate and save model every eval_every steps +eval_frequency: 3 + +# combine forward +concatenated_forward: True + +# report frequency of train step +report_metrics_freq: 1 + +# the maximum allowed length for an input +max_length: 512 + +# cache of models +cache_local_dir: null + +dataset: + dataset_name: wikitext + dataset_config_name: wikitext-2-raw-v1 + streaming: False + + # num of process in data processing + num_proc: 1 + + # whether to load dataset from cache + load_from_cache_file: True \ No newline at end of file diff --git a/mixture_of_experts_pretraining/helm_context/templates/nemo-example.yaml b/mixture_of_experts_pretraining/helm_context/templates/nemo-example.yaml new file mode 100644 index 000000000..593ded531 --- /dev/null +++ b/mixture_of_experts_pretraining/helm_context/templates/nemo-example.yaml @@ -0,0 +1,525 @@ +{{ $timestamp := now | unixEpoch }} +{{ $jobSuffix := randAlphaNum 4 | lower }} +{{ $jobuuid := uuidv4 }} + +{{ $nodes := div .Values.workload.gpus 8 | max 1 }} +{{ $gpusPerNode := min .Values.workload.gpus 8 }} + +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: "{{ .Release.Name }}" +data: + nemo-configuration.yaml: |- +{{ .Files.Get "selected-configuration.yaml" | nindent 4 }} +--- +apiVersion: v1 +kind: Service +metadata: + name: "{{ .Release.Name }}" +spec: + clusterIP: None + selector: + job-name: "{{ .Release.Name }}" +--- +{{- $root := . -}} +apiVersion: batch/v1 +kind: Job +metadata: + name: "{{ .Release.Name }}" + namespace: default + labels: + {{- if $root.Values.queue }} + kueue.x-k8s.io/queue-name: "{{ $root.Values.queue }}" + {{- end }} +spec: + {{- if $root.Values.queue }} + suspend: true + {{- end }} + parallelism: {{ $nodes }} + completions: {{ $nodes }} + completionMode: Indexed + ttlSecondsAfterFinished: 43200 + template: + metadata: + annotations: + kubectl.kubernetes.io/default-container: megatron + {{- if $root.Values.volumes.gcsMounts }} + gke-gcsfuse/volumes: "true" + {{- end}} + + spec: + schedulingGates: + - name: "gke.io/topology-aware-auto-scheduling" + hostNetwork: true + dnsPolicy: ClusterFirstWithHostNet + subdomain: "{{.Release.Name}}" + restartPolicy: Never + + {{ if $root.Values.targetNodes }} + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: kubernetes.io/hostname + operator: In + values: + {{- range $hostname := $root.Values.targetNodes }} + - {{ $hostname }} + {{- end }} + {{ end }} + + tolerations: + - operator: "Exists" + key: nvidia.com/gpu + - operator: "Exists" + key: cloud.google.com/impending-node-termination + + volumes: + {{ if eq $root.Values.targetPlatform "gke" }} + - name: nvidia-install-dir-host + hostPath: + path: /home/kubernetes/bin/nvidia + {{ else }} + - name: dmabuf + hostPath: + path: /dev/dmabuf_import_helper + type: CharDevice + - name: cuda-lib + hostPath: + path: /usr/lib/x86_64-linux-gnu/libcuda.so + - name: cuda-lib1 + hostPath: + path: /usr/lib/x86_64-linux-gnu/libcuda.so.1 + - name: cuda-lib535 + hostPath: + path: /usr/lib/x86_64-linux-gnu/libcuda.so.535.104.12 + {{ end }} + + - name: nccl-plugin-volume + emptyDir: {} + {{ if ne $root.Values.network.stack "tcp" }} + - name: tcpx-daemon-socket + hostPath: + path: /run/tcpx + {{ end }} + - name: workload-configuration + configMap: + name: "{{.Release.Name}}" + - name: workload-terminated-volume + emptyDir: {} + - name: local-ssd + hostPath: + path: /mnt/stateful_partition/kube-ephemeral-ssd + - name: shared-memory + emptyDir: + medium: "Memory" + sizeLimit: 250Gi + + {{- range $pvc := $root.Values.volumes.pvcMounts }} + - name: "{{ $pvc.name }}" + persistentVolumeClaim: + claimName: "{{ $pvc.name }}" + {{- end }} + + {{- range $gcs := $root.Values.volumes.gcsMounts }} + - name: "{{ $gcs.bucketName }}" + csi: + driver: gcsfuse.csi.storage.gke.io + volumeAttributes: + bucketName: "{{ $gcs.bucketName }}" + {{- end}} + + initContainers: + + {{ if ne $root.Values.network.stack "tcp" }} + - name: nccl-plugin-installer + image: "{{ $root.Values.network.pluginVersion }}" + imagePullPolicy: Always + volumeMounts: + - name: nccl-plugin-volume + mountPath: /usr/local/nccl-plugin + command: + - /bin/sh + - -c + - | + mkdir -p /var/lib/tcpxo + ln -s /var/lib/tcpxo /var/lib/tcpx + /scripts/container_entry.sh install --install-nccl + # cp -r /var/lib/tcpxo/lib64/. /usr/local/nccl-plugin/lib64 + cp -r /var/lib/tcpxo/* /usr/local/nccl-plugin/ + echo "Installed NCCL plugin to pod-wide, shared NCCL plug-in volume" + echo "Contents (mounted at /usr/local/nccl-plugin/lib64):" + ls /usr/local/nccl-plugin/lib64 | sed 's/^/ /' + echo "Contents (mounted at /usr/local/nccl-plugin/):" + ls /usr/local/nccl-plugin/ | sed 's/^/ /' + + {{ end }} + + containers: + + # Either the tcpx or tcpxo receive daemon + {{ if ne $root.Values.network.stack "tcp" }} + - name: network-rx-daemon + image: "{{ $root.Values.network.daemonVersion }}" + imagePullPolicy: Always + securityContext: + privileged: true + volumeMounts: + - name: tcpx-daemon-socket + mountPath: /tmp + - name: workload-terminated-volume + mountPath: /semaphore + {{ if eq $root.Values.targetPlatform "gke" }} + - name: nvidia-install-dir-host + mountPath: "/usr/local/nvidia" + {{ else }} + - name: dmabuf + mountPath: /dev/dmabuf_import_helper + - name: cuda-lib + mountPath: /usr/lib/x86_64-linux-gnu/libcuda.so + - name: cuda-lib1 + mountPath: /usr/lib/x86_64-linux-gnu/libcuda.so.1 + - name: cuda-lib535 + mountPath: /usr/lib/x86_64-linux-gnu/libcuda.so.535.104.12 + {{ end }} + env: + - name: LD_LIBRARY_PATH + {{ if eq $root.Values.targetPlatform "gke" }} + value: /usr/local/nvidia/lib64 + {{ else }} + value: /usr/local/cuda-12.2/lib64 + {{ end }} + + {{ if eq $root.Values.network.stack "tcpx" }} + command: + - bash + - -c + - | + /tcpgpudmarxd/build/app/tcpgpudmarxd --gpu_nic_preset a3vm --gpu_shmem_type fd --setup_param "--verbose 128 2 0" & + while [ ! -e "/semaphore/workload_terminated" ]; do sleep 10; done + pkill -e "^"tcpgpudmarxd || true + sleep 15 + {{ end }} + + {{ if eq $root.Values.network.stack "tcpxo" }} + command: + - bash + - -c + - | + /fts/entrypoint_rxdm_container.sh --num_hops 2 --num_nics 8 --uid= --alsologtostderr & + while [ ! -e "/semaphore/workload_terminated" ]; do sleep 10; done + pkill -e "^"entrypoint_rxdm_container.sh || true + sleep 15 + {{ end }} + + {{ end }} + + - name: megatron + image: "{{ $root.Values.workload.image }}" + imagePullPolicy: Always + securityContext: + privileged: true + env: + - name: JOB_IDENTIFIER + value: "{{ .Release.Name }}-{{ $timestamp }}-{{ $jobSuffix }}" + - name: JOB_TIMESTAMP + value: "{{ $timestamp }}" + - name: JOB_UUID + value: "{{ $jobuuid }}" + - name: JOB_ORCHESTRATOR + value: "gke" + + - name: SSD_MOUNT_PATH + value: "{{ $root.Values.volumes.ssdMountPath }}" + + # The following settings are specific to the Torch distributed launcher: + - name: GCS_FUSE_BUCKET + value: "{{ $root.Values.workload.gcsBucketForDataCataPath }}" + - name: TORCH_DISTRIBUTED_TARGET + value: "{{ $root.Values.workload.torchDistributedTarget }}" + - name: TORCH_DISTRIBUTED_TRACING + value: "ALL" + - name: HF_TOKEN + value: "{{ $root.Values.workload.hfToken }}" + + + - name: MASTER_ADDR + value: "{{.Release.Name}}-0.{{.Release.Name}}.default.svc.cluster.local" + - name: MASTER_PORT + value: "6002" + - name: WORLD_SIZE + value: "{{ $root.Values.workload.gpus }}" + - name: NNODES + value: "{{ $nodes }}" + - name: GPUS_PER_NODE + value: "{{ $gpusPerNode }}" + - name: GLOO_SOCKET_IFNAME + {{ if eq $root.Values.targetPlatform "gke" }} + value: "eth0" + {{ else }} + value: "enp0s12" + {{ end }} + + # The leader node can launch an embedded Tensorboard server (if needed) + {{- if $root.Values.workload.embeddedTensorboardTarget }} + - name: EMBEDDED_TENSORBOARD_TARGET + value: "{{ $root.Values.workload.embeddedTensorboardTarget}}" + {{- end }} + + # The following arguments are passed to the Workload: + {{- range $environment_variable := $root.Values.workload.arguments }} + - name: "WORKLOAD_{{ $environment_variable.name }}" + value: "{{ $environment_variable.value }}" + {{- end }} + + # The following is needed to prevent send-receive stalling execution + - name: NVTE_FWD_LAYERNORM_SM_MARGIN + value: "8" + - name: NVTE_BWD_LAYERNORM_SM_MARGIN + value: "8" + + {{ if ne $root.Values.network.stack "tcp" }} + + # The following TCPxo settings should likely not be adjusted: + {{ if eq $root.Values.network.stack "tcpxo" }} + - name: NCCL_BUFFSIZE + value: "8388608" + - name: NCCL_FASTRAK_CTRL_DEV + {{ if eq $root.Values.targetPlatform "gke" }} + value: "eth0" + {{ else }} + value: "enp0s12" + {{ end }} + - name: NCCL_FASTRAK_IFNAME + {{ if eq $root.Values.targetPlatform "gke" }} + value: "eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8" + {{ else }} + value: "enp6s0f0,enp7s0f0,enp13s0f0,enp14s0f0,enp134s0f0,enp135s0f0,enp141s0f0,enp142s0f0" + {{ end }} + - name: NCCL_FASTRAK_NUM_FLOWS + value: "2" + - name: NCCL_FASTRAK_NUM_FLOWS_PER_GROUP + value: "1" + - name: NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL + value: "0" + - name: NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING + value: "0" + - name: NCCL_FASTRAK_USE_SNAP + value: "1" + - name: NCCL_FASTRAK_USE_LLCM + value: "1" + + # The following NCCL tuner settings should likely not be adjusted: + - name: NCCL_TUNER_PLUGIN + value: "libnccl-tuner.so" + - name: NCCL_TUNER_CONFIG_PATH + value: "/usr/local/nccl-plugin/lib64/a3plus_tuner_config.textproto" + - name: NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE + value: "/usr/local/nccl-plugin/lib64/a3plus_guest_config.textproto" + + {{ end }} + + {{ if eq $root.Values.network.stack "tcpx" }} + - name: NCCL_GPUDIRECTTCPX_CTRL_DEV + value: "eth0" + - name: NCCL_GPUDIRECTTCPX_SOCKET_IFNAME + value: "eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8" + - name: NCCL_GPUDIRECTTCPX_TX_BINDINGS + value: "eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177" + - name: NCCL_GPUDIRECTTCPX_RX_BINDINGS + value: "eth1:22-35,126-139;eth2:22-35,126-139;eth3:74-87,178-191;eth4:74-87,178-191" + - name: NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS + value: "500000" + {{ end }} + + # The following NCCL settings should likely not be adjusted: + - name: NCCL_SOCKET_IFNAME + {{ if eq $root.Values.targetPlatform "gke" }} + value: "eth0" + {{ else }} + value: "enp0s12" + {{ end }} + - name: NCCL_DYNAMIC_CHUNK_SIZE + value: "524288" + - name: NCCL_P2P_NET_CHUNKSIZE + value: "524288" + - name: NCCL_P2P_PCI_CHUNKSIZE + value: "524288" + - name: NCCL_P2P_NVL_CHUNKSIZE + value: "1048576" + - name: NCCL_CROSS_NIC + value: "0" + - name: NCCL_PROTO + value: "Simple" + - name: NCCL_NET_GDR_LEVEL + value: "PIX" + - name: NCCL_P2P_PXN_LEVEL + value: "0" + - name: NCCL_NVLS_ENABLE + value: "0" + + {{- range $environment_variable := $root.Values.network.ncclSettings }} + - name: {{ $environment_variable.name }} + value: "{{ $environment_variable.value }}" + {{- end }} + + {{ end }} + + command: + - bash + - -c + - | + function on_script_completion { + # Note: This semaphore is used to terminate the TCPx side-car + touch /semaphore/workload_terminated + } + trap on_script_completion EXIT + echo "Pod on $(hostname --fqdn) is running" + echo "Pod is assigned job index of $JOB_COMPLETION_INDEX" + echo "Job ID is $JOB_IDENTIFIER" + + echo "Running nvidia-smi" + nvidia-smi + + mkdir -p /gcs + gcsfuse --client-protocol http2 $GCS_FUSE_BUCKET /gcs + + mkdir -p /gcs/index_mapping_dir + + # export LD_LIBRARY_PATH="/usr/local/nccl-plugin/lib64:/usr/local/cuda-12.3/lib64:/usr/local/nvidia/lib64/:${LD_LIBRARY_PATH}" + export LD_LIBRARY_PATH="/usr/local/nccl-plugin/lib64:/usr/local/nvidia/lib64/:${LD_LIBRARY_PATH}" + echo "Warning: Set LD_LIBRARY_PATH=$LD_LIBRARY_PATH to override the NCCL library" + + ldconfig /usr/local/nvidia/lib64/ + echo "Added /usr/local/nvidia/lib64/ to ldconfig:" + ldconfig -p | grep libcuda | sed 's/^/ /' + + echo "Contents of /usr/local/nccl-plugin/lib64:" + ls /usr/local/nccl-plugin/lib64 | sed 's/^/ /' + + touch $SSD_MOUNT_PATH/hello-from-$HOSTNAME.txt + echo "Local SSD contents (path $SSD_MOUNT_PATH):"; ls $SSD_MOUNT_PATH | sed 's/^/ /' + + echo "Downloading GPT vocabulary files" + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json &&\ + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt + + echo "NeMo configuration file:" + cat /etc/workload-configuration/nemo-configuration.yaml | sed 's/^/| /' + echo "" + readarray -d "" workload_arguments < <(env | grep -e "^WORKLOAD_" | sed 's/^WORKLOAD_/+/' | tr '\n' '\0') + echo "Detected the following additional workload arguments:" + for workload_argument in "${workload_arguments[@]}"; do + echo " $workload_argument" + done + + sleep 10 # <- Hack to allow some time for service to boot + + mount /tmp -o remount,exec + chmod -R a+rwx /tmp + + echo "Checking for presence of nsys:" + which nsys + + echo "Nsight profiling will go to /gcs/nemo-experiments/$JOB_IDENTIFIER/." + mkdir -p /gcs/nemo-experiments/$JOB_IDENTIFIER/ + + apt -y update && apt -y install gdb python3.10-dbg + + mkdir -p /app/tmp/ && export TMPDIR=/app/tmp/ + + # Conversion of HF Mixtral checkpoints into NeMo checkpoint + export HF_TOKEN=$HF_TOKEN + python download_dataset.py mistralai/Mixtral-8x7B-v0.1 /app/checkpoints/mixtral-8x7b + python /opt/NeMo/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py \ + --input_name_or_path /app/checkpoints/mixtral-8x7b \ + --output_path /app/checkpoints/mixtral-8x7b.nemo \ + --precision=bf16 + rm -rf /app/checkpoints/mixtral-8x7b + + export NODE_RANK=$JOB_COMPLETION_INDEX + export WORLD_SIZE=$WORLD_SIZE + export TOKENIZERS_PARALLELISM=false + echo "Launching Torch distributed as node rank $NODE_RANK out of $NNODES nodes" + for ((LOCAL_RANK=0; LOCAL_RANK <= $((GPUS_PER_NODE - 1)); LOCAL_RANK++)); do + RANK=$((8*$NODE_RANK + $LOCAL_RANK)) + + OMP_NUM_THREADS=12 RANK=$RANK LOCAL_RANK=$LOCAL_RANK \ + nsys profile -s none -t nvtx,cuda --capture-range=cudaProfilerApi --capture-range-end=stop \ + -o /gcs/nemo-experiments/$JOB_IDENTIFIER/rank-$RANK \ + --session-new "nemo-rank$RANK" \ + python $TORCH_DISTRIBUTED_TARGET \ + --config-path="/etc/workload-configuration" \ + --config-name="nemo-configuration.yaml" \ + +trainer.num_nodes="$NNODES" \ + +exp_manager.version="$JOB_IDENTIFIER" \ + ${workload_arguments[@]} & + + echo "Launched rank $RANK with PID $!" + TORCH_PIDS[$LOCAL_RANK]=$! + done + + if [ "$NODE_RANK" -eq "1" ]; then + echo "Launching nvidia-smi in daemon mode with (20 sec delay)" + nvidia-smi dmon -d 20 -s pum & + fi + + if [ "$NODE_RANK" -eq "0" ] && { ! [ -z ${EMBEDDED_TENSORBOARD_TARGET} ]; }; then + echo "Launching an embedded Tensorboard against log directory $EMBEDDED_TENSORBOARD_TARGET" + tensorboard --logdir $EMBEDDED_TENSORBOARD_TARGET & + wait # <-- This will indefinitely stall node rank 0 + fi + + # Wait for Torch processes (might be problematic if only one fails) + for PID in ${TORCH_PIDS[*]}; do + echo "Waiting on Torch PID $PID" + wait $PID + done + + echo "Pod on $(hostname --fqdn) is exiting" + volumeMounts: + {{ if eq $root.Values.targetPlatform "gke" }} + - name: nvidia-install-dir-host + mountPath: /usr/local/nvidia + {{ else }} + - name: dmabuf + mountPath: /dev/dmabuf_import_helper + - name: cuda-lib + mountPath: /usr/lib/x86_64-linux-gnu/libcuda.so + - name: cuda-lib1 + mountPath: /usr/lib/x86_64-linux-gnu/libcuda.so.1 + - name: cuda-lib535 + mountPath: /usr/lib/x86_64-linux-gnu/libcuda.so.535.104.12 + {{ end }} + - name: nccl-plugin-volume + mountPath: /usr/local/nccl-plugin + {{ if ne $root.Values.network.stack "tcp" }} + - name: tcpx-daemon-socket + mountPath: /tmp + {{ end }} + - name: workload-terminated-volume + mountPath: /semaphore + - name: workload-configuration + mountPath: /etc/workload-configuration + - name: shared-memory + mountPath: /dev/shm + - name: local-ssd + mountPath: "{{ $root.Values.volumes.ssdMountPath }}" + + {{- range $pvc := $root.Values.volumes.pvcMounts }} + - name: "{{ $pvc.name }}" + mountPath: "{{ $pvc.mountPath }}" + {{- end }} + + {{- range $gcs := $root.Values.volumes.gcsMounts }} + - name: "{{ $gcs.bucketName }}" + mountPath: "{{ $gcs.mountPath }}" + {{- end }} + + resources: + limits: + nvidia.com/gpu: {{ $gpusPerNode }} +--- \ No newline at end of file diff --git a/mixture_of_experts_pretraining/helm_context/values.yaml b/mixture_of_experts_pretraining/helm_context/values.yaml new file mode 100644 index 000000000..dde420bae --- /dev/null +++ b/mixture_of_experts_pretraining/helm_context/values.yaml @@ -0,0 +1,52 @@ +targetPlatform: "gke" + +volumes: + # The VM host path for SSDs is assumed at /mnt/stateful_partition/kube-ephemeral-ssd + ssdMountPath: "/ssd" + +workload: + # This should be the image built and pushed to the registry (EDIT THIS) + image: "/address/to/image/built/and/uploaded/:" + + torchDistributedTarget: "run_clm.py" + + # HuggingFace tokens (EDIT THIS) + hfToken: "" + + # It will be mounted to /nfs on container startup using GCS fuse (EDIT THIS) + gcsBucketForDataCataPath: + + gpus: 16 # This should be one of: {<= 8, multiple of 8} + arguments: + # The argument name will be prefixed with '+' (see https://hydra.cc/docs/advanced/override_grammar/basic/) + - name: "exp_manager.explicit_log_dir" + value: "/nemo-experiments/results" + - name: "exp_manager.exp_dir" + value: "/nemo-experiments/" + + # Llama 2 tokenizer (not used) + #- name: "model.data.data_prefix" + # value: "[1.0,/ssd/.cache/wikipedia-tokenized-for-llama2]" + #- name: "model.tokenizer.model" + # value: "/ssd/.cache/llama-2-7b-megatron-checkpoint/tokenizer.model" + + # If not 'null', launches a Tensorboard server on first node. By design, the job will then not exit on first node. + # This is primarly intended for debugging purposes, when a shared file-system or external Tensorboard is unavailable. + embeddedTensorboardTarget: null + +network: + stack: "tcpxo" # one of {"tcp", "tcpx", "tcpxo"} + + daemonVersion: "us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.9" + pluginVersion: "us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/nccl-plugin-gpudirecttcpx-dev:v1.0.3" + + + ncclSettings: + - name: NCCL_DEBUG + value: "VERSION" + - name: NCCL_ALGO + value: "Ring,Tree" + + # The following NCCL settings are recommended for TCPxo only (but tunable): + - name: NCCL_MIN_NCHANNELS + value: "4" diff --git a/mixture_of_experts_pretraining/mixtral80.json b/mixture_of_experts_pretraining/mixtral80.json new file mode 100644 index 000000000..7af26cb3f --- /dev/null +++ b/mixture_of_experts_pretraining/mixtral80.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_experts_per_tok": 2, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.36.0.dev0", + "use_cache": true, + "vocab_size": 32000 +} + diff --git a/mixture_of_experts_pretraining/mixtral822-instruct.json b/mixture_of_experts_pretraining/mixtral822-instruct.json new file mode 100644 index 000000000..3f8f113b0 --- /dev/null +++ b/mixture_of_experts_pretraining/mixtral822-instruct.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 65536, + "model_type": "mixtral", + "num_attention_heads": 48, + "num_experts_per_tok": 2, + "num_hidden_layers": 56, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.0", + "use_cache": true, + "vocab_size": 32768 +} + diff --git a/mixture_of_experts_pretraining/mixtral822.json b/mixture_of_experts_pretraining/mixtral822.json new file mode 100644 index 000000000..5616b5066 --- /dev/null +++ b/mixture_of_experts_pretraining/mixtral822.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 65536, + "model_type": "mixtral", + "num_attention_heads": 48, + "num_experts_per_tok": 2, + "num_hidden_layers": 56, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.0", + "use_cache": true, + "vocab_size": 32000 +} + diff --git a/mixture_of_experts_pretraining/mixtral87.json b/mixture_of_experts_pretraining/mixtral87.json new file mode 100644 index 000000000..de132a80b --- /dev/null +++ b/mixture_of_experts_pretraining/mixtral87.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_experts_per_tok": 2, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.36.0.dev0", + "use_cache": true, + "vocab_size": 32000 +} diff --git a/mixture_of_experts_pretraining/mlperf_logging_utils.py b/mixture_of_experts_pretraining/mlperf_logging_utils.py new file mode 100644 index 000000000..ab2c13801 --- /dev/null +++ b/mixture_of_experts_pretraining/mlperf_logging_utils.py @@ -0,0 +1,424 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os + +import torch +import torch.distributed as dist +from mlperf_logging import mllog +from mlperf_logging.mllog import constants +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, + is_torch_xla_available, +) + +if is_torch_xla_available(): + import torch_xla.runtime as xr + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if is_torch_xla_available(): + return xr.global_ordinal() + else: + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def barrier(): + if not is_dist_avail_and_initialized(): + return + torch.distributed.barrier() + + +class ClmLogger: + def __init__(self, config, filename=None, default_stack_offset=2): + self.mllogger = mllog.get_mllogger() + mllog.config( + default_stack_offset=default_stack_offset, + filename=( + filename + or os.getenv("COMPLIANCE_FILE") + or os.path.join(config.run_dir, "mlperf_compliance.log") + ), + ) + self.target_eval_loss = config.target_eval_loss + + def event(self, key, value=None, metadata=None, sync=False, log_rank=None): + if get_rank() == 0: + self.mllogger.event(key=key, value=value, metadata=metadata) + + def start(self, key, value=None, metadata=None, sync=False, log_rank=None): + if get_rank() == 0: + self.mllogger.start(key=key, value=value, metadata=metadata) + + def end(self, key, value=None, metadata=None, sync=False, log_rank=None): + if get_rank() == 0: + self.mllogger.end(key=key, value=value, metadata=metadata) + + +class MLPerfCallback(TrainerCallback): + "A callback that prints a message at the beginning of training" + + def __init__(self, config): + super().__init__() + self.mllogger = ClmLogger(config) + self.submission_info = { + "submission_benchmark": "mixture-of-expert", # TODO change task name + "submission_division": "closed", + "submission_org": "Google", + "submission_platform": "reference", + "submission_status": "reference", + } + self.mllogger.event( + key=constants.CACHE_CLEAR, + value="True", + ) + self.mllogger.start(key=constants.INIT_START, value="") + self.config = config + self.global_batch_tokens = config.global_train_batch_size * config.max_length + + def on_train_begin(self, args, state, control, **kwargs): + if torch.cuda.is_available(): + num_devices = int(os.getenv("WORLD_SIZE", 1)) + elif is_torch_xla_available(): + num_devices = xr.global_runtime_device_count() + else: + raise ValueError("The pipeline should be either cuda or xla backend.") + + self.global_batch_size = int( + args.per_device_train_batch_size + * args.gradient_accumulation_steps + * num_devices + ) + + self.mllogger.event( + key=constants.SUBMISSION_BENCHMARK, + value=self.submission_info["submission_benchmark"], + ) + self.mllogger.event( + key=constants.SUBMISSION_DIVISION, + value=self.submission_info["submission_division"], + ) + self.mllogger.event( + key=constants.SUBMISSION_ORG, value=self.submission_info["submission_org"] + ) + self.mllogger.event( + key=constants.SUBMISSION_PLATFORM, + value=self.submission_info["submission_platform"], + ) + self.mllogger.event( + key=constants.SUBMISSION_STATUS, + value=self.submission_info["submission_status"], + ) + self.mllogger.event( + key=constants.GLOBAL_BATCH_SIZE, + value=self.config.global_train_batch_size, + ) + self.mllogger.event( + key=constants.EVAL_SAMPLES, + value=12694503, + ) + self.mllogger.event(key=constants.SEED, value=args.seed) + self.mllogger.event( + key=constants.OPT_LR_WARMUP_FACTOR, value=args.sched.warmup_ratio + ) + self.mllogger.event(key=constants.OPT_LR_TRAINING_STEPS, value=args.max_steps) + self.mllogger.event( + key=constants.OPT_ADAMW_WEIGHT_DECAY, value=args.weight_decay + ) + self.mllogger.event( + key=constants.OPT_GRADIENT_CLIP_NORM, value=args.max_grad_norm + ) + self.mllogger.event(key=constants.OPT_BASE_LR, value=args.lr) + self.mllogger.event( + key=constants.GRADIENT_ACCUMULATION_STEPS, + value=args.gradient_accumulation_steps, + ) + # device warmup should be done here + self.mllogger.end(key=constants.INIT_STOP, value="") + + # run on all ranks to allow sync + barrier() + self.mllogger.start(constants.RUN_START, value="") + self.mllogger.start( + constants.BLOCK_START, + value="", + metadata={ + "samples_count": 0, + }, + ) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Event called at the end of a training step. + """ + if state.global_step % state.eval_steps == 0 and state.global_step > 0: + self.mllogger.event( + "train_loss", + value=state.log_history[-1]["train/loss"] if state.log_history else -1, + metadata={ + "samples_count": ( + state.global_step * self.global_batch_tokens + if state.log_history + else -1 + ) + }, + ) + control.should_log = True + + if state.global_step % state.eval_steps == 0: + self.mllogger.end( + constants.BLOCK_STOP, + value="", + metadata={ + "samples_count": state.global_step * self.global_batch_tokens, + }, + ) + self.mllogger.event( + constants.EVAL_ACCURACY, + value=state.log_history[-1]["eval/loss"], + metadata={ + "samples_count": state.global_step * self.global_batch_tokens, + }, + ) + latest_eval_loss = float("nan") + if state.log_history and "eval/loss" in state.log_history[-1]: + latest_eval_loss = state.log_history[-1]["eval/loss"] + if latest_eval_loss <= self.mllogger.target_eval_loss: + control.should_training_stop = True + + # run on all ranks to allow sync + barrier() + self.mllogger.end( + constants.RUN_STOP, + value=latest_eval_loss, + metadata={ + "samples_count": state.global_step * self.global_batch_tokens, + "status": "success", + }, + ) + if state.global_step >= state.max_steps: + control.should_training_stop = True + self.mllogger.end( + constants.RUN_STOP, + value=latest_eval_loss, + metadata={ + "samples_count": state.global_step * self.global_batch_tokens, + "status": "fail", + }, + ) + + if not control.should_training_stop: + self.mllogger.start( + constants.BLOCK_START, + value="", + metadata={ + "samples_count": state.global_step * self.global_batch_tokens + }, + ) + + return control + + +class MLPerfLightningCallback(Callback): + def __init__(self, logger, global_batch_size: int, sequence_length: int): + super().__init__() + self.gbs = global_batch_size + self.seq = sequence_length + self.mllogger = logger + self.force_success = False + + def __deepcopy__(self, memo): + return MLPerfLightningCallback(self.mllogger, self.gbs, self.seq) + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + return super().on_train_batch_start(trainer, pl_module, batch, batch_idx) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + + @rank_zero_only + def on_validation_start(self, trainer, pl_module): + self.mllogger.end( + constants.BLOCK_STOP, + metadata={"samples_count": trainer.global_step * self.gbs * self.seq}, + sync=False, + ) + self.mllogger.start( + key=constants.EVAL_START, + metadata={"samples_count": trainer.global_step * self.gbs * self.seq}, + sync=False, + ) + return super().on_validation_start(trainer, pl_module) + + @rank_zero_only + def on_validation_end(self, trainer, pl_module): + if not trainer.should_stop: + self.mllogger.start( + constants.BLOCK_START, + metadata={"samples_count": trainer.global_step * self.gbs * self.seq}, + sync=False, + ) + return super().on_validation_end(trainer, pl_module) + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + self.mllogger.start( + constants.BLOCK_START, metadata={"samples_count": 0}, sync=False + ) + + @rank_zero_only + def on_train_end(self, trainer, pl_module): + if hasattr(trainer, "run_stop_logged") and not trainer.run_stop_logged: + self.mllogger.end( + constants.RUN_STOP, + metadata={ + "samples_count": trainer.global_step * self.gbs * self.seq, + "status": "aborted" if not self.force_success else "success", + }, + ) + return super().on_train_end(trainer, pl_module) + + +class MetricsLogger(Logger): + def __init__( + self, + logger, + nodes: int, + global_batch_size: int, + learning_rate: float, + sequence_length: int, + ): + super().__init__() + self.nodes = nodes + self.gbs = global_batch_size + self.seq = sequence_length + self.lr = learning_rate + self.mllogger = logger + self.experiment = None + + def __deepcopy__(self, memo): + output = MetricsLogger(self.mllogger, self.nodes, self.gbs, self.lr, self.seq) + if hasattr(self, "trainer"): + output.trainer = self.trainer + return output + + def set_trainer(self, trainer): + self.trainer = trainer + trainer.run_stop_logged = False + + @rank_zero_only + def log_metrics(self, metrics, step): + if "reduced_train_loss" in metrics: + self.mllogger.event( + "train_loss_update", + value=metrics["reduced_train_loss"], + metadata={ + "samples_count": self.trainer.global_step * self.gbs * self.seq, + }, + ) + + if "val_loss" in metrics: + val_loss = metrics["val_loss"] + self.mllogger.event( + constants.EVAL_ACCURACY, + value=val_loss, + metadata={ + "samples_count": self.trainer.global_step * self.gbs * self.seq, + }, + ) + self.mllogger.end( + key=constants.EVAL_STOP, + metadata={ + "samples_count": self.trainer.global_step * self.gbs * self.seq + }, + sync=False, + ) + + @rank_zero_only + def log_hyperparams(self, params, *args, **kwargs): + self.mllogger.event(key=constants.CACHE_CLEAR, value=True) + self.mllogger.start(key=constants.INIT_START) + # self.mllogger.mlperf_submission_log( + # benchmark="mixtral_8x22b", + # num_nodes=self.nodes, + # ) + # self.mllogger.event( + # key=constants.SEED, + # value=self.cfg.model.seed, + # sync=False, + # unique=True, + # ) + self.mllogger.event( + key=constants.GLOBAL_BATCH_SIZE, + value=self.gbs, + sync=False, + ) + # self.mllogger.event( + # key=constants.TRAIN_SAMPLES, + # value=0, + # ) + # self.mllogger.event( + # key=constants.EVAL_SAMPLES, + # value=0, + # ) + # self.mllogger.event( + # key=constants.OPT_LR_WARMUP_FACTOR, + # value=self.cfg.model.optim.sched.warmup_ratio, + # ) + # self.mllogger.event( + # key=constants.OPT_ADAMW_WEIGHT_DECAY, + # value=self.cfg.model.optim.weight_decay, + # ) + # self.mllogger.event( + # key=constants.OPT_GRADIENT_CLIP_NORM, + # value=self.cfg.trainer.gradient_clip_val, + # ) + # ga = int(os.getenv("MINIBS", "1")) // self.cfg.model.micro_batch_size + # self.mllogger.event(key=constants.GRADIENT_ACCUMULATION_STEPS, value=ga) + # self.mllogger.event( + # key=constants.OPT_LR_TRAINING_STEPS, value=self.cfg.trainer.max_steps + # ) + self.mllogger.event(key=constants.OPT_BASE_LR, value=self.lr) + + @property + def name(self): + return "mlperf-metrics" + + @property + def version(self): + return 1 diff --git a/mixture_of_experts_pretraining/model_utils_gpu.py b/mixture_of_experts_pretraining/model_utils_gpu.py new file mode 100644 index 000000000..83bf2aecb --- /dev/null +++ b/mixture_of_experts_pretraining/model_utils_gpu.py @@ -0,0 +1,147 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +import torch +from megatron.core.optimizer import OptimizerConfig +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.common.tokenizers import AutoTokenizer +from nemo.utils import logging + + +def setup_distributed(config): + """Initialize torch.distributed.""" + torch.distributed.init_process_group( + backend="nccl", + ) + + +def setup_model_and_trainer( + model_name_or_path: str, + input_sequence_length: int, + global_batch_size: int, + nodes: int, + tp_size: int, + pp_size: int, + vpp_size: int, + cp_size: int, + learning_rate: float, + weight_decay: float, + optimizer_name: str, + tokenizer_name_or_path: str, + scheduler, + max_grad_norm: float, + eval_frequency: int, + log_frequency: int, + max_steps: int, + *, + logger, + callbacks: list, +): + logging.info("loading model") + + if "mixtral-8x7b" in model_name_or_path.lower(): + mixtral_config = llm.MixtralConfig8x7B() + elif "mixtral-8x22b" in model_name_or_path.lower(): + mixtral_config = llm.MixtralConfig8x22B( + moe_aux_loss_coeff=0.001, + ) + else: + raise ValueError(f"Unknown model specified: {model_name_or_path}") + + resume = nl.AutoResume(resume_from_path="/app/checkpoints/") + tokenizer = AutoTokenizer(pretrained_model_name=tokenizer_name_or_path) + model = llm.MixtralModel(mixtral_config, tokenizer=tokenizer) + + ## initialize the strategy + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + virtual_pipeline_model_parallel_size=vpp_size, + sequence_parallel=True, + context_parallel_size=cp_size, + pipeline_dtype=torch.bfloat16, + ckpt_load_optimizer=False, + ) + + precision = nl.MegatronMixedPrecision( + precision="bf16-mixed", + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + autocast_enabled=False, + grad_reduce_in_fp32=True, + ) + + ## setup the optimizer + opt_config = OptimizerConfig( + optimizer=optimizer_name, + lr=learning_rate, + weight_decay=weight_decay, + bf16=True, + fp16=False, + params_dtype=torch.bfloat16, + clip_grad=max_grad_norm, + use_distributed_optimizer=True, + ) + + if scheduler.name == "CosineAnnealing": + opt_sched = nl.lr_scheduler.CosineAnnealingScheduler( + warmup_steps=scheduler.warmup_steps + if "warmup_steps" in scheduler + else None, + warmup_ratio=scheduler.warmup_ratio + if "warmup_steps" not in scheduler + else None, + max_steps=scheduler.max_steps, + min_lr=scheduler.min_lr, + ) + elif scheduler.name == "WarmupHoldPolicy": + opt_sched = nl.lr_scheduler.WarmupHoldPolicyScheduler( + warmup_steps=scheduler.warmup_steps + if "warmup_steps" in scheduler + else None, + warmup_ratio=scheduler.warmup_ratio + if "warmup_steps" not in scheduler + else None, + hold_steps=scheduler.hold_steps, + max_steps=scheduler.max_steps, + ) + + opt = nl.MegatronOptimizerModule(config=opt_config, lr_scheduler=opt_sched) + trainer = nl.Trainer( + devices=torch.cuda.device_count(), + num_nodes=nodes, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=precision, + callbacks=callbacks, + logger=logger, + enable_progress_bar=False, + val_check_interval=eval_frequency, + log_every_n_steps=log_frequency, + ) + + logger.set_trainer(trainer) + logger.log_hyperparams(None) + + return ( + model, + trainer, + opt, + resume, + ) diff --git a/mixture_of_experts_pretraining/model_utils_tpu.py b/mixture_of_experts_pretraining/model_utils_tpu.py new file mode 100644 index 000000000..26b5d82fc --- /dev/null +++ b/mixture_of_experts_pretraining/model_utils_tpu.py @@ -0,0 +1,376 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +import gc +import os +from omegaconf import OmegaConf +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs +from transformers import logging +from torch_xla.experimental.distributed_checkpoint import ( + CheckpointManager, + prime_optimizer, +) +import numpy as np +from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( + SpmdFullyShardedDataParallel as FSDPv2, +) + +from torch_xla.distributed.fsdp import checkpoint_module + +from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, +) + +from transformers.trainer_pt_utils import ( + get_module_class_from_name, +) +from psutil import Process +from transformers import AutoModelForCausalLM, AutoConfig, TrainerCallback +from nemo.core.optim.lr_scheduler import CosineAnnealing, WarmupHoldPolicy +from torch.utils.tensorboard import SummaryWriter +import json + + +logger = logging.get_logger(__name__) + + +def prepare_model(model, config): + if config.tensor_parallelism == 1: + + def shard_output(output, mesh): + real_output = None + if isinstance(output, torch.Tensor): + real_output = output + elif isinstance(output, tuple): + real_output = output[0] + elif hasattr(output, "logits"): + real_output = output.logits + + if real_output is None: + raise ValueError( + "Something went wrong, the output of the model shouldn't be `None`" + ) + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + + auto_wrap_policy = None + auto_wrapper_callable = None + + default_transformer_cls_names_to_wrap = getattr( + model, "_no_split_modules", None + ) + fsdp_transformer_layer_cls_to_wrap = config.model.fsdp_config.get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) + + if config.model.fsdp_config["min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, + min_num_params=config.model.fsdp_config["min_num_params"], + ) + elif fsdp_transformer_layer_cls_to_wrap is not None: + transformer_cls_to_wrap = set() + for layer_class in fsdp_transformer_layer_cls_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception( + "Could not find the transformer layer class to wrap in the model." + ) + else: + transformer_cls_to_wrap.add(transformer_cls) + + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + + if config.model.fsdp_config["xla_fsdp_grad_ckpt"]: + if model.config.use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + model.config.use_cache = False + + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + target_cls = FSDPv2 + return target_cls(checkpoint_module(m), *args, **kwargs) + + model = FSDPv2( + model, + shard_output=shard_output, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + ) + + return model + else: + model.to("xla") + mesh = xs.get_global_mesh() + for name, param in model.named_parameters(): + logger.debug(f"> [2D] Sharding tensor {name}, {param.shape}") + + # Here we intentionally skip layernorm and moe.gate weights given they are small. + if "embed_tokens" in name: + xs.mark_sharding( + param, mesh, ("fsdp", "tensor") + ) # needed to have activations fully sharded. + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + xs.mark_sharding(param, mesh, ("tensor", "fsdp")) + elif "o_proj" in name: + xs.mark_sharding(param, mesh, ("fsdp", "tensor")) + elif "w1" in name or "w3" in name: + xs.mark_sharding(param, mesh, ("tensor", "fsdp")) + elif "w2" in name: + xs.mark_sharding(param, mesh, ("fsdp", "tensor")) + elif "lm_head" in name: + xs.mark_sharding(param, mesh, ("tensor", "fsdp")) + + logger.info(f"{name} {torch_xla._XLAC._get_xla_sharding_spec(param)}") + + for i, block in enumerate(model.model.layers): + xs.apply_backward_optimization_barrier(model.model.layers[i]) + logger.info("Applying gradient checkpointing") + if model.config.use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + model.config.use_cache = False + + for i, block in enumerate(model.model.layers): + model.model.layers[i] = checkpoint_module(block) + + return model + + +def print_param_sharding(model): + for name, param in model.named_parameters(): + logger.debug( + f"{name}: {param.shape} {param.dtype} {torch_xla._XLAC._get_xla_sharding_spec(param)}" + ) + + +def setup_xla(config): + if config.local_compile_cache_dir: + xr.initialize_cache(config.local_compile_cache_dir) + if config.full_precision: + import jax + + assert config.model.dtype == "float32", "model dtype need to be float32" + torch_xla._XLAC._xla_set_use_full_mat_mul_precision( + use_full_mat_mul_precision=True + ) + jax.config.update("jax_default_matmul_precision", "highest") + + num_devices = xr.global_runtime_device_count() + mesh_shape = (num_devices // config.tensor_parallelism, config.tensor_parallelism) + device_ids = np.array(range(num_devices)) + mesh = xs.Mesh(device_ids, mesh_shape, axis_names=("fsdp", "tensor")) + xs.set_global_mesh(mesh) + + +def fmt_size(num_bytes: int) -> str: + assert num_bytes > 0 + for unit in ["B", "KiB", "MiB", "GiB"]: + if num_bytes < 1024.0: + break + num_bytes /= 1024.0 + return f"{num_bytes:.2f} {unit}" + + +def get_cpu_memory() -> str: + """print out cpu/tpu memory.""" + cpu_bytes = Process().memory_info().rss + return fmt_size(cpu_bytes) + + +def setup_model_optimizer(config): + dtype = getattr(torch, config.model.dtype) + + logger.debug(f"cpu memory usage: {get_cpu_memory()}") + + logger.info("loading model") + if config.model.config_path: + model_config = AutoConfig.from_pretrained(config.model.config_path) + model_config.static = True + model_config.flash_attention = config.model.flash_attention + model_config.gmm = False + model_config.gmm_stack = False + model_config.capacity_factor = config.model.capacity_factor + model_config.output_router_logits = True + with torch.device("meta"): + model = ( + AutoModelForCausalLM.from_config(model_config) + .to_empty(device=xm.xla_device()) + .to(torch.bfloat16) + ) + else: + model = AutoModelForCausalLM.from_pretrained( + config.model.name_or_path, + cache_dir=config.cache_local_dir, + low_cpu_mem_usage=True, + torch_dtype=dtype, + ) + + if model.config.architectures == ["MixtralForCausalLM"]: + for layer in model.model.layers: + layer.self_attn.rotary_emb._set_buffer(device=xm.xla_device()) + logger.info("model loaded") + model = prepare_model(model, config) + model = model.to(dtype) + logger.info("model prepared") + gc.collect() + xm.mark_step() + logger.debug(f"cpu memory usage: {get_cpu_memory()}") + + print_param_sharding(model) + + if config.checkpoint_manager_path: + torch.distributed.init_process_group("gloo", init_method="xla://") + logger.info(f"checkpoint found from {config.checkpoint_manager_path=}") + + ckpt_manager = CheckpointManager( + path=config.checkpoint_manager_path, + save_interval=float("inf"), + max_to_keep=0, + ) + + state_dict = { + "model": model.state_dict(), + } + ckpt_manager.restore(0, state_dict) + model.load_state_dict(state_dict["model"]) + del state_dict + xm.mark_step() + logger.info("checkpoint loaded") + else: + if config.model.config_path: + model.apply(model._init_weights) + + no_decay = ["bias", "layer_norm.weight"] + + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": config.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + if config.optimizer == "ADAMW_TORCH_XLA": + from torch_xla.amp.syncfree import AdamW + + optimizer = AdamW( + optimizer_grouped_parameters, lr=config.lr, + ) + else: + optimizer = getattr(torch.optim, config.optimizer)( + optimizer_grouped_parameters, lr=config.lr + ) + + # initialize optimizer states and scheduler + optimizer = prime_optimizer(optimizer) + sched_config = OmegaConf.to_container(config.sched, resolve=True) + scheduler_name = sched_config.pop("name") + if scheduler_name == "WarmupHoldPolicy": + scheduler = WarmupHoldPolicy(optimizer=optimizer, **sched_config) + elif scheduler_name == "CosineAnnealing": + assert ( + config.lr >= sched_config["min_lr"] + ), f"{config.lr=} should be larger than {config.sched.min_lr=}" + scheduler = CosineAnnealing(optimizer=optimizer, **sched_config) + else: + raise ValueError( + f"{config.sched.name=} should be one of valid schedulers (WarmupHoldPolicy, CosineAnnealing)" + ) + + return model, optimizer, scheduler + + +def get_global_batch_size(per_device_batch_size): + num_devices = xr.global_runtime_device_count() + global_batch_size = int(per_device_batch_size * num_devices) + return global_batch_size + + +def flatten(dictionary, parent_key="", separator="_"): + items = [] + for key, value in dictionary.items(): + new_key = parent_key + separator + key if parent_key else key + if isinstance(value, dict): + items.extend(flatten(value, new_key, separator=separator).items()) + else: + items.append((new_key, value)) + return dict(items) + + +class TensorBoardCallback(TrainerCallback): + """ + A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard). + + Args: + tb_writer (`SummaryWriter`, *optional*): + The writer to use. Will instantiate one if not set. + """ + + def __init__(self, config): + if xr.process_index() == 0: + exp_config = {} + for k, v in flatten(OmegaConf.to_container(config)).items(): + if isinstance(v, (str, int, float, str, bool, torch.Tensor)): + exp_config[k] = v + else: + exp_config[k] = str(v) + self.tb_writer = SummaryWriter( + log_dir=os.path.join(config.run_dir, "tensorboard") + ) + self.tb_writer.add_text("model_config", json.dumps(exp_config, indent=2)) + + def on_log(self, args, state, control, logs=None, **kwargs): + if xr.process_index() == 0: + for k, v in logs.items(): + if isinstance(v, (int, float)): + self.tb_writer.add_scalar(k, v, state.global_step) + else: + logger.warning( + "Trainer is attempting to log a value of " + f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' + "This invocation of Tensorboard's writer.add_scalar() " + "is incorrect so we dropped this attribute." + ) + self.tb_writer.flush() + + def on_train_end(self, args, state, control, **kwargs): + if xr.process_index() == 0 and self.tb_writer: + self.tb_writer.close() + self.tb_writer = None diff --git a/mixture_of_experts_pretraining/run_clm.py b/mixture_of_experts_pretraining/run_clm.py new file mode 100644 index 000000000..3920110a2 --- /dev/null +++ b/mixture_of_experts_pretraining/run_clm.py @@ -0,0 +1,239 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os + +import hydra +import numpy as np +import torch +from clm_datasets import get_dataset_cuda, get_datasets, process_datasets +from file_utils import get_file +from mlperf_logging_utils import ( + ClmLogger, + MetricsLogger, + MLPerfCallback, + MLPerfLightningCallback, +) +from omegaconf import DictConfig, OmegaConf +from transformers import AutoTokenizer, default_data_collator, logging, set_seed + +USE_CUDA = torch.cuda.is_available() # os.environ.get('USE_CUDA', False) + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) + +if not USE_CUDA: + from model_utils_tpu import ( + TensorBoardCallback, + get_global_batch_size, + setup_model_optimizer, + setup_xla, + ) + from trainer_utils_tpu import Trainer + + OmegaConf.register_new_resolver( + "get_global_batch_size", + lambda per_device_batch_size: get_global_batch_size(per_device_batch_size), + ) + OmegaConf.register_new_resolver( + "path_join", lambda output_dir, exp_name: os.path.join(output_dir, exp_name) + ) +else: + import torch.multiprocessing as mp + from model_utils_gpu import setup_distributed, setup_model_and_trainer + from nemo import lightning as nl + from nemo.collections import llm + + OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) + OmegaConf.register_new_resolver( + "path_join", lambda output_dir, exp_name: os.path.join(output_dir, exp_name) + ) + OmegaConf.register_new_resolver( + "get_global_batch_size", + lambda per_device_batch_size: per_device_batch_size, + ) + + mp.set_start_method("spawn", force=True) + + +@hydra.main(version_base=None, config_path="config", config_name="config") +def main(config: DictConfig): + logger = logging.get_logger(__name__) + + OmegaConf.resolve(config) + set_seed(config.seed) + logger.info("\n\n************** Experiment configuration ***********") + logger.info(OmegaConf.to_yaml(config)) + if USE_CUDA: + setup_distributed(config) + + if config.eval_frequency == -1: + config.eval_frequency = int( + np.ceil(24576 * 2048 / config.max_length / config.global_train_batch_size) + ) + logger.info(f"{config.eval_frequency=}") + + clmlogger = ClmLogger(config, filename="output.txt") + + if not USE_CUDA: + tokenizer = AutoTokenizer.from_pretrained( + config.model.name_or_path, + add_eos_token=False, + add_bos_token=False, + use_fast=False, + ) + + config_path = os.path.join(config.run_dir, "config.yaml") + with get_file(config_path, "w") as f: + OmegaConf.save(config, f) + + logger.info(f"log tensorboard to {os.path.join(config.run_dir, 'tensorboard')}") + setup_xla(config) + model, optimizer, scheduler = setup_model_optimizer(config) + + if tokenizer.vocab_size != model.config.vocab_size: + logger.warning( + f"Found mismatch between {tokenizer.vocab_size=} and {model.config.vocab_size}" + ) + raw_datasets = get_datasets(config) + datasets = process_datasets(raw_datasets, tokenizer, config) + train_dataset, eval_dataset = datasets["train"], datasets["validation"] + + # Initialize our Trainer + trainer = Trainer( + model=model, + config=config, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizer=optimizer, + scheduler=scheduler, + # Data collator will default to DataCollatorWithPadding, so we change it. + data_collator=default_data_collator, + callbacks=[MLPerfCallback(config), TensorBoardCallback(config)], + ) + trainer.train() + + else: + if "adam" in config.optimizer.lower(): + optimizer_name = "adam" + else: + raise ValueError("Unsupported optimizer for GPU run") + + data_parallel_size = torch.distributed.get_world_size() // ( + config.tensor_parallelism + * config.pipeline_parallelism + * config.context_parallelism + ) + + config.global_train_batch_size = int( + config.per_device_train_batch_size + * config.gradient_accumulation_steps + * data_parallel_size + ) + + config.global_eval_batch_size = config.global_train_batch_size + number_of_nodes = ( + torch.distributed.get_world_size() // torch.cuda.device_count() + ) + + metrics_logger = MetricsLogger( + clmlogger, + number_of_nodes, + config.global_train_batch_size, + config.lr, + config.max_length, + ) + + callbacks = [ + MLPerfLightningCallback( + clmlogger, + config.global_train_batch_size, + config.max_length, + ) + ] + + if ( + "capacity_factor" in config.model + and config.model.capacity_factor is not None + and config.model.capacity_factor > 0 + ): + from nemo.lightning.pytorch.callbacks.moe_token_drop import ( + MegatronTokenDropCallback, + ) + + callbacks.append( + MegatronTokenDropCallback( + moe_expert_capacity_factor=config.model.capacity_factor + ) + ) + + number_of_nodes = max( + 1, torch.distributed.get_world_size() // torch.cuda.device_count() + ) + + model, trainer, optimizer, resume = setup_model_and_trainer( + model_name_or_path=config.model.name_or_path, + input_sequence_length=config.max_length, + global_batch_size=config.global_train_batch_size, + nodes=number_of_nodes, + tp_size=config.tensor_parallelism, + pp_size=config.pipeline_parallelism, + vpp_size=None, # config.virtual_pipeline_parallelism, + cp_size=config.context_parallelism, + learning_rate=config.lr, + weight_decay=config.weight_decay, + optimizer_name=optimizer_name, + tokenizer_name_or_path=config.model.name_or_path, + scheduler=config.sched, + max_grad_norm=config.max_grad_norm, + eval_frequency=config.eval_frequency, + log_frequency=config.log_frequency, + max_steps=config.max_steps, + logger=metrics_logger, + callbacks=callbacks, + ) + ckpt = nl.ModelCheckpoint( + save_last=False, + save_top_k=False, + every_n_train_steps=0, + always_save_context=False, + save_context_on_train_end=False, + ) + + nemo_logger = nl.NeMoLogger( + ckpt=ckpt, + name="mixtral-reference", + tensorboard=None, + wandb=None, + log_dir="/results", + ) + + dataset = get_dataset_cuda(config) + + llm.train( + model=model, + data=dataset, + trainer=trainer, + tokenizer="data", + optim=optimizer, + log=nemo_logger, + # log=None, + resume=resume, + ) + + +if __name__ == "__main__": + main() diff --git a/mixture_of_experts_pretraining/scripts/gpu/checkpoint_download.py b/mixture_of_experts_pretraining/scripts/gpu/checkpoint_download.py new file mode 100644 index 000000000..e20ac0ac2 --- /dev/null +++ b/mixture_of_experts_pretraining/scripts/gpu/checkpoint_download.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import enum +import os +import pathlib + +from huggingface_hub import snapshot_download +from nemo.collections.llm.gpt.model.mixtral import HFMixtralImporter + + +class Model(enum.Enum): + MIXTRAL_8x7B_BASE = "mistralai/Mixtral-8x7B-v0.1" + MIXTRAL_8x22B_BASE = "mistralai/Mixtral-8x22B-v0.1" + + def __str__(self): + return self.value + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_id", + type=str, + default=Model.MIXTRAL_8x7B_BASE.value, + choices=list(x.value for x in Model) + ["path"], + ) + + parser.add_argument( + "--output_dir", + type=pathlib.Path, + required=True, + ) + + parser.add_argument( + "--hf_token", + type=str, + default=os.environ.get("HF_TOKEN", ""), + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace) -> None: + if args.checkpoint_id in list(x.value for x in Model): + snapshot_download( + repo_id=str(args.checkpoint_id), + repo_type="model", + local_dir=args.output_dir / "hf", + token=args.hf_token, + ) + importer = HFMixtralImporter(args.output_dir / "hf") + else: + importer = HFMixtralImporter(args.checkpoint_id) + importer.apply(args.output_dir / "nemo") + + +if __name__ == "__main__": + args = parse_arguments() + main(args) diff --git a/mixture_of_experts_pretraining/scripts/gpu/dataset_preprocessing.py b/mixture_of_experts_pretraining/scripts/gpu/dataset_preprocessing.py new file mode 100644 index 000000000..da697df76 --- /dev/null +++ b/mixture_of_experts_pretraining/scripts/gpu/dataset_preprocessing.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import pathlib +import subprocess + +from huggingface_hub import snapshot_download + +training_files = [f"en/c4-train.{i:05d}-of-01024.json.gz" for i in range(1024)] + +file_mapping_train = [ + (f"c4-train.en_{i}.json.gz", f"c4_train.en_{i}") for i in range(6, 8) +] + + +def download_dataset( + output_path: pathlib.Path, + repo_id: str = "allenai/c4", +) -> None: + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + local_dir=output_path, + allow_patterns="en/*.json.gz", + ) + + +def merge_into_consolidated( + source_directory: pathlib.Path, + output_directory: pathlib.Path, +): + def merge_files(output_path: pathlib.Path, input_paths: list[pathlib.Path]): + with open(output_path, "wb") as output_file: + for input_path in input_paths: + with open(input_path, "rb") as input_file: + file_content = input_file.read() + output_file.write(file_content) + + for i in range(6, 8): + file_chunks = [ + source_directory / training_files[j] for j in range(i * 128, (i + 1) * 128) + ] + merge_files(output_directory / f"c4-train.en_{i}.json.gz", file_chunks) + + +def run_conversion( + input_file: pathlib.Path, + output_file: pathlib.Path, + tokenizer_path: pathlib.Path, +): + print(f"Converting {input_file} into {output_file} using {tokenizer_path}") + + with subprocess.Popen( + [ + "python", + "/opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py", + "--input", + str(input_file), + "--output", + str(output_file), + "--tokenizer-library", + "huggingface", + "--tokenizer-type", + str(tokenizer_path), + "--dataset-impl", + "mmap", + "--workers", + "8", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) as process: + for line in process.stdout: + print(line.strip()) + + print(f"Exited with code={process.returncode}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input-tokenizer", + type=pathlib.Path, + required=True, + help="Path for stored tokenizer", + ) + parser.add_argument( + "--workdir", + type=pathlib.Path, + require=True, + help="Workdir for script", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + workdir_path = args.workdir + workdir_path.mkdir(exist_ok=True) + + os.makedirs(workdir_path / "raw", exist_ok=True) + os.makedirs(workdir_path / "merged", exist_ok=True) + os.makedirs(workdir_path / "output", exist_ok=True) + + download_dataset(workdir_path / "raw") + merge_into_consolidated(workdir_path / "raw", workdir_path / "merged") + + for source, target in file_mapping_train: + run_conversion( + workdir_path / "merged" / source, + workdir_path / "output" / target, + args.input_tokenizer, + ) + run_conversion( + workdir_path / "raw" / "en/c4-validation_24567exp.json", + workdir_path / "output" / "c4-validation-small.en", + args.input_tokenizer, + ) diff --git a/mixture_of_experts_pretraining/scripts/gpu/run.sub b/mixture_of_experts_pretraining/scripts/gpu/run.sub new file mode 100644 index 000000000..d0bf58f50 --- /dev/null +++ b/mixture_of_experts_pretraining/scripts/gpu/run.sub @@ -0,0 +1,23 @@ +#!/bin/bash + +: "${CONT:?Base Container image is not set, please specify CONT envvar}" +: "${DATA:?Data directory is not set, please specify DATA envvar}" +: "${CKPT:?Checkpoint directory is not set, please specify CKPT envvar}" +: "${NODES:?Number of nodes is not set, please specify NODES envvar}" +: "${OUTPUT:?Output directory is not set, please specify OUTPUT envvar}" + +CONT_MOUNTS="${DATA}:/app/dataset:ro,${CKPT}:/app/checkpoints:ro,${OUTPUT}:/results" + +: "${MASTER_PORT:=29500}" +export MASTER_PORT +export MASTER_ADDR="$(scontrol show hostnames "${SLURM_JOB_NODELIST-}" | head -n1)" + +srun -l --kill-on-bad-exit=0 --mpi="${SLURM_MPI_TYPE:-pmix}" \ + --ntasks="$(( NODES * ${GPUS:-8} ))" \ + --ntasks-per-node="${GPUS:-8}" \ + --container-image="${CONT}" \ + --container-mounts="${CONT_MOUNTS}" \ + --container-env=MASTER_PORT,MASTER_ADDR \ + slurm2pytorch python /app/training/run_clm.py output_dir=/results \ + dataset.train_dataset_path=/app/dataset dataset.eval_dataset_path=/app/dataset \ + diff --git a/mixture_of_experts_pretraining/scripts/tpu/distributed_checkpoint_saving.py b/mixture_of_experts_pretraining/scripts/tpu/distributed_checkpoint_saving.py new file mode 100644 index 000000000..25591c53d --- /dev/null +++ b/mixture_of_experts_pretraining/scripts/tpu/distributed_checkpoint_saving.py @@ -0,0 +1,106 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Load and save huggingface model into torch_xla distributed checkpoint + +Test cmd for gpt2: + python -m clm.scripts.tpu.distributed_checkpoint_saving model.name_or_path=gpt2 checkpoint_manager_path=/tmp/save/ + +True cmd for mixtral-8x22b: + export LOCAL_DIR=/tmp/save + python -m clm.scripts.tpu.distributed_checkpoint_saving model.name_or_path=mistralai/Mixtral-8x22B-v0.1 checkpoint_manager_path=$LOCAL_DIR + gsutil -m cp -r $LOCAL_DIR gs://some_bucket/path/to/dir +""" +import torch +import os +import hydra +from omegaconf import OmegaConf, DictConfig +from transformers import AutoTokenizer, AutoModelForCausalLM + +from transformers import logging + +from transformers import set_seed +import torch_xla.core.xla_model as xm + +from torch_xla.experimental.distributed_checkpoint import CheckpointManager +from ...model_utils_tpu import ( + setup_xla, + prepare_model, + get_global_batch_size, +) + +OmegaConf.register_new_resolver( + "path_join", lambda output_dir, exp_name: os.path.join(output_dir, exp_name) +) +OmegaConf.register_new_resolver( + "get_global_batch_size", + lambda per_device_batch_size: get_global_batch_size(per_device_batch_size), +) + +logger = logging.get_logger(__name__) + + +@hydra.main(version_base=None, config_path="../../config", config_name="config") +def main(config: DictConfig): + OmegaConf.resolve(config) + set_seed(config.seed) + + logger.info("\n\n************** Experiment configuration ***********") + logger.info(OmegaConf.to_yaml(config)) + + setup_xla(config) + + tokenizer = AutoTokenizer.from_pretrained( + config.model.name_or_path, + add_eos_token=False, + add_bos_token=False, + use_fast=False, + ) + logger.info("model loaded") + dtype = getattr(torch, config.model.dtype) + + model = AutoModelForCausalLM.from_pretrained( + config.model.name_or_path, + cache_dir=config.cache_local_dir, + torch_dtype=dtype, + ) + model = prepare_model(model, config) + model = model.to(dtype) + + torch.distributed.init_process_group("gloo", init_method="xla://") + if config.checkpoint_manager_path: + ckpt_manager = CheckpointManager( + path=config.checkpoint_manager_path, + save_interval=1, + max_to_keep=1, + ) + + state_dict = { + "model": model.state_dict(), + } + logger.info("saved model.state_dict:") + for k, v in state_dict["model"].items(): + logger.info(f"{k}: {v.dtype} {v.mean()}") + + ckpt_manager.save(0, state_dict) + else: + raise ValueError("need valid {config.checkpoint_manager_path=}") + + logger.info("checkpoing saving finished.") + + +if __name__ == "__main__": + main() diff --git a/mixture_of_experts_pretraining/trainer_utils_tpu.py b/mixture_of_experts_pretraining/trainer_utils_tpu.py new file mode 100644 index 000000000..a549a1256 --- /dev/null +++ b/mixture_of_experts_pretraining/trainer_utils_tpu.py @@ -0,0 +1,320 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import torch +from torch.nn import CrossEntropyLoss +from transformers.trainer_utils import EvalLoopOutput +from transformers.trainer_pt_utils import find_batch_size +from transformers import default_data_collator +from torch.utils.data import DataLoader +from typing import List +from tqdm.auto import tqdm + +from transformers import logging, TrainerState, TrainerControl +import torch_xla.runtime as xr +from typing import Dict +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs +import torch_xla.distributed.parallel_loader as pl +import torch_xla +import numpy as np +import datetime +import torch_xla.debug.profiler as xp +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + +logger = logging.get_logger(__name__) +PROFILE_PORT = 9012 +server = xp.start_server(PROFILE_PORT) +logger.info(f"Profiling server started: {server=}") + + +def calculate_tflops_training_per_device(model, config): + n_params = sum({name: p.numel() for name, p in model.named_parameters()}.values()) + logger.info(f"Total size={n_params/1e9:.3f}B params") + n_active_params = sum( + { + name: p.numel() for name, p in model.named_parameters() if p.requires_grad + }.values() + ) + logger.info(f"Active size={n_active_params/1e9:.3f}B params") + + # effective param + if hasattr(model.config, "num_experts_per_tok") and hasattr( + model.config, "num_local_experts" + ): + effective_n_params = ( + n_params * model.config.num_experts_per_tok / model.config.num_local_experts + ) + else: + effective_n_params = n_params + + # estimated tflops i.e. 6 * B * P, where B means number of tokens in batch + tflops_training_per_device = ( + 6 + * config.per_device_train_batch_size + * config.max_length + * effective_n_params + / 1e12 + ) + + logger.info( + f"Estimated {tflops_training_per_device=} with dtype as {config.model.dtype}" + ) + return tflops_training_per_device + + +class Trainer: + def __init__( + self, + config, + model, + tokenizer, + train_dataset, + eval_dataset, + optimizer, + scheduler, + data_collator=default_data_collator, + callbacks: List = None, + ): + self.config = config + self.model = model + mesh = xs.get_global_mesh() + + assert ( + config.global_train_batch_size % config.gradient_accumulation_steps == 0 + ), f"{config.global_train_batch_size=} is not divisable by {config.gradient_accumulation_steps=}" + self.global_train_micro_batch_size = ( + config.global_train_batch_size // config.gradient_accumulation_steps + ) + self.train_dataloader = pl.MpDeviceLoader( + DataLoader( + train_dataset, + collate_fn=data_collator, + batch_size=self.global_train_micro_batch_size, + ), + torch_xla.device(), + input_sharding=xs.ShardingSpec(mesh, ("fsdp", None)), + ) + + self.eval_dataloader = pl.MpDeviceLoader( + DataLoader( + eval_dataset, + collate_fn=data_collator, + batch_size=config.global_eval_batch_size, + ), + torch_xla.device(), + input_sharding=xs.ShardingSpec(mesh, ("fsdp", None)), + ) + self.optimizer = optimizer + self.scheduler = scheduler + self.callbacks = callbacks + + self.state = TrainerState() + self.state.global_step = 0 + self.state.max_steps = config.max_steps + self.state.eval_steps = config.eval_frequency + self.control = TrainerControl() + self.per_device_tflops = calculate_tflops_training_per_device(model, config) + + def compute_loss(self, batch, add_load_balancing_loss: bool = False): + labels = batch.pop("labels") + outputs = self.model(**batch) + logits = outputs.logits + # Flatten the tokens + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) + # flatten + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + num_tokens = (labels != self.config.pad_token_id).sum() + loss_weight = (shift_labels != self.config.pad_token_id).sum() + metrics = { + "num_tokens": num_tokens, + "loss_weight": loss_weight, + } + if add_load_balancing_loss: + assert self.model.training + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.model.num_experts, + self.model.num_experts_per_tok, + attention_mask=batch["attention_mask"], + ) + loss += self.model.router_aux_loss_coef * aux_loss + return loss, metrics + + def eval_loop(self): + self.model.eval() + group_eval_loss_sum: List = [] + group_eval_loss_weight: List = [] + group_eval_num_tokens: List = [] + for eval_batch in self.eval_dataloader: + with torch.no_grad(): + eval_loss_mean, eval_metrics = self.compute_loss( + eval_batch, add_load_balancing_loss=False + ) + eval_num_tokens = eval_metrics["num_tokens"] + eval_loss_weight = eval_metrics["loss_weight"] + eval_loss_sum = eval_loss_mean * eval_loss_weight + group_eval_loss_sum.append(eval_loss_sum) + group_eval_loss_weight.append(eval_loss_weight) + group_eval_num_tokens.append(eval_num_tokens) + + total_eval_loss_sum = sum(group_eval_loss_sum) + total_eval_loss_weight = sum(group_eval_loss_weight) + total_eval_num_tokens = sum(group_eval_num_tokens) + group_eval_metrics = { + "eval/loss": (total_eval_loss_sum / total_eval_loss_weight), + "eval/num_tokens": total_eval_num_tokens, + "eval/total_weights": total_eval_loss_weight, + } + return group_eval_metrics + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + logs = { + k: v.cpu().item() if isinstance(v, torch.Tensor) else v + for k, v in logs.items() + } + logs = {**logs, **{"step": self.state.global_step}} + logger.info(f"{logs}") + for callback in self.callbacks: + callback.on_log(self.config, self.state, self.control, logs=logs) + self.state.log_history.append(logs) + + def update_step(self): + self.state.global_step += 1 + + def train(self): + # Train! + for callback in self.callbacks: + callback.on_train_begin(self.config, self.state, self.control) + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = 1") + logger.info( + f" Instantaneous batch size per device = {self.config.per_device_train_batch_size}" + ) + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {self.config.global_train_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {self.config.max_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(self.config.max_steps), disable=xr.process_index() > 0 + ) + + train_loss_list = [] + train_num_tokens_list = [] + eval_first = self.config.do_first_eval + last_step_completion = datetime.datetime.now() + for batch_idx, batch in enumerate(self.train_dataloader): + if eval_first: + eval_metrics = self.eval_loop() + xm.add_step_closure(self.log, args=(eval_metrics,)) + eval_first = False + + if ( + self.control.should_training_stop + or self.state.global_step >= self.config.max_steps + ): + xm.mark_step() + break + + self.model.train() + train_loss_step, train_metrics_step = self.compute_loss( + batch, add_load_balancing_loss=True + ) + train_num_tokens_step = train_metrics_step["num_tokens"] + + train_loss_step /= self.config.gradient_accumulation_steps + train_loss_step.backward() + train_loss_list.append(train_loss_step) + train_num_tokens_list.append(train_num_tokens_step) + if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0: + # ensure wrap updating global step to avoid async in lazy printing + logs: Dict[str, float] = {} + if self.config.max_grad_norm > 0: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.max_grad_norm + ) + logs["train/grad_norm"] = grad_norm.detach() + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + train_loss = sum(train_loss_list) + train_num_tokens = sum(train_num_tokens_list) + logs["train/loss"] = train_loss.detach() + logs["train/num_tokens"] = train_num_tokens.detach() + logs["train/lr"] = self.scheduler.get_last_lr()[0] + if (self.state.global_step + 1) % self.state.eval_steps == 0: + eval_metrics = self.eval_loop() + logs.update(eval_metrics) + + # add tflops per second + new_time = datetime.datetime.now() + step_time_delta = (new_time - last_step_completion).total_seconds() + logs["perf/step_time_seconds"] = step_time_delta + logs["perf/per_device_tflops"] = self.per_device_tflops + logs["perf/per_device_tflops_per_sec"] = ( + self.per_device_tflops / step_time_delta + ) + logs["perf/per_device_tokens_per_sec"] = ( + logs["train/num_tokens"] / step_time_delta + ) + last_step_completion = new_time + + xm.add_step_closure(self.update_step) + if (self.state.global_step + 1) % self.config.log_frequency == 0: + xm.add_step_closure(self.log, args=(logs,)) + for callback in self.callbacks: + xm.add_step_closure( + callback.on_step_end, + args=(self.config, self.state, self.control), + ) + + train_loss_list = [] + train_num_tokens_list = [] + progress_bar.update(1) + + if self.state.global_step == self.config.xla_profile_step: + xm.wait_device_ops() + duration_ms = 20000 + xp.trace_detached( + f"localhost:{PROFILE_PORT}", + os.path.join(self.config.run_dir, "profile"), + duration_ms=duration_ms, + ) + + logger.info("train finished.")