Skip to content

Commit db22d4f

Browse files
committed
Test multi-host runner
1 parent 17b4717 commit db22d4f

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

.github/workflows/test.yml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
name: build
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
8+
permissions:
9+
contents: read
10+
actions: write # to cancel previous workflows
11+
12+
concurrency:
13+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
build-checkpoint:
18+
name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
19+
runs-on: linux-g2-16-l4-1gpu-x4
20+
container: python:3.11
21+
defaults:
22+
run:
23+
working-directory: checkpoint
24+
strategy:
25+
matrix:
26+
python-version: ["3.10", "3.11", "3.12"]
27+
jax-version: ["newest"]
28+
include:
29+
- python-version: "3.10"
30+
jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml
31+
# TODO(b/401258175) Re-enable once JAX nightlies are fixed.
32+
# - python-version: "3.13"
33+
# jax-version: "nightly"
34+
steps:
35+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
36+
- name: Set up Python ${{ matrix.python-version }}
37+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
- name: Extract branch name
41+
shell: bash
42+
run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT
43+
id: extract_branch
44+
- name: Install dependencies
45+
run: |
46+
sudo apt-get update
47+
sudo apt-get install -y protobuf-compiler
48+
49+
pip install tensorflow
50+
51+
protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto")
52+
53+
pip install -e .
54+
55+
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
56+
57+
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
58+
pip install -U jax jaxlib
59+
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
60+
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
61+
else
62+
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
63+
fi
64+
- name: Test with pytest
65+
run: |
66+
pytest orbax/experimental/model/core/python/*_test.py
67+
68+
pytest orbax/experimental/model/tf2obm/*_test.py
69+
70+
pytest orbax/experimental/model/jax2obm/ \
71+
--ignore=orbax/experimental/model/jax2obm/main_lib_test.py \
72+
--ignore=orbax/experimental/model/jax2obm/sharding_test.py \
73+
--ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py

0 commit comments

Comments
 (0)