Skip to content

Test multi-host runner #1

Test multi-host runner

Test multi-host runner #1

Workflow file for this run

name: build
on:
pull_request:
branches:
- main
permissions:
contents: read
actions: write # to cancel previous workflows
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-checkpoint:
name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: linux-g2-16-l4-1gpu-x4
container: python:3.11
defaults:
run:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
jax-version: ["newest"]
include:
- python-version: "3.10"
jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml
# TODO(b/401258175) Re-enable once JAX nightlies are fixed.
# - python-version: "3.13"
# jax-version: "nightly"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Extract branch name
shell: bash
run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT
id: extract_branch
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler
pip install tensorflow
protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto")
pip install -e .
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Test with pytest
run: |
pytest orbax/experimental/model/core/python/*_test.py
pytest orbax/experimental/model/tf2obm/*_test.py
pytest orbax/experimental/model/jax2obm/ \
--ignore=orbax/experimental/model/jax2obm/main_lib_test.py \
--ignore=orbax/experimental/model/jax2obm/sharding_test.py \
--ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py