Llama Performance Benchmarks #230
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Llama Performance Benchmarks | |
| # This workflow runs Llama performance benchmarks | |
| # using two different JAX versions: 0.6.0 & 0.8.2 | |
| # For JAX 0.6.0: uses official released wheels | |
| # and Docker image. | |
| # For JAX 0.8.2: uses wheels and Docker image | |
| # built from the latest nightly results. | |
| # PS: Ubuntu 24 & ROCm 7.2.0. | |
| env: | |
| TE_REF_DEFAULT: "0e7a1a9e1d997ea42f48028f652453e6f9390b42" | |
| on: | |
| schedule: | |
| # plan to use prev. nightly builds | |
| - cron: "0 22 * * *" | |
| workflow_dispatch: | |
| inputs: | |
| runner-label: | |
| required: false | |
| type: choice | |
| options: | |
| - linux-jax-mi325-8 | |
| - linux-jax-mi355-8 | |
| default: "linux-jax-mi355-8" | |
| te-ref: | |
| required: false | |
| type: string | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.ref }} | |
| cancel-in-progress: true | |
| jobs: | |
| build-te-wheels: | |
| runs-on: linux-jax-mi355-1 | |
| outputs: | |
| runner_label: ${{ steps.arch.outputs.label }} | |
| te_commit_sha: ${{ steps.tsha.outputs.hash }} | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| jax-version: ["0.6.0", "0.8.2"] | |
| include: | |
| - jax-version: "0.6.0" | |
| jaxlib-version: "0.6.0" | |
| # There is no docker image with ROCm 7.2.0 and Jax 0.6.0 | |
| # use a ROCm 7.2.0 / Jax 0.8.2 docker image and update Jax | |
| # before building TE. For the application workload Bazel installs | |
| # its own Jax | |
| docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" | |
| - jax-version: "0.8.2" | |
| jaxlib-version: "0.8.2" | |
| docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" | |
| env: | |
| NVTE_FRAMEWORK: jax | |
| NVTE_USE_ROCM: 1 | |
| HIP_PLATFORM: amd | |
| NVTE_FUSED_ATTN_AOTRITON: 0 | |
| CU_NUM: 64 | |
| steps: | |
| - name: Checkout TE source | |
| uses: actions/checkout@v4 | |
| with: | |
| repository: rocm/transformerengine | |
| ref: ${{ inputs.te-ref || env.TE_REF_DEFAULT }} | |
| submodules: recursive | |
| fetch-depth: 0 | |
| path: transformerengine | |
| - name: Print workspace | |
| run: | | |
| ls -la ${{ github.workspace }} | |
| - name: Get TE commit SHA | |
| id: tsha | |
| run: echo "hash=$(git -C transformerengine rev-parse HEAD)" >> $GITHUB_OUTPUT | |
| - name: Define ROCm arch & runner-label | |
| id: arch | |
| run: | | |
| if [ "${{ inputs.runner-label }}" == "linux-jax-mi325-8" ]; then | |
| echo "ROCM_ARCH=gfx942" >> $GITHUB_ENV | |
| echo "label=MI325" >> $GITHUB_OUTPUT | |
| else | |
| echo "ROCM_ARCH=gfx950" >> $GITHUB_ENV | |
| echo "label=MI355" >> $GITHUB_OUTPUT | |
| fi | |
| - name: Authenticate to GitHub Container Registry | |
| run: | | |
| echo "${{ secrets.SECRET_TOKEN }}" \ | |
| | docker login ghcr.io -u ${{ github.actor }} --password-stdin | |
| - name: Run Container for ${{ matrix.jax-version }} | |
| run: | | |
| docker run -d \ | |
| --name te-build-jax${{ matrix.jax-version }} \ | |
| --network=host \ | |
| --device=/dev/kfd \ | |
| --device=/dev/dri \ | |
| --ipc=host \ | |
| --shm-size 32G \ | |
| --group-add video \ | |
| --cap-add=SYS_PTRACE \ | |
| --security-opt seccomp=unconfined \ | |
| -e NVTE_FRAMEWORK=${{ env.NVTE_FRAMEWORK }} \ | |
| -e NVTE_USE_ROCM=${{ env.NVTE_USE_ROCM }} \ | |
| -e HIP_PLATFORM=${{ env.HIP_PLATFORM }} \ | |
| -e NVTE_ROCM_ARCH=${{ env.ROCM_ARCH }} \ | |
| -e NVTE_FUSED_ATTN_AOTRITON=${{ env.NVTE_FUSED_ATTN_AOTRITON }} \ | |
| -e CU_NUM=${{ env.CU_NUM }} \ | |
| -e PIP_BREAK_SYSTEM_PACKAGES=1 \ | |
| -v "${{ github.workspace }}:/workspace" \ | |
| ${{ matrix.docker-image }} \ | |
| tail -f /dev/null | |
| - name: Build TE-wheel | |
| run: | | |
| docker exec te-build-jax${{ matrix.jax-version }} bash -lc ' | |
| pip install setuptools ninja pybind11 | |
| pip install cmake==4.1.0 | |
| pip install packaging psutil | |
| if [ "${{ matrix.jax-version }}" == "0.6.0" ]; then | |
| apt update | |
| apt install libdw1 libglib2.0-0 | |
| ROCM_JAX_BASE=https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0 | |
| PJRT_URL="$ROCM_JAX_BASE/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl" | |
| PLUGIN_URL="$ROCM_JAX_BASE/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl" | |
| pip install --force-reinstall jax==0.6.0 \ | |
| jaxlib==0.6.0 \ | |
| "$PJRT_URL" \ | |
| "$PLUGIN_URL" | |
| fi | |
| cd /workspace/transformerengine | |
| sed -i '\''s/jax\.extend/jax/g'\'' build_tools/jax.py | |
| python3 setup.py bdist_wheel' | |
| - name: Upload TE wheels artifacts | |
| uses: actions/upload-artifact@v4 | |
| with: | |
| name: te-wheel-jax${{ matrix.jax-version }} | |
| path: transformerengine/dist/*whl | |
| - name: Cleanup containers | |
| if: always() | |
| run: | | |
| docker stop te-build-jax${{ matrix.jax-version }} || true | |
| docker rm te-build-jax${{ matrix.jax-version }} || true | |
| run-llama-perf-model: | |
| needs: build-te-wheels | |
| runs-on: ${{ inputs.runner-label || 'linux-jax-mi355-8' }} | |
| strategy: | |
| fail-fast: false | |
| max-parallel: 1 | |
| matrix: | |
| jax-version: ["0.6.0", "0.8.2"] | |
| model-name: ["train_dense"] | |
| include: | |
| - jax-version: "0.6.0" | |
| model-name: "train_dense" | |
| jaxlib-version: "0.6.0" | |
| docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" | |
| - jax-version: "0.8.2" | |
| model-name: "train_dense" | |
| jaxlib-version: "0.8.2" | |
| docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" | |
| steps: | |
| - name: Checkout source repo | |
| uses: actions/checkout@v4 | |
| with: | |
| repository: AMD-AGI/scif_repro | |
| token: ${{ secrets.SECRET_TOKEN }} | |
| - name: Download JAX wheels | |
| env: | |
| GH_TOKEN: ${{ secrets.SECRET_TOKEN }} | |
| run: | | |
| if [ "${{ matrix.jax-version }}" == "0.6.0" ]; then | |
| mkdir -p wheelhouse | |
| REL_URL="https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0" | |
| curl -L --output-dir wheelhouse -O \ | |
| "$REL_URL/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl" | |
| curl -L --output-dir wheelhouse -O \ | |
| "$REL_URL/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl" | |
| else | |
| gh run download \ | |
| $(gh run list \ | |
| -R ROCm/rocm-jax \ | |
| -w nightly.yml \ | |
| -L 2000 \ | |
| --json databaseId,event,createdAt \ | |
| -q ' | |
| map(select( | |
| .event == "schedule")) | |
| | max_by(.createdAt) | |
| | .databaseId') \ | |
| -R ROCm/rocm-jax \ | |
| -n plugin_wheels_r7.2.0 \ | |
| -D wheelhouse | |
| rm -v wheelhouse/*311* | |
| fi | |
| - name: Download TE wheel artifact | |
| uses: actions/download-artifact@v4 | |
| with: | |
| name: te-wheel-jax${{ matrix.jax-version }} | |
| path: wheelhouse | |
| - name: Print wheels | |
| run: | | |
| ls -la ${{ github.workspace }}/wheelhouse | |
| - name: Disable non-root check | |
| run: | | |
| sed -i '/python.toolchain(/a\ ignore_root_user_error = True,' MODULE.bazel | |
| - name: Patch source files | |
| # these patches will be removed when the repo is updated/tested accordingly | |
| run: | | |
| find . -type f -exec sed -i 's/jax_rocm60_plugin/jax_rocm7_plugin/g' {} + | |
| find . -type f -exec sed -i 's/jax_rocm60_pjrt/jax_rocm7_pjrt/g' {} + | |
| echo 'jax==${{ matrix.jax-version }}' >> kylix/third_party/requirements.lock | |
| echo 'jaxlib==${{ matrix.jaxlib-version }}' >> kylix/third_party/requirements.lock | |
| PJRT_WHL=$(ls wheelhouse/jax_rocm7_pjrt-*.whl | head -n1 || true) | |
| PLUGIN_WHL=$(ls wheelhouse/jax_rocm7_plugin-*.whl | head -n1 || true) | |
| TE_WHL=$(ls wheelhouse/transformer_engine-*.whl | head -n1 || true) | |
| PJRT_BASE=$(basename "$PJRT_WHL") | |
| PLUGIN_BASE=$(basename "$PLUGIN_WHL") | |
| TE_BASE=$(basename "$TE_WHL") | |
| { | |
| echo "jax_rocm7_pjrt @ file:///workspace/wheelhouse/$PJRT_BASE" | |
| echo "jax_rocm7_plugin @ file:///workspace/wheelhouse/$PLUGIN_BASE" | |
| echo "transformer_engine @ file:///workspace/wheelhouse/$TE_BASE" | |
| } >> kylix/third_party/requirements.lock | |
| sed -i '11,36s/^/#/' bazel_run_anynode.sh | |
| sed -i '/^TRAIN_SETTING/s/^/# /' bazel_run_anynode.sh | |
| sed -i '/pip/s/^/#/' kylix/third_party/requirements.lock | |
| sed -i '/orbax-checkpoint/ s/^/# /' kylix/third_party/requirements.lock | |
| echo 'orbax-checkpoint==0.11.25' >> kylix/third_party/requirements.lock | |
| echo 'aiofiles==25.1.0' >> kylix/third_party/requirements.lock | |
| sed -i '/^wheel/s/^/# /; /^# wheel/a wheel==0.45.1' kylix/third_party/requirements.lock | |
| sed -i 's/\bfun\s*=\s*self.create_train_step()/self.create_train_step()/g' \ | |
| kylix/max/execution/executors.py | |
| sed -i '/train_dense",/a\ "//kylix/max/projects/llama/config:train_dense_8b",' \ | |
| kylix/max/projects/llama/BUILD | |
| sed -i 's/operations\[4\]/operations[5]/; /num_operation_workers=16;/a \ | |
| model.mha_use_transformer_engine=True;optimizer.total_steps=200;' bazel_run_anynode.sh | |
| - name: Authenticate to GitHub Container Registry | |
| run: | | |
| echo "${{ secrets.SECRET_TOKEN }}" \ | |
| | docker login ghcr.io -u ${{ github.actor }} --password-stdin | |
| - name: Run Container for ${{ matrix.jax-version }} and ${{ matrix.model-name }} | |
| run: | | |
| docker run -d \ | |
| --name model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} \ | |
| --network=host \ | |
| --device=/dev/kfd \ | |
| --device=/dev/dri \ | |
| --ipc=host \ | |
| --shm-size 32G \ | |
| --group-add video \ | |
| --cap-add=SYS_PTRACE \ | |
| --security-opt seccomp=unconfined \ | |
| -v "${{ github.workspace }}:/workspace" \ | |
| ${{ matrix.docker-image }} \ | |
| tail -f /dev/null | |
| - name: Run model | |
| run: | | |
| docker exec model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} bash -lc ' | |
| BIN_URL="https://github.com/bazelbuild/bazelisk/releases/download/" | |
| wget -O /usr/local/bin/bazel "$BIN_URL/v1.26.0/bazelisk-linux-amd64" | |
| chmod +x /usr/local/bin/bazel | |
| echo "8.5.0" > /workspace/.bazelversion | |
| cd /workspace | |
| bazel version | |
| apt update && apt install -y libglib2.0-0 | |
| export EXECUTOR="base_executor_config" | |
| export TRAIN_SETTING="${{ matrix.model-name }}" | |
| export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=True \ | |
| --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True \ | |
| --xla_gpu_enable_command_buffer='' --xla_gpu_autotune_level=4" | |
| export PROJECT=byllama | |
| bash -i bazel_run_anynode.sh 2>&1 | tee logs.log | |
| tail -n 25 logs.log > training_summary.txt' | |
| - name: Upload training summary | |
| uses: actions/upload-artifact@v4 | |
| with: | |
| name: training-summary-${{ matrix.jax-version }}-${{ matrix.model-name }} | |
| path: training_summary.txt | |
| - name: Cleanup container | |
| if: always() | |
| run: | | |
| docker stop model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} || true | |
| docker rm model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} || true | |
| upload-result-db: | |
| needs: ['run-llama-perf-model', 'build-te-wheels'] | |
| runs-on: mysqldb | |
| strategy: | |
| fail-fast: false | |
| max-parallel: 1 | |
| matrix: | |
| include: | |
| - jax-version: "0.6.0" | |
| model-name: "train_dense" | |
| rocm-version: "7.2.0" | |
| python-version: "3.12" | |
| - jax-version: "0.8.2" | |
| model-name: "train_dense" | |
| rocm-version: "7.2.0" | |
| python-version: "3.12" | |
| steps: | |
| - name: Checkout plugin repo | |
| uses: actions/checkout@v4 | |
| - name: Download training summary | |
| uses: actions/download-artifact@v4 | |
| with: | |
| name: training-summary-${{ matrix.jax-version }}-${{ matrix.model-name }} | |
| path: . | |
| - name: Upload result to DB | |
| env: | |
| ROCM_JAX_DB_HOSTNAME: ${{ secrets.ROCM_JAX_DB_HOSTNAME }} | |
| ROCM_JAX_DB_USERNAME: ${{ secrets.ROCM_JAX_DB_USERNAME }} | |
| ROCM_JAX_DB_PASSWORD: ${{ secrets.ROCM_JAX_DB_PASSWORD }} | |
| ROCM_JAX_DB_NAME: ${{ secrets.ROCM_JAX_DB_NAME }} | |
| ROCM_VERSION: ${{ matrix.rocm-version }} | |
| PYTHON_VERSION: ${{ matrix.python-version }} | |
| run: | | |
| python3 -m venv venv | |
| source venv/bin/activate | |
| pip install --upgrade pip | |
| pip install mysql-connector-python | |
| python3 ci/upload_to_llama_db.py \ | |
| --github-run-id "${{ github.run_id }}" \ | |
| --run-tag "ci-run" \ | |
| --model-name "${{ matrix.model-name }}" \ | |
| --te-commit "${{ needs.build-te-wheels.outputs.te_commit_sha }}" \ | |
| --jax-version "${{ matrix.jax-version }}" \ | |
| --rocm-version "${ROCM_VERSION//.}" \ | |
| --python-version "${PYTHON_VERSION//.}" \ | |
| --runner-label "${{ needs.build-te-wheels.outputs.runner_label }}" \ | |
| --github-ref "${{ github.ref_name }}" \ | |
| --trig-event "${{ github.event_name }}" \ | |
| --actor-name "${{ github.actor }}" |