Skip to content

Introduce tsan pipeline #181

Introduce tsan pipeline

Introduce tsan pipeline #181

Workflow file for this run

name: CI Unit Tests
on:
push:
branches:
- master
- 'rocm-jaxlib-v*'
pull_request:
branches:
- master
- 'rocm-jaxlib-v*'
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
build-and-test:
name: build-and-test (${{ matrix.mode.name }})
runs-on: linux-x86-64-4gpu-amd-gfx942
strategy:
fail-fast: false
matrix:
mode:
- {name: "py3.11", python_version: "3.11", config: ""}
- {name: "py3.12", python_version: "3.12", config: ""}
- {name: "py3.13", python_version: "3.13", config: ""}
- {name: "py3.14", python_version: "3.14", config: ""}
- {name: "asan", python_version: "3.11", config: "--config=asan"}
# yamllint disable-line rule:line-length
- {name: "tsan", python_version: "3.11", config: "--config=tsan --strategy=TestRunner=local"}
container:
# note this image shall match the one defined in platform/linux:tf_linux_gpu
image: rocm/tensorflow-build@sha256:7fcfbd36b7ac8f6b0805b37c4248e929e31cf5ee3af766c8409dd70d5ab65faa
options: >-
-w ${{ github.workspace }}/jax_rocm_plugin
--device=/dev/kfd
--device=/dev/dri
--group-add video
--cap-add=SYS_PTRACE
--security-opt seccomp=unconfined
--shm-size 16G
defaults:
run:
working-directory: jax_rocm_plugin
steps:
- name: Checkout plugin repo
uses: actions/checkout@v4
- name: Get RBE cluster keys
env:
RBE_CI_CERT: ${{ secrets.RBE_CI_CERT }}
RBE_CI_KEY: ${{ secrets.RBE_CI_KEY }}
run: |
echo "$RBE_CI_CERT" >> ci-cert.crt
echo "$RBE_CI_KEY" >> ci-cert.key
- name: Run single-GPU unit tests
if: always()
run: |
bash build/rocm/ci_run_jax_ut.sh \
--config=rocm_sgpu \
--config=rocm_rbe \
--repo_env=HERMETIC_PYTHON_VERSION=${{ matrix.mode.python_version }} \
${{ matrix.mode.config }} \
--curses=no \
--color=yes \
-- \
@jax//tests:gpu_tests \
@jax//tests:backend_independent_tests \
$(build/rocm/targets_to_ignore.sh)
- name: Run multi-GPU unit tests
if: always()
run: |
bash build/rocm/ci_run_jax_ut.sh \
--config=rocm_mgpu \
--config=rocm_rbe \
--repo_env=HERMETIC_PYTHON_VERSION=${{ matrix.mode.python_version }} \
${{ matrix.mode.config }} \
--curses=no \
--color=yes \
--strategy=TestRunner=local \
-- \
@jax//tests:gpu_tests \
@jax//tests:backend_independent_tests \
$(build/rocm/targets_to_ignore.sh)
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: logs-rbe-py${{ matrix.mode.name }}
path: jax_rocm_plugin/logs/