Test CI scripts and workflows #290
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: CI - Pytest CPU | |
| on: | |
| pull_request: | |
| branches: | |
| - main | |
| workflow_dispatch: | |
| inputs: | |
| halt-for-connection: | |
| description: 'Should this workflow run wait for a remote connection?' | |
| type: choice | |
| required: true | |
| default: 'no' | |
| options: | |
| - 'yes' | |
| - 'no' | |
| workflow_call: | |
| inputs: | |
| runner: | |
| description: "Which runner should the workflow run on?" | |
| type: string | |
| required: true | |
| default: "linux-x86-n2-16" | |
| python: | |
| description: "Which python version should the artifact be built for?" | |
| type: string | |
| required: true | |
| default: "3.12" | |
| enable-x64: | |
| description: "Should x64 mode be enabled?" | |
| type: string | |
| required: true | |
| default: "0" | |
| gcs_download_uri: | |
| description: "GCS location URI from where the artifacts should be downloaded" | |
| required: false | |
| default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' | |
| type: string | |
| jobs: | |
| run-tests: | |
| defaults: | |
| run: | |
| # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. | |
| shell: bash | |
| runs-on: ${{ inputs.runner }} | |
| container: ${{ (contains(inputs.runner, 'linux-x86-n2') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | |
| (contains(inputs.runner, 'linux-x86-t2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | |
| (contains(inputs.runner, 'windows-x86') && null) }} | |
| name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" | |
| env: | |
| JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" | |
| JAXCI_PYTHON: "python${{ inputs.python }}" | |
| JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} | |
| steps: | |
| - uses: actions/checkout@v4 | |
| # Halt for testing | |
| - name: Wait For Connection | |
| uses: google-ml-infra/actions/ci_connection@main | |
| with: | |
| halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
| - name: Set env vars for use in artifact download URL | |
| run: | | |
| os=$(uname -s | awk '{print tolower($0)}') | |
| arch=$(uname -m) | |
| # Adjust os and arch for Windows x86 | |
| if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then | |
| os="win" | |
| arch="amd64" | |
| fi | |
| # Get the major and minor version of Python. | |
| # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 | |
| python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') | |
| echo "OS=${os}" >> $GITHUB_ENV | |
| echo "ARCH=${arch}" >> $GITHUB_ENV | |
| echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV | |
| - name: Download the jaxlib wheel from GCS (non-Windows runs) | |
| if: ${{ !contains(inputs.runner, 'windows-x86') }} | |
| run: >- | |
| mkdir -p $(pwd)/dist && | |
| gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ | |
| - name: Download the jaxlib wheel from GCS (Windows runs) | |
| if: ${{ contains(inputs.runner, 'windows-x86') }} | |
| shell: cmd | |
| run: >- | |
| mkdir dist && | |
| gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/ | |
| - name: Install dependencies | |
| run: $JAXCI_PYTHON -m pip install -r build/requirements.in | |
| - name: Run Pytest CPU tests | |
| run: ./ci/run_pytest_cpu.sh |