Adding basic tpu github tests #28
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # This workflow will install Python dependencies, run tests and lint with a variety of Python versions | |
| # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | |
| name: TPU Tests | |
| on: | |
| pull_request: | |
| branches: [ main ] | |
| workflow_dispatch: | |
| schedule: | |
| # Run the job every 6 hours | |
| - cron: '0 */6 * * *' | |
| concurrency: | |
| # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else | |
| group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }} | |
| cancel-in-progress: true | |
| jobs: | |
| tpu_unit_tests: | |
| runs-on: [linux-x86-ct5lp-224-8tpu] | |
| container: | |
| image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest | |
| options: --privileged | |
| env: | |
| CLOUD_TPU_ACCELERATOR: v5e-8 | |
| JAX_PLATFORMS: tpu | |
| steps: | |
| - name: Checkout code | |
| uses: actions/checkout@v4 | |
| with: | |
| fetch-depth: 0 | |
| - name: Install tunix dependencies | |
| run: | | |
| pip install -e . | |
| pip install pytest pytest-xdist | |
| - name: Verify TPU availability | |
| run: | | |
| python -c " | |
| import jax | |
| print(f'JAX version: {jax.__version__}') | |
| print(f'JAX devices: {jax.devices()}') | |
| print(f'TPU available: {len(jax.devices()) > 0}') | |
| " | |
| - name: Run tunix model tests | |
| run: | | |
| python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only" | |
| - name: Run tunix generation tests (PASSED only) | |
| run: | | |
| python -m pytest tests/generate/utils_test.py -v --tb=short | |
| - name: Run tunix SFT tests (PASSED only) | |
| run: | | |
| # Config tests that passed | |
| python -m pytest tests/sft/config_test.py -v --tb=short | |
| # DPO tests that passed | |
| python -m pytest tests/sft/dpo/dpo_trainer_test.py::DpoTrainerTest::test_dpo_trainer_chosen_reject_equal_length -v --tb=short | |
| # Checkpoint manager test that passed | |
| python -m pytest tests/sft/checkpoint_manager_test.py::CheckpointManagerTest::test_empty_root_directory -v --tb=short | |
| # PEFT trainer tests that passed | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_hooks -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_profiler -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_grad_accu -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume_and_grad_accu -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_without_grad_accu -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_custom_loss_fn -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_invalid_config -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_loss_fn_with_aux -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_reusing_trainer -v --tb=short | |
| python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_shard_input_on_already_sharded_input_is_noop -v --tb=short | |
| # System metrics tests that passed | |
| python -m pytest tests/sft/system_metrics_calculator_test.py -v --tb=short | |
| - name: Run tunix distillation tests (PASSED only) | |
| run: | | |
| # Distillation trainer tests that passed | |
| python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_basic_training -v --tb=short | |
| python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_complex_strategy_training -v --tb=short | |
| python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_with_loss_fn_raises_exception -v --tb=short | |
| # Feature extraction pooling tests that passed (all 12 tests) | |
| python -m pytest tests/distillation/feature_extraction/pooling_test.py -v --tb=short | |
| # Feature extraction projection tests that passed (all 6 tests) | |
| python -m pytest tests/distillation/feature_extraction/projection_test.py -v --tb=short | |
| # Feature extraction sowed module tests that passed (all 5 tests) | |
| python -m pytest tests/distillation/feature_extraction/sowed_module_test.py -v --tb=short | |
| # Attention strategy tests that passed | |
| python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_get_eval_loss -v --tb=short | |
| python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_neg -v --tb=short | |
| python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_over -v --tb=short | |
| python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_half -v --tb=short | |
| python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_one -v --tb=short | |
| python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_zero -v --tb=short | |
| # Feature pooling strategy tests that passed | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_compute_loss_default_cosine_distance_alpha_one -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_get_eval_loss -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_neg -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_over -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_half -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_one -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_zero -v --tb=short | |
| # Feature projection strategy tests that passed | |
| python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_neg -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_over -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_half -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_one -v --tb=short | |
| python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_zero -v --tb=short | |
| # Logit strategy tests that passed | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_eval_loss -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_train_loss -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_neg -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_over -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_neg -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_zero -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_one -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_zero -v --tb=short | |
| python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_valid -v --tb=short | |
| - name: Run tunix RL tests (basic) | |
| run: | | |
| python -m pytest tests/rl/common_test.py -v --tb=short -m "not cpu_only and not gpu_only" | |
| - name: Test tunix imports | |
| run: | | |
| python -c " | |
| import tunix | |
| import tunix.models | |
| import tunix.generate | |
| import tunix.sft | |
| import tunix.distillation | |
| import tunix.rl | |
| print('All tunix modules imported successfully') | |
| " | |
| - name: Test basic model loading | |
| run: | | |
| python -c " | |
| from tunix.models.llama3.params import Llama3Params | |
| print('Llama3 params loaded successfully') | |
| " | |
| notify_failure: | |
| name: Notify failed build | |
| needs: [tpu_unit_tests] | |
| if: ${{ always() }} | |
| runs-on: ubuntu-latest | |
| permissions: | |
| issues: write | |
| steps: | |
| - name: Check whether one of the jobs failed | |
| id: report_failure | |
| if: ${{ contains(needs.*.result, 'failure') && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }} | |
| uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 | |
| with: | |
| github-token: ${{ secrets.GITHUB_TOKEN }} | |
| labels: "failed-build" | |
| - name: Reset consecutive success counter on failure | |
| if: ${{ steps.report_failure.outputs.issue-number != '' }} | |
| env: | |
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
| GH_REPO: ${{ github.repository }} | |
| ISSUE_NUMBER: ${{ steps.report_failure.outputs.issue-number }} | |
| run: | | |
| echo "A failure occurred. Resetting success counter on issue #${ISSUE_NUMBER}." | |
| gh issue remove-label $ISSUE_NUMBER "success-run-1" "success-run-2" --repo $GH_REPO || echo "No success labels to remove." | |
| notify_success_and_close: | |
| name: Close issue after 3 successful builds | |
| if: ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }} | |
| needs: [tpu_unit_tests] | |
| runs-on: ubuntu-latest | |
| permissions: | |
| issues: write | |
| steps: | |
| - name: Find existing failure issue | |
| id: find_issue | |
| env: | |
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
| GH_REPO: ${{ github.repository }} | |
| run: | | |
| ISSUE_NUMBER=$(gh issue list --label "failed-build" --state open --limit 1 --json number -q '.[0].number') | |
| if [[ -z "$ISSUE_NUMBER" ]]; then | |
| echo "No open build failure issue found. Nothing to do." | |
| echo "issue_number=" >> $GITHUB_OUTPUT | |
| else | |
| echo "Found open build failure issue: #${ISSUE_NUMBER}" | |
| echo "issue_number=${ISSUE_NUMBER}" >> $GITHUB_OUTPUT | |
| fi | |
| - name: Add success label or close issue | |
| if: steps.find_issue.outputs.issue_number != '' | |
| env: | |
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
| GH_REPO: ${{ github.repository }} | |
| run: | | |
| ISSUE_NUMBER=${{ steps.find_issue.outputs.issue_number }} | |
| LABELS=$(gh issue view $ISSUE_NUMBER --json labels -q '.labels[].name') | |
| if echo "$LABELS" | grep -q "success-run-2"; then | |
| echo "Third consecutive success. Closing issue #${ISSUE_NUMBER}." | |
| gh issue comment $ISSUE_NUMBER --body "Build succeeded for the third consecutive time. Closing this issue automatically." | |
| gh issue close $ISSUE_NUMBER | |
| gh issue remove-label $ISSUE_NUMBER "failed-build" "success-run-2" --repo $GH_REPO | |
| elif echo "$LABELS" | grep -q "success-run-1"; then | |
| echo "Second consecutive success. Updating label on issue #${ISSUE_NUMBER}." | |
| gh issue comment $ISSUE_NUMBER --body "Build succeeded for the second time. One more successful run will close this issue." | |
| gh issue remove-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO | |
| gh issue add-label $ISSUE_NUMBER "success-run-2" --repo $GH_REPO | |
| else | |
| echo "First consecutive success since failure. Adding label to issue #${ISSUE_NUMBER}." | |
| gh issue comment $ISSUE_NUMBER --body "Build succeeded. This issue will be auto-closed after two more consecutive successful runs." | |
| gh issue add-label $ISSUE_NUMBER "success-run-1" --repo $GH_REPO | |
| fi |