Skip to content

Adding basic tpu github tests #16

Adding basic tpu github tests

Adding basic tpu github tests #16

Workflow file for this run

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: TPU Tests
on:
pull_request:
branches: [ main ]
workflow_dispatch:
schedule:
# Run the job every 6 hours
- cron: '0 */6 * * *'
concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
cancel-in-progress: true
jobs:
prelim:
runs-on: ["linux-x86-n2-32"]
steps:
- name: Test gsutil installation
run: which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}
- name: Cleanup old docker images
run: docker system prune --all --force
tpu_unit_tests:
needs: prelim
runs-on: [linux-x86-ct5lp-224-8tpu]
container:
image: python:3.12-slim
options: --privileged --cpus=2 --memory=4Gi
env:
TPU_ACCELERATOR_TYPE: ""
JAX_PLATFORMS: "cpu"
CUDA_VISIBLE_DEVICES: ""
TF_CPP_MIN_LOG_LEVEL: "3"
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y git curl
- name: Set up JAX for TPU
run: |
pip install --upgrade pip
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- name: Install tunix dependencies
run: |
pip install -e .
pip install pytest pytest-xdist
- name: Verify TPU availability
run: |
python -c "
import jax
print(f'JAX version: {jax.__version__}')
print(f'JAX devices: {jax.devices()}')
print(f'TPU available: {len(jax.devices()) > 0}')
"
- name: Run tunix model tests
run: |
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
- name: Run tunix generation tests
run: |
python -m pytest tests/generate/ -v --tb=short -m "not cpu_only and not gpu_only"
- name: Run tunix SFT tests
run: |
python -m pytest tests/sft/ -v --tb=short -m "not cpu_only and not gpu_only" -k "not test_common"
- name: Run tunix distillation tests
run: |
python -m pytest tests/distillation/ -v --tb=short -m "not cpu_only and not gpu_only"
- name: Run tunix RL tests (basic)
run: |
python -m pytest tests/rl/common_test.py -v --tb=short -m "not cpu_only and not gpu_only"
- name: Test tunix imports
run: |
python -c "
import tunix
import tunix.models
import tunix.generate
import tunix.sft
import tunix.distillation
import tunix.rl
print('All tunix modules imported successfully')
"
- name: Test basic model loading
run: |
python -c "
from tunix.models.llama3.params import Llama3Params
print('Llama3 params loaded successfully')
"
tpu_integration_tests:
needs: prelim
runs-on: [self-hosted, linux-x86-ct5lp-224-8tpu]
container:
image: python:3.12-slim
options: --privileged --cpus=2 --memory=4Gi
env:
TPU_ACCELERATOR_TYPE: ""
JAX_PLATFORMS: "cpu"
CUDA_VISIBLE_DEVICES: ""
TF_CPP_MIN_LOG_LEVEL: "3"
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y git curl
- name: Set up JAX for TPU
run: |
pip install --upgrade pip
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- name: Install tunix dependencies
run: |
pip install -e .
pip install pytest pytest-xdist
- name: Run integration tests
run: |
# Run more comprehensive tests that might take longer
python -m pytest tests/ -v --tb=short -m "integration_test" --timeout=300
notify_failure:
name: Notify failed build
needs: [tpu_unit_tests, tpu_integration_tests]
if: ${{ always() }}
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Check whether one of the jobs failed
id: report_failure
if: ${{ contains(needs.*.result, 'failure') && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
labels: "failed-build"
- name: Reset consecutive success counter on failure
if: ${{ steps.report_failure.outputs.issue-number != '' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
ISSUE_NUMBER: ${{ steps.report_failure.outputs.issue-number }}
run: |
echo "A failure occurred. Resetting success counter on issue #${ISSUE_NUMBER}."
gh issue remove-label $ISSUE_NUMBER "success-run-1" "success-run-2" --repo $GH_REPO || echo "No success labels to remove."
notify_success_and_close:
name: Close issue after 3 successful builds
if: ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
needs: [tpu_unit_tests, tpu_integration_tests]
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Find existing failure issue
id: find_issue
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
run: |
ISSUE_NUMBER=$(gh issue list --label "failed-build" --state open --limit 1 --json number -q '.[0].number')
if [[ -z "$ISSUE_NUMBER" ]]; then
echo "No open build failure issue found. Nothing to do."
echo "issue_number=" >> $GITHUB_OUTPUT
else
echo "Found open build failure issue: #${ISSUE_NUMBER}"
echo "issue_number=${ISSUE_NUMBER}" >> $GITHUB_OUTPUT
fi
- name: Add success label or close issue
if: steps.find_issue.outputs.issue_number != ''
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
run: |
ISSUE_NUMBER=${{ steps.find_issue.outputs.issue_number }}
LABELS=$(gh issue view $ISSUE_NUMBER --json labels -q '.labels[].name')
if echo "$LABELS" | grep -q "success-run-2"; then
echo "Third consecutive success. Closing issue #${ISSUE_NUMBER}."
gh issue comment $ISSUE_NUMBER --body "Build succeeded for the third consecutive time. Closing this issue automatically."
gh issue close $ISSUE_NUMBER
gh issue remove-label $ISSUE_NUMBER "failed-build" "success-run-2" --repo $GH_REPO
elif echo "$LABELS" | grep -q "success-run-1"; then
echo "Second consecutive success. Updating label on issue #${ISSUE_NUMBER}."
gh issue comment $ISSUE_NUMBER --body "Build succeeded for the second time. One more successful run will close this issue."
gh issue remove-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO
gh issue add-label $ISSUE_NUMBER "success-run-2" --repo $GH_REPO
else
echo "First consecutive success since failure. Adding label to issue #${ISSUE_NUMBER}."
gh issue comment $ISSUE_NUMBER --body "Build succeeded. This issue will be auto-closed after two more consecutive successful runs."
gh issue add-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO
fi