Skip to content

Commit c5fae82

Browse files
committed
Smoke test docker
1 parent b55b975 commit c5fae82

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

.github/workflows/ci.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ jobs:
9797
strategy:
9898
matrix:
9999
rocm-version: ["7.1.1"]
100+
python-version: ["3.11", "3.12", "3.13", "3.14"]
100101
steps:
101102
- uses: actions/checkout@v4
102103

@@ -114,3 +115,24 @@ jobs:
114115
115116
- name: Build JAX docker image
116117
run: bash jax_rocm_plugin/build/rocm/build_docker.sh ${{ matrix.rocm-version }}
118+
119+
- name: Smoke test docker image
120+
env:
121+
ROCM_VERSION: ${{ matrix.rocm-version }}
122+
run: |
123+
ROCM_VERSION_TAG="${ROCM_VERSION//./}"
124+
JAX_TAG="jax-ubu24.rocm${ROCM_VERSION_TAG}"
125+
eval "$(bash jax_rocm_plugin/build/rocm/get_commits.sh)"
126+
docker run --rm \
127+
--device=/dev/kfd \
128+
--device=/dev/dri \
129+
--group-add video \
130+
--cap-add=SYS_PTRACE \
131+
--security-opt seccomp=unconfined \
132+
--shm-size 16G \
133+
"${JAX_TAG}" \
134+
bash -c "
135+
curl -sL https://github.com/ROCm/jax/archive/${JAX_COMMIT}.tar.gz | tar xz -C /tmp && \
136+
cd /tmp/jax-${JAX_COMMIT} && \
137+
python${{ matrix.python-version }} -m pytest -x -v tests/nn_test.py tests/random_test.py
138+
"

jax_rocm_plugin/build/rocm/build_docker.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@ REGISTRY="ghcr.io/rocm"
1010
BASE_TAG="jax-base-ubu24.rocm${ROCM_VERSION_TAG}"
1111
JAX_TAG="jax-ubu24.rocm${ROCM_VERSION_TAG}"
1212

13-
# Extract commit hashes from workspace files
14-
XLA_COMMIT=$(grep -oP 'XLA_COMMIT = "\K[0-9a-f]+' jax_rocm_plugin/third_party/xla/workspace.bzl)
15-
JAX_COMMIT=$(grep -oP 'COMMIT = "\K[0-9a-f]+' jax_rocm_plugin/third_party/jax/workspace.bzl)
16-
ROCM_JAX_COMMIT=$(git rev-parse HEAD)
13+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
14+
eval "$(bash "${SCRIPT_DIR}/get_commits.sh")"
1715

1816
# Build base image if not available in the registry
1917
if ! docker manifest inspect "${REGISTRY}/${BASE_TAG}:latest" >/dev/null 2>&1; then
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
# Extracts commit hashes from workspace files and prints them as
3+
# shell-evaluable assignments. Usage:
4+
# eval "$(bash jax_rocm_plugin/build/rocm/get_commits.sh)"
5+
set -euo pipefail
6+
7+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
8+
THIRD_PARTY="${SCRIPT_DIR}/../../third_party"
9+
10+
XLA_COMMIT=$(grep -oP 'XLA_COMMIT = "\K[0-9a-f]+' "${THIRD_PARTY}/xla/workspace.bzl")
11+
JAX_COMMIT=$(grep -oP 'COMMIT = "\K[0-9a-f]+' "${THIRD_PARTY}/jax/workspace.bzl")
12+
ROCM_JAX_COMMIT=$(git rev-parse HEAD)
13+
14+
echo "XLA_COMMIT=${XLA_COMMIT}"
15+
echo "JAX_COMMIT=${JAX_COMMIT}"
16+
echo "ROCM_JAX_COMMIT=${ROCM_JAX_COMMIT}"

0 commit comments

Comments
 (0)