Skip to content

ROCm DLM Performance Evaluations #323

ROCm DLM Performance Evaluations

ROCm DLM Performance Evaluations #323

Workflow file for this run

name: ROCm DLM Performance Evaluations
on:
schedule:
- cron: '0 3 * * *' # Nightly at 3:00 AM UTC
workflow_dispatch:
jobs:
build-and-test-jax-perf:
runs-on: "linux-x86-64-8gpu-amd"
strategy:
matrix:
python-version: ["3.12"]
rocm-version: ["7.1.1"]
outputs:
python_version: ${{ steps.meta.outputs.python }}
rocm_version: ${{ steps.meta.outputs.rocm }}
env:
WORKSPACE_DIR: ${{ format(
'jax_rocm_perf_{0}_{1}_{2}',
github.run_id,
github.run_number,
github.run_attempt
) }}
PYTHON_VERSION: ${{ matrix.python-version }}
ROCM_VERSION: ${{ matrix.rocm-version }}
steps:
- name: Get job metadata
id: meta
run: |
echo "python=${{ matrix.python-version }}" >> "$GITHUB_OUTPUT"
echo "rocm=${{ matrix.rocm-version }}" >> "$GITHUB_OUTPUT"
- name: Clean up old workdirs
run: |
ls -l
docker run --rm -v "$(pwd):/rocm-jax" ubuntu \
bash -c "shopt -s dotglob; chown -R $UID /rocm-jax/* || true"
rm -rf * || true
ls -l
- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi || true
- name: Checkout source
uses: actions/checkout@v4
- name: Checkout JAX repo for jaxlib build
uses: actions/checkout@v4
with:
repository: ROCm/jax
ref: rocm-jaxlib-v0.8.2
path: jax
- name: Build plugin wheels
run: |
python3 build/ci_build \
--compiler clang \
--python-versions $PYTHON_VERSION \
--rocm-version $ROCM_VERSION \
--jax-source-dir="./jax" \
dist_wheels
- name: Copy wheels for Docker build context
run: |
mkdir -p wheelhouse
cp ./jax_rocm_plugin/wheelhouse/*.whl ./wheelhouse/
- name: Build JAX docker image
run: |
python3 build/ci_build \
--rocm-version $ROCM_VERSION \
build_dockers \
--filter ubu24
- name: Build Docker image for MaxText
run: |
IMAGE=ghcr.io/rocm/maxtext-jax-rocm${ROCM_VERSION//.}
docker build \
--build-arg BASE_IMAGE=jax-ubu24.rocm${ROCM_VERSION//.} \
--build-arg MAXTEXT_BRANCH=rv_jax \
-f ci/Dockerfile.maxtext \
-t $IMAGE:nightly \
-t $IMAGE:${{ github.sha }} \
-t $IMAGE:run${{ github.run_id }} \
.
- name: Launch container
run: |
docker run -d --name maxtext_container \
--network=host \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size=64G \
--group-add=video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-w /maxtext \
ghcr.io/rocm/maxtext-jax-rocm${ROCM_VERSION//.}:run${{ github.run_id }} \
tail -f /dev/null
- name: Run MaxText training and save logs
run: |
for config in \
MaxText/configs/models/gpu/llama2_7b_rocm.yml \
MaxText/configs/models/gpu/gemma_2b_rocm.yml \
MaxText/configs/models/gpu/gpt3_6b_rocm.yml \
MaxText/configs/models/gpu/mixtral_8x1b_rocm.yml; do
model_name=$(basename "$config" _rocm.yml)
echo "Running $model_name"
if [[ "$model_name" == "mixtral_8x1b" ]]; then
docker exec maxtext_container bash -c \
"export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 && \
python3 -m MaxText.train $config" | tee logs_${model_name}.log
else
docker exec maxtext_container bash -c "python3 -m MaxText.train $config" \
| tee logs_${model_name}.log
fi
done
- name: Analyze logs to compute median step time
run: |
pip install numpy --break-system-packages
python3 build/analyze_maxtext_logs.py
cat summary.json
- name: Upload logs and summary
uses: actions/upload-artifact@v4
with:
name: training-results
path: |
logs_*.log
summary.json
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.GITHUB_TOKEN }}" \
| docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Push Docker image to GHCR
run: |
IMAGE=ghcr.io/rocm/maxtext-jax-rocm${ROCM_VERSION//.}
docker push $IMAGE:nightly
docker push $IMAGE:${{ github.sha }}
docker push $IMAGE:run${{ github.run_id }}
- name: Cleanup container
if: always()
run: |
docker stop maxtext_container || true
docker rm maxtext_container || true
upload-summary-to-db:
name: Upload Summary to MySQL
needs: build-and-test-jax-perf
runs-on: mysqldb
steps:
- name: Checkout source
uses: actions/checkout@v4
- name: Download training summary artifact
uses: actions/download-artifact@v4
with:
name: training-results
- name: Upload summary.json to MySQL database
env:
ROCM_JAX_DB_HOSTNAME: ${{ secrets.ROCM_JAX_DB_HOSTNAME }}
ROCM_JAX_DB_USERNAME: ${{ secrets.ROCM_JAX_DB_USERNAME }}
ROCM_JAX_DB_PASSWORD: ${{ secrets.ROCM_JAX_DB_PASSWORD }}
ROCM_JAX_DB_NAME: ${{ secrets.ROCM_JAX_DB_NAME }}
PYTHON_VERSION: ${{ needs.build-and-test-jax-perf.outputs.python_version }}
ROCM_VERSION: ${{ needs.build-and-test-jax-perf.outputs.rocm_version }}
run: |
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install mysql-connector-python
python3 ci/upload_to_db.py \
--github-run-id "${{ github.run_id }}" \
--python-version "$PYTHON_VERSION" \
--rocm-version "$ROCM_VERSION" \
--gfx-version gfx942 \
--jax-version 0.8.2