TPU simple tests #19
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright 2025 Google LLC | ||
|
Check failure on line 1 in .github/workflows/tpu-tests.yml
|
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
| # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | ||
| name: TPU Tests | ||
| on: | ||
| pull_request: | ||
| branches: [ main ] | ||
| workflow_dispatch: | ||
| schedule: | ||
| # Run the job every 6 hours | ||
| - cron: '0 */6 * * *' | ||
| concurrency: | ||
| # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else | ||
| group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }} | ||
| cancel-in-progress: true | ||
| jobs: | ||
| tpu_unit_tests: | ||
| runs-on: [linux-x86-ct5lp-224-8tpu] | ||
| container: | ||
| image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest | ||
| options: --privileged | ||
| env: | ||
| CLOUD_TPU_ACCELERATOR: v5e-8 | ||
| JAX_PLATFORMS: tpu | ||
| steps: | ||
| - name: Checkout code | ||
| uses: actions/checkout@v4 | ||
| with: | ||
| fetch-depth: 0 | ||
| - name: Install tunix dependencies | ||
| run: | | ||
| pip install -e . | ||
| pip install pytest pytest-xdist | ||
| - name: Verify TPU availability | ||
| run: | | ||
| python -c " | ||
| import jax | ||
| print(f'JAX version: {jax.__version__}') | ||
| print(f'JAX devices: {jax.devices()}') | ||
| print(f'TPU available: {len(jax.devices()) > 0}') | ||
| " | ||
| - name: Run tunix model tests | ||
| run: | | ||
| python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only" | ||
| - name: Run tunix generation tests | ||
| run: | | ||
| python -m pytest tests/generate/ -v --tb=short -m "not cpu_only and not gpu_only" | ||
| - name: Run tunix SFT tests | ||
| run: | | ||
| python -m pytest tests/sft/ -v --tb=short -m "not cpu_only and not gpu_only" -k "not test_common" | ||
| - name: Run tunix distillation tests | ||
| run: | | ||
| python -m pytest tests/distillation/ -v --tb=short -m "not cpu_only and not gpu_only" | ||
| - name: Run tunix RL tests (basic) | ||
| run: | | ||
| python -m pytest tests/rl/common_test.py -v --tb=short -m "not cpu_only and not gpu_only" | ||
| - name: Test tunix imports | ||
| run: | | ||
| python -c " | ||
| import tunix | ||
| import tunix.models | ||
| import tunix.generate | ||
| import tunix.sft | ||
| import tunix.distillation | ||
| import tunix.rl | ||
| print('All tunix modules imported successfully') | ||
| " | ||
| - name: Test basic model loading | ||
| run: | | ||
| python -c " | ||
| from tunix.models.llama3.params import Llama3Params | ||
| print('Llama3 params loaded successfully') | ||
| " | ||
| notify_failure: | ||
| name: Notify failed build | ||
| needs: [tpu_unit_tests, tpu_integration_tests] | ||
| if: ${{ always() }} | ||
| runs-on: ubuntu-latest | ||
| permissions: | ||
| issues: write | ||
| steps: | ||
| - name: Check whether one of the jobs failed | ||
| id: report_failure | ||
| if: ${{ contains(needs.*.result, 'failure') && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }} | ||
| uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 | ||
| with: | ||
| github-token: ${{ secrets.GITHUB_TOKEN }} | ||
| labels: "failed-build" | ||
| - name: Reset consecutive success counter on failure | ||
| if: ${{ steps.report_failure.outputs.issue-number != '' }} | ||
| env: | ||
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
| GH_REPO: ${{ github.repository }} | ||
| ISSUE_NUMBER: ${{ steps.report_failure.outputs.issue-number }} | ||
| run: | | ||
| echo "A failure occurred. Resetting success counter on issue #${ISSUE_NUMBER}." | ||
| gh issue remove-label $ISSUE_NUMBER "success-run-1" "success-run-2" --repo $GH_REPO || echo "No success labels to remove." | ||
| notify_success_and_close: | ||
| name: Close issue after 3 successful builds | ||
| if: ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }} | ||
| needs: [tpu_unit_tests, tpu_integration_tests] | ||
| runs-on: ubuntu-latest | ||
| permissions: | ||
| issues: write | ||
| steps: | ||
| - name: Find existing failure issue | ||
| id: find_issue | ||
| env: | ||
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
| GH_REPO: ${{ github.repository }} | ||
| run: | | ||
| ISSUE_NUMBER=$(gh issue list --label "failed-build" --state open --limit 1 --json number -q '.[0].number') | ||
| if [[ -z "$ISSUE_NUMBER" ]]; then | ||
| echo "No open build failure issue found. Nothing to do." | ||
| echo "issue_number=" >> $GITHUB_OUTPUT | ||
| else | ||
| echo "Found open build failure issue: #${ISSUE_NUMBER}" | ||
| echo "issue_number=${ISSUE_NUMBER}" >> $GITHUB_OUTPUT | ||
| fi | ||
| - name: Add success label or close issue | ||
| if: steps.find_issue.outputs.issue_number != '' | ||
| env: | ||
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
| GH_REPO: ${{ github.repository }} | ||
| run: | | ||
| ISSUE_NUMBER=${{ steps.find_issue.outputs.issue_number }} | ||
| LABELS=$(gh issue view $ISSUE_NUMBER --json labels -q '.labels[].name') | ||
| if echo "$LABELS" | grep -q "success-run-2"; then | ||
| echo "Third consecutive success. Closing issue #${ISSUE_NUMBER}." | ||
| gh issue comment $ISSUE_NUMBER --body "Build succeeded for the third consecutive time. Closing this issue automatically." | ||
| gh issue close $ISSUE_NUMBER | ||
| gh issue remove-label $ISSUE_NUMBER "failed-build" "success-run-2" --repo $GH_REPO | ||
| elif echo "$LABELS" | grep -q "success-run-1"; then | ||
| echo "Second consecutive success. Updating label on issue #${ISSUE_NUMBER}." | ||
| gh issue comment $ISSUE_NUMBER --body "Build succeeded for the second time. One more successful run will close this issue." | ||
| gh issue remove-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO | ||
| gh issue add-label $ISSUE_NUMBER "success-run-2" --repo $GH_REPO | ||
| else | ||
| echo "First consecutive success since failure. Adding label to issue #${ISSUE_NUMBER}." | ||
| gh issue comment $ISSUE_NUMBER --body "Build succeeded. This issue will be auto-closed after two more consecutive successful runs." | ||
| gh issue add-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO | ||
| fi | ||