Skip to content

Jax-Metal CI

Jax-Metal CI #533

# JAX-Metal plugin CI
name: Jax-Metal CI
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/metal_plugin_ci.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
permissions: {}
jobs:
jax-metal-plugin-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Get repo
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
path: jax
persist-credentials: false
- name: Setup build and test enviroment
run: |
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install uv~=0.5.30
uv pip install -U pip numpy wheel absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
uv pip install --pre jaxlib \
-i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
fi;
cd jax
uv pip install . jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
export ENABLE_PJRT_COMPATIBILITY=1
cd jax
pytest tests/lax_metal_test.py