Skip to content

Llama Performance Benchmarks #230

Llama Performance Benchmarks

Llama Performance Benchmarks #230

Workflow file for this run

name: Llama Performance Benchmarks
# This workflow runs Llama performance benchmarks
# using two different JAX versions: 0.6.0 & 0.8.2
# For JAX 0.6.0: uses official released wheels
# and Docker image.
# For JAX 0.8.2: uses wheels and Docker image
# built from the latest nightly results.
# PS: Ubuntu 24 & ROCm 7.2.0.
env:
TE_REF_DEFAULT: "0e7a1a9e1d997ea42f48028f652453e6f9390b42"
on:
schedule:
# plan to use prev. nightly builds
- cron: "0 22 * * *"
workflow_dispatch:
inputs:
runner-label:
required: false
type: choice
options:
- linux-jax-mi325-8
- linux-jax-mi355-8
default: "linux-jax-mi355-8"
te-ref:
required: false
type: string
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-te-wheels:
runs-on: linux-jax-mi355-1
outputs:
runner_label: ${{ steps.arch.outputs.label }}
te_commit_sha: ${{ steps.tsha.outputs.hash }}
strategy:
fail-fast: false
matrix:
jax-version: ["0.6.0", "0.8.2"]
include:
- jax-version: "0.6.0"
jaxlib-version: "0.6.0"
# There is no docker image with ROCm 7.2.0 and Jax 0.6.0
# use a ROCm 7.2.0 / Jax 0.8.2 docker image and update Jax
# before building TE. For the application workload Bazel installs
# its own Jax
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
- jax-version: "0.8.2"
jaxlib-version: "0.8.2"
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
env:
NVTE_FRAMEWORK: jax
NVTE_USE_ROCM: 1
HIP_PLATFORM: amd
NVTE_FUSED_ATTN_AOTRITON: 0
CU_NUM: 64
steps:
- name: Checkout TE source
uses: actions/checkout@v4
with:
repository: rocm/transformerengine
ref: ${{ inputs.te-ref || env.TE_REF_DEFAULT }}
submodules: recursive
fetch-depth: 0
path: transformerengine
- name: Print workspace
run: |
ls -la ${{ github.workspace }}
- name: Get TE commit SHA
id: tsha
run: echo "hash=$(git -C transformerengine rev-parse HEAD)" >> $GITHUB_OUTPUT
- name: Define ROCm arch & runner-label
id: arch
run: |
if [ "${{ inputs.runner-label }}" == "linux-jax-mi325-8" ]; then
echo "ROCM_ARCH=gfx942" >> $GITHUB_ENV
echo "label=MI325" >> $GITHUB_OUTPUT
else
echo "ROCM_ARCH=gfx950" >> $GITHUB_ENV
echo "label=MI355" >> $GITHUB_OUTPUT
fi
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.SECRET_TOKEN }}" \
| docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Run Container for ${{ matrix.jax-version }}
run: |
docker run -d \
--name te-build-jax${{ matrix.jax-version }} \
--network=host \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size 32G \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-e NVTE_FRAMEWORK=${{ env.NVTE_FRAMEWORK }} \
-e NVTE_USE_ROCM=${{ env.NVTE_USE_ROCM }} \
-e HIP_PLATFORM=${{ env.HIP_PLATFORM }} \
-e NVTE_ROCM_ARCH=${{ env.ROCM_ARCH }} \
-e NVTE_FUSED_ATTN_AOTRITON=${{ env.NVTE_FUSED_ATTN_AOTRITON }} \
-e CU_NUM=${{ env.CU_NUM }} \
-e PIP_BREAK_SYSTEM_PACKAGES=1 \
-v "${{ github.workspace }}:/workspace" \
${{ matrix.docker-image }} \
tail -f /dev/null
- name: Build TE-wheel
run: |
docker exec te-build-jax${{ matrix.jax-version }} bash -lc '
pip install setuptools ninja pybind11
pip install cmake==4.1.0
pip install packaging psutil
if [ "${{ matrix.jax-version }}" == "0.6.0" ]; then
apt update
apt install libdw1 libglib2.0-0
ROCM_JAX_BASE=https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0
PJRT_URL="$ROCM_JAX_BASE/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl"
PLUGIN_URL="$ROCM_JAX_BASE/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl"
pip install --force-reinstall jax==0.6.0 \
jaxlib==0.6.0 \
"$PJRT_URL" \
"$PLUGIN_URL"
fi
cd /workspace/transformerengine
sed -i '\''s/jax\.extend/jax/g'\'' build_tools/jax.py
python3 setup.py bdist_wheel'
- name: Upload TE wheels artifacts
uses: actions/upload-artifact@v4
with:
name: te-wheel-jax${{ matrix.jax-version }}
path: transformerengine/dist/*whl
- name: Cleanup containers
if: always()
run: |
docker stop te-build-jax${{ matrix.jax-version }} || true
docker rm te-build-jax${{ matrix.jax-version }} || true
run-llama-perf-model:
needs: build-te-wheels
runs-on: ${{ inputs.runner-label || 'linux-jax-mi355-8' }}
strategy:
fail-fast: false
max-parallel: 1
matrix:
jax-version: ["0.6.0", "0.8.2"]
model-name: ["train_dense"]
include:
- jax-version: "0.6.0"
model-name: "train_dense"
jaxlib-version: "0.6.0"
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
- jax-version: "0.8.2"
model-name: "train_dense"
jaxlib-version: "0.8.2"
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
steps:
- name: Checkout source repo
uses: actions/checkout@v4
with:
repository: AMD-AGI/scif_repro
token: ${{ secrets.SECRET_TOKEN }}
- name: Download JAX wheels
env:
GH_TOKEN: ${{ secrets.SECRET_TOKEN }}
run: |
if [ "${{ matrix.jax-version }}" == "0.6.0" ]; then
mkdir -p wheelhouse
REL_URL="https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0"
curl -L --output-dir wheelhouse -O \
"$REL_URL/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl"
curl -L --output-dir wheelhouse -O \
"$REL_URL/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl"
else
gh run download \
$(gh run list \
-R ROCm/rocm-jax \
-w nightly.yml \
-L 2000 \
--json databaseId,event,createdAt \
-q '
map(select(
.event == "schedule"))
| max_by(.createdAt)
| .databaseId') \
-R ROCm/rocm-jax \
-n plugin_wheels_r7.2.0 \
-D wheelhouse
rm -v wheelhouse/*311*
fi
- name: Download TE wheel artifact
uses: actions/download-artifact@v4
with:
name: te-wheel-jax${{ matrix.jax-version }}
path: wheelhouse
- name: Print wheels
run: |
ls -la ${{ github.workspace }}/wheelhouse
- name: Disable non-root check
run: |
sed -i '/python.toolchain(/a\ ignore_root_user_error = True,' MODULE.bazel
- name: Patch source files
# these patches will be removed when the repo is updated/tested accordingly
run: |
find . -type f -exec sed -i 's/jax_rocm60_plugin/jax_rocm7_plugin/g' {} +
find . -type f -exec sed -i 's/jax_rocm60_pjrt/jax_rocm7_pjrt/g' {} +
echo 'jax==${{ matrix.jax-version }}' >> kylix/third_party/requirements.lock
echo 'jaxlib==${{ matrix.jaxlib-version }}' >> kylix/third_party/requirements.lock
PJRT_WHL=$(ls wheelhouse/jax_rocm7_pjrt-*.whl | head -n1 || true)
PLUGIN_WHL=$(ls wheelhouse/jax_rocm7_plugin-*.whl | head -n1 || true)
TE_WHL=$(ls wheelhouse/transformer_engine-*.whl | head -n1 || true)
PJRT_BASE=$(basename "$PJRT_WHL")
PLUGIN_BASE=$(basename "$PLUGIN_WHL")
TE_BASE=$(basename "$TE_WHL")
{
echo "jax_rocm7_pjrt @ file:///workspace/wheelhouse/$PJRT_BASE"
echo "jax_rocm7_plugin @ file:///workspace/wheelhouse/$PLUGIN_BASE"
echo "transformer_engine @ file:///workspace/wheelhouse/$TE_BASE"
} >> kylix/third_party/requirements.lock
sed -i '11,36s/^/#/' bazel_run_anynode.sh
sed -i '/^TRAIN_SETTING/s/^/# /' bazel_run_anynode.sh
sed -i '/pip/s/^/#/' kylix/third_party/requirements.lock
sed -i '/orbax-checkpoint/ s/^/# /' kylix/third_party/requirements.lock
echo 'orbax-checkpoint==0.11.25' >> kylix/third_party/requirements.lock
echo 'aiofiles==25.1.0' >> kylix/third_party/requirements.lock
sed -i '/^wheel/s/^/# /; /^# wheel/a wheel==0.45.1' kylix/third_party/requirements.lock
sed -i 's/\bfun\s*=\s*self.create_train_step()/self.create_train_step()/g' \
kylix/max/execution/executors.py
sed -i '/train_dense",/a\ "//kylix/max/projects/llama/config:train_dense_8b",' \
kylix/max/projects/llama/BUILD
sed -i 's/operations\[4\]/operations[5]/; /num_operation_workers=16;/a \
model.mha_use_transformer_engine=True;optimizer.total_steps=200;' bazel_run_anynode.sh
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.SECRET_TOKEN }}" \
| docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Run Container for ${{ matrix.jax-version }} and ${{ matrix.model-name }}
run: |
docker run -d \
--name model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} \
--network=host \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size 32G \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-v "${{ github.workspace }}:/workspace" \
${{ matrix.docker-image }} \
tail -f /dev/null
- name: Run model
run: |
docker exec model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} bash -lc '
BIN_URL="https://github.com/bazelbuild/bazelisk/releases/download/"
wget -O /usr/local/bin/bazel "$BIN_URL/v1.26.0/bazelisk-linux-amd64"
chmod +x /usr/local/bin/bazel
echo "8.5.0" > /workspace/.bazelversion
cd /workspace
bazel version
apt update && apt install -y libglib2.0-0
export EXECUTOR="base_executor_config"
export TRAIN_SETTING="${{ matrix.model-name }}"
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=True \
--xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True \
--xla_gpu_enable_command_buffer='' --xla_gpu_autotune_level=4"
export PROJECT=byllama
bash -i bazel_run_anynode.sh 2>&1 | tee logs.log
tail -n 25 logs.log > training_summary.txt'
- name: Upload training summary
uses: actions/upload-artifact@v4
with:
name: training-summary-${{ matrix.jax-version }}-${{ matrix.model-name }}
path: training_summary.txt
- name: Cleanup container
if: always()
run: |
docker stop model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} || true
docker rm model-run-jax${{ matrix.jax-version }}-${{ matrix.model-name }} || true
upload-result-db:
needs: ['run-llama-perf-model', 'build-te-wheels']
runs-on: mysqldb
strategy:
fail-fast: false
max-parallel: 1
matrix:
include:
- jax-version: "0.6.0"
model-name: "train_dense"
rocm-version: "7.2.0"
python-version: "3.12"
- jax-version: "0.8.2"
model-name: "train_dense"
rocm-version: "7.2.0"
python-version: "3.12"
steps:
- name: Checkout plugin repo
uses: actions/checkout@v4
- name: Download training summary
uses: actions/download-artifact@v4
with:
name: training-summary-${{ matrix.jax-version }}-${{ matrix.model-name }}
path: .
- name: Upload result to DB
env:
ROCM_JAX_DB_HOSTNAME: ${{ secrets.ROCM_JAX_DB_HOSTNAME }}
ROCM_JAX_DB_USERNAME: ${{ secrets.ROCM_JAX_DB_USERNAME }}
ROCM_JAX_DB_PASSWORD: ${{ secrets.ROCM_JAX_DB_PASSWORD }}
ROCM_JAX_DB_NAME: ${{ secrets.ROCM_JAX_DB_NAME }}
ROCM_VERSION: ${{ matrix.rocm-version }}
PYTHON_VERSION: ${{ matrix.python-version }}
run: |
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install mysql-connector-python
python3 ci/upload_to_llama_db.py \
--github-run-id "${{ github.run_id }}" \
--run-tag "ci-run" \
--model-name "${{ matrix.model-name }}" \
--te-commit "${{ needs.build-te-wheels.outputs.te_commit_sha }}" \
--jax-version "${{ matrix.jax-version }}" \
--rocm-version "${ROCM_VERSION//.}" \
--python-version "${PYTHON_VERSION//.}" \
--runner-label "${{ needs.build-te-wheels.outputs.runner_label }}" \
--github-ref "${{ github.ref_name }}" \
--trig-event "${{ github.event_name }}" \
--actor-name "${{ github.actor }}"