ROCm DLM Performance Evaluations #323
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: ROCm DLM Performance Evaluations | |
| on: | |
| schedule: | |
| - cron: '0 3 * * *' # Nightly at 3:00 AM UTC | |
| workflow_dispatch: | |
| jobs: | |
| build-and-test-jax-perf: | |
| runs-on: "linux-x86-64-8gpu-amd" | |
| strategy: | |
| matrix: | |
| python-version: ["3.12"] | |
| rocm-version: ["7.1.1"] | |
| outputs: | |
| python_version: ${{ steps.meta.outputs.python }} | |
| rocm_version: ${{ steps.meta.outputs.rocm }} | |
| env: | |
| WORKSPACE_DIR: ${{ format( | |
| 'jax_rocm_perf_{0}_{1}_{2}', | |
| github.run_id, | |
| github.run_number, | |
| github.run_attempt | |
| ) }} | |
| PYTHON_VERSION: ${{ matrix.python-version }} | |
| ROCM_VERSION: ${{ matrix.rocm-version }} | |
| steps: | |
| - name: Get job metadata | |
| id: meta | |
| run: | | |
| echo "python=${{ matrix.python-version }}" >> "$GITHUB_OUTPUT" | |
| echo "rocm=${{ matrix.rocm-version }}" >> "$GITHUB_OUTPUT" | |
| - name: Clean up old workdirs | |
| run: | | |
| ls -l | |
| docker run --rm -v "$(pwd):/rocm-jax" ubuntu \ | |
| bash -c "shopt -s dotglob; chown -R $UID /rocm-jax/* || true" | |
| rm -rf * || true | |
| ls -l | |
| - name: Print system info | |
| run: | | |
| whoami | |
| printenv | |
| df -h | |
| rocm-smi || true | |
| - name: Checkout source | |
| uses: actions/checkout@v4 | |
| - name: Checkout JAX repo for jaxlib build | |
| uses: actions/checkout@v4 | |
| with: | |
| repository: ROCm/jax | |
| ref: rocm-jaxlib-v0.8.2 | |
| path: jax | |
| - name: Build plugin wheels | |
| run: | | |
| python3 build/ci_build \ | |
| --compiler clang \ | |
| --python-versions $PYTHON_VERSION \ | |
| --rocm-version $ROCM_VERSION \ | |
| --jax-source-dir="./jax" \ | |
| dist_wheels | |
| - name: Copy wheels for Docker build context | |
| run: | | |
| mkdir -p wheelhouse | |
| cp ./jax_rocm_plugin/wheelhouse/*.whl ./wheelhouse/ | |
| - name: Build JAX docker image | |
| run: | | |
| python3 build/ci_build \ | |
| --rocm-version $ROCM_VERSION \ | |
| build_dockers \ | |
| --filter ubu24 | |
| - name: Build Docker image for MaxText | |
| run: | | |
| IMAGE=ghcr.io/rocm/maxtext-jax-rocm${ROCM_VERSION//.} | |
| docker build \ | |
| --build-arg BASE_IMAGE=jax-ubu24.rocm${ROCM_VERSION//.} \ | |
| --build-arg MAXTEXT_BRANCH=rv_jax \ | |
| -f ci/Dockerfile.maxtext \ | |
| -t $IMAGE:nightly \ | |
| -t $IMAGE:${{ github.sha }} \ | |
| -t $IMAGE:run${{ github.run_id }} \ | |
| . | |
| - name: Launch container | |
| run: | | |
| docker run -d --name maxtext_container \ | |
| --network=host \ | |
| --device=/dev/kfd \ | |
| --device=/dev/dri \ | |
| --ipc=host \ | |
| --shm-size=64G \ | |
| --group-add=video \ | |
| --cap-add=SYS_PTRACE \ | |
| --security-opt seccomp=unconfined \ | |
| -w /maxtext \ | |
| ghcr.io/rocm/maxtext-jax-rocm${ROCM_VERSION//.}:run${{ github.run_id }} \ | |
| tail -f /dev/null | |
| - name: Run MaxText training and save logs | |
| run: | | |
| for config in \ | |
| MaxText/configs/models/gpu/llama2_7b_rocm.yml \ | |
| MaxText/configs/models/gpu/gemma_2b_rocm.yml \ | |
| MaxText/configs/models/gpu/gpt3_6b_rocm.yml \ | |
| MaxText/configs/models/gpu/mixtral_8x1b_rocm.yml; do | |
| model_name=$(basename "$config" _rocm.yml) | |
| echo "Running $model_name" | |
| if [[ "$model_name" == "mixtral_8x1b" ]]; then | |
| docker exec maxtext_container bash -c \ | |
| "export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 && \ | |
| python3 -m MaxText.train $config" | tee logs_${model_name}.log | |
| else | |
| docker exec maxtext_container bash -c "python3 -m MaxText.train $config" \ | |
| | tee logs_${model_name}.log | |
| fi | |
| done | |
| - name: Analyze logs to compute median step time | |
| run: | | |
| pip install numpy --break-system-packages | |
| python3 build/analyze_maxtext_logs.py | |
| cat summary.json | |
| - name: Upload logs and summary | |
| uses: actions/upload-artifact@v4 | |
| with: | |
| name: training-results | |
| path: | | |
| logs_*.log | |
| summary.json | |
| - name: Authenticate to GitHub Container Registry | |
| run: | | |
| echo "${{ secrets.GITHUB_TOKEN }}" \ | |
| | docker login ghcr.io -u ${{ github.actor }} --password-stdin | |
| - name: Push Docker image to GHCR | |
| run: | | |
| IMAGE=ghcr.io/rocm/maxtext-jax-rocm${ROCM_VERSION//.} | |
| docker push $IMAGE:nightly | |
| docker push $IMAGE:${{ github.sha }} | |
| docker push $IMAGE:run${{ github.run_id }} | |
| - name: Cleanup container | |
| if: always() | |
| run: | | |
| docker stop maxtext_container || true | |
| docker rm maxtext_container || true | |
| upload-summary-to-db: | |
| name: Upload Summary to MySQL | |
| needs: build-and-test-jax-perf | |
| runs-on: mysqldb | |
| steps: | |
| - name: Checkout source | |
| uses: actions/checkout@v4 | |
| - name: Download training summary artifact | |
| uses: actions/download-artifact@v4 | |
| with: | |
| name: training-results | |
| - name: Upload summary.json to MySQL database | |
| env: | |
| ROCM_JAX_DB_HOSTNAME: ${{ secrets.ROCM_JAX_DB_HOSTNAME }} | |
| ROCM_JAX_DB_USERNAME: ${{ secrets.ROCM_JAX_DB_USERNAME }} | |
| ROCM_JAX_DB_PASSWORD: ${{ secrets.ROCM_JAX_DB_PASSWORD }} | |
| ROCM_JAX_DB_NAME: ${{ secrets.ROCM_JAX_DB_NAME }} | |
| PYTHON_VERSION: ${{ needs.build-and-test-jax-perf.outputs.python_version }} | |
| ROCM_VERSION: ${{ needs.build-and-test-jax-perf.outputs.rocm_version }} | |
| run: | | |
| python3 -m venv venv | |
| source venv/bin/activate | |
| pip install --upgrade pip | |
| pip install mysql-connector-python | |
| python3 ci/upload_to_db.py \ | |
| --github-run-id "${{ github.run_id }}" \ | |
| --python-version "$PYTHON_VERSION" \ | |
| --rocm-version "$ROCM_VERSION" \ | |
| --gfx-version gfx942 \ | |
| --jax-version 0.8.2 |