Skip to content

Adding basic tpu github tests #26

Adding basic tpu github tests

Adding basic tpu github tests #26

Workflow file for this run

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: TPU Tests
on:
pull_request:
branches: [ main ]
workflow_dispatch:
schedule:
# Run the job every 6 hours
- cron: '0 */6 * * *'
concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
cancel-in-progress: true
jobs:
tpu_unit_tests:
runs-on: [linux-x86-ct5lp-224-8tpu]
container:
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install tunix dependencies
run: |
pip install -e .
pip install pytest pytest-xdist
- name: Verify TPU availability
run: |
python -c "
import jax
print(f'JAX version: {jax.__version__}')
print(f'JAX devices: {jax.devices()}')
print(f'TPU available: {len(jax.devices()) > 0}')
"
- name: Run tunix model tests
run: |
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
- name: Run tunix generation tests (PASSED only)
run: |
python -m pytest tests/generate/utils_test.py -v --tb=short
- name: Run tunix SFT tests (PASSED only)
run: |
# Config tests that passed
python -m pytest tests/sft/config_test.py -v --tb=short
# DPO tests that passed
python -m pytest tests/sft/dpo/dpo_trainer_test.py::DpoTrainerTest::test_dpo_loss_fn -v --tb=short
python -m pytest tests/sft/dpo/dpo_trainer_test.py::DpoTrainerTest::test_dpo_trainer_chosen_reject_equal_length -v --tb=short
# Checkpoint manager test that passed
python -m pytest tests/sft/checkpoint_manager_test.py::CheckpointManagerTest::test_empty_root_directory -v --tb=short
# PEFT trainer tests that passed
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_hooks -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_profiler -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_grad_accu -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume_and_grad_accu -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_without_grad_accu -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_custom_loss_fn -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_invalid_config -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_loss_fn_with_aux -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_reusing_trainer -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_shard_input_on_already_sharded_input_is_noop -v --tb=short
# System metrics tests that passed
python -m pytest tests/sft/system_metrics_calculator_test.py -v --tb=short
- name: Run tunix distillation tests
run: |
python -m pytest tests/distillation/ -v --tb=short -m "not cpu_only and not gpu_only"
- name: Run tunix RL tests (basic)
run: |
python -m pytest tests/rl/common_test.py -v --tb=short -m "not cpu_only and not gpu_only"
- name: Test tunix imports
run: |
python -c "
import tunix
import tunix.models
import tunix.generate
import tunix.sft
import tunix.distillation
import tunix.rl
print('All tunix modules imported successfully')
"
- name: Test basic model loading
run: |
python -c "
from tunix.models.llama3.params import Llama3Params
print('Llama3 params loaded successfully')
"
notify_failure:
name: Notify failed build
needs: [tpu_unit_tests]
if: ${{ always() }}
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Check whether one of the jobs failed
id: report_failure
if: ${{ contains(needs.*.result, 'failure') && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
labels: "failed-build"
- name: Reset consecutive success counter on failure
if: ${{ steps.report_failure.outputs.issue-number != '' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
ISSUE_NUMBER: ${{ steps.report_failure.outputs.issue-number }}
run: |
echo "A failure occurred. Resetting success counter on issue #${ISSUE_NUMBER}."
gh issue remove-label $ISSUE_NUMBER "success-run-1" "success-run-2" --repo $GH_REPO || echo "No success labels to remove."
notify_success_and_close:
name: Close issue after 3 successful builds
if: ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
needs: [tpu_unit_tests]
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Find existing failure issue
id: find_issue
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
run: |
ISSUE_NUMBER=$(gh issue list --label "failed-build" --state open --limit 1 --json number -q '.[0].number')
if [[ -z "$ISSUE_NUMBER" ]]; then
echo "No open build failure issue found. Nothing to do."
echo "issue_number=" >> $GITHUB_OUTPUT
else
echo "Found open build failure issue: #${ISSUE_NUMBER}"
echo "issue_number=${ISSUE_NUMBER}" >> $GITHUB_OUTPUT
fi
- name: Add success label or close issue
if: steps.find_issue.outputs.issue_number != ''
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
run: |
ISSUE_NUMBER=${{ steps.find_issue.outputs.issue_number }}
LABELS=$(gh issue view $ISSUE_NUMBER --json labels -q '.labels[].name')
if echo "$LABELS" | grep -q "success-run-2"; then
echo "Third consecutive success. Closing issue #${ISSUE_NUMBER}."
gh issue comment $ISSUE_NUMBER --body "Build succeeded for the third consecutive time. Closing this issue automatically."
gh issue close $ISSUE_NUMBER
gh issue remove-label $ISSUE_NUMBER "failed-build" "success-run-2" --repo $GH_REPO
elif echo "$LABELS" | grep -q "success-run-1"; then
echo "Second consecutive success. Updating label on issue #${ISSUE_NUMBER}."
gh issue comment $ISSUE_NUMBER --body "Build succeeded for the second time. One more successful run will close this issue."
gh issue remove-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO
gh issue add-label $ISSUE_NUMBER "success-run-2" --repo $GH_REPO
else
echo "First consecutive success since failure. Adding label to issue #${ISSUE_NUMBER}."
gh issue comment $ISSUE_NUMBER --body "Build succeeded. This issue will be auto-closed after two more consecutive successful runs."
gh issue add-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO
fi