|
| 1 | +name: build |
| 2 | + |
| 3 | +on: |
| 4 | + pull_request: |
| 5 | + branches: |
| 6 | + - main |
| 7 | + |
| 8 | +permissions: |
| 9 | + contents: read |
| 10 | + actions: write # to cancel previous workflows |
| 11 | + |
| 12 | +concurrency: |
| 13 | + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} |
| 14 | + cancel-in-progress: true |
| 15 | + |
| 16 | +jobs: |
| 17 | + build-checkpoint: |
| 18 | + name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" |
| 19 | + runs-on: linux-g2-16-l4-1gpu-x4 |
| 20 | + container: python:3.11 |
| 21 | + defaults: |
| 22 | + run: |
| 23 | + working-directory: checkpoint |
| 24 | + strategy: |
| 25 | + matrix: |
| 26 | + python-version: ["3.10", "3.11", "3.12"] |
| 27 | + jax-version: ["newest"] |
| 28 | + include: |
| 29 | + - python-version: "3.10" |
| 30 | + jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml |
| 31 | + # TODO(b/401258175) Re-enable once JAX nightlies are fixed. |
| 32 | + # - python-version: "3.13" |
| 33 | + # jax-version: "nightly" |
| 34 | + steps: |
| 35 | + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 |
| 36 | + - name: Set up Python ${{ matrix.python-version }} |
| 37 | + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 |
| 38 | + with: |
| 39 | + python-version: ${{ matrix.python-version }} |
| 40 | + - name: Extract branch name |
| 41 | + shell: bash |
| 42 | + run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT |
| 43 | + id: extract_branch |
| 44 | + - name: Install dependencies |
| 45 | + run: | |
| 46 | + sudo apt-get update |
| 47 | + sudo apt-get install -y protobuf-compiler |
| 48 | +
|
| 49 | + pip install tensorflow |
| 50 | +
|
| 51 | + protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") |
| 52 | +
|
| 53 | + pip install -e . |
| 54 | +
|
| 55 | + pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
| 56 | +
|
| 57 | + if [[ "${{ matrix.jax-version }}" == "newest" ]]; then |
| 58 | + pip install -U jax jaxlib |
| 59 | + elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then |
| 60 | + pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ |
| 61 | + else |
| 62 | + pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" |
| 63 | + fi |
| 64 | + - name: Test with pytest |
| 65 | + run: | |
| 66 | + pytest orbax/experimental/model/core/python/*_test.py |
| 67 | +
|
| 68 | + pytest orbax/experimental/model/tf2obm/*_test.py |
| 69 | +
|
| 70 | + pytest orbax/experimental/model/jax2obm/ \ |
| 71 | + --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ |
| 72 | + --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ |
| 73 | + --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py |
0 commit comments