|
1 | 1 | name: Run Pytest CPU tests |
2 | 2 |
|
3 | 3 | on: |
4 | | - # pull_request: |
5 | | - # branches: |
6 | | - # - main |
| 4 | + pull_request: |
| 5 | + branches: |
| 6 | + - main |
7 | 7 | workflow_dispatch: |
8 | 8 | inputs: |
9 | 9 | halt-for-connection: |
|
13 | 13 | default: 'no' |
14 | 14 | options: |
15 | 15 | - 'yes' |
16 | | - - 'no' |
17 | 16 |
|
18 | 17 | jobs: |
19 | | - build: |
| 18 | + build_jaxlib_artifact: |
| 19 | + name: "Build the jaxlib aritfact using latest XLA" |
| 20 | + uses: ./.github/workflows/build_artifacts.yml |
| 21 | + with: |
| 22 | + wheel_list: "jaxlib" |
| 23 | + python_list: "3.10" |
| 24 | + platform_list: "linux_x86,linux_arm64" |
| 25 | + clone_main_xla: 1 |
| 26 | + upload_artifacts: true |
| 27 | + upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' |
| 28 | + |
| 29 | + run_pytest: |
| 30 | + name: "Run CPU tests with Pytest" |
| 31 | + needs: build_jaxlib_artifact |
20 | 32 | continue-on-error: true |
21 | 33 | defaults: |
22 | 34 | run: |
23 | 35 | # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. |
24 | 36 | shell: bash |
25 | 37 | strategy: |
26 | 38 | matrix: |
27 | | - runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] |
| 39 | + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] |
28 | 40 | python: ["3.10"] |
29 | 41 |
|
30 | 42 | runs-on: ${{ matrix.runner }} |
31 | 43 | container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || |
32 | | - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || |
33 | | - (contains(matrix.runner, 'windows-x86') && null) }} |
| 44 | + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} |
34 | 45 |
|
35 | 46 | env: |
36 | | - JAXCI_CLONE_MAIN_XLA: 1 |
37 | 47 | JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} |
38 | 48 |
|
39 | 49 | steps: |
|
43 | 53 | uses: google-ml-infra/actions/ci_connection@main |
44 | 54 | with: |
45 | 55 | halt-dispatch-input: ${{ inputs.halt-for-connection }} |
46 | | - - name: Build jaxlib |
47 | | - run: ./ci/build_artifacts.sh "jaxlib" |
| 56 | + - name: Set Platform |
| 57 | + run: | |
| 58 | + os=$(uname -s | awk '{print tolower($0)}') |
| 59 | + arch=$(uname -m) |
| 60 | + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV |
| 61 | + - name: Download the artifacts built in the "build_artifacts" job |
| 62 | + run: >- |
| 63 | + mkdir -p $(pwd)/dist && |
| 64 | + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ |
48 | 65 | - name: Install pytest |
49 | 66 | env: |
50 | 67 | JAXCI_PYTHON: python${{ matrix.python }} |
|
0 commit comments