|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions |
| 16 | +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python |
| 17 | + |
| 18 | +name: TPU Tests |
| 19 | + |
| 20 | +on: |
| 21 | + workflow_call: |
| 22 | + |
| 23 | +concurrency: |
| 24 | + # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else |
| 25 | + group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }} |
| 26 | + cancel-in-progress: true |
| 27 | + |
| 28 | +jobs: |
| 29 | + tpu_unit_tests: |
| 30 | + runs-on: [linux-x86-ct5lp-224-8tpu] |
| 31 | + container: |
| 32 | + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest |
| 33 | + options: --privileged |
| 34 | + env: |
| 35 | + CLOUD_TPU_ACCELERATOR: v5e-8 |
| 36 | + JAX_PLATFORMS: tpu |
| 37 | + steps: |
| 38 | + - name: Checkout code |
| 39 | + uses: actions/checkout@v4 |
| 40 | + with: |
| 41 | + fetch-depth: 0 |
| 42 | + |
| 43 | + - name: Install tunix dependencies |
| 44 | + run: | |
| 45 | + pip install -e . |
| 46 | + pip install pytest pytest-xdist |
| 47 | + |
| 48 | + - name: Verify TPU availability |
| 49 | + run: | |
| 50 | + python -c " |
| 51 | + import jax |
| 52 | + print(f'JAX version: {jax.__version__}') |
| 53 | + print(f'JAX devices: {jax.devices()}') |
| 54 | + |
| 55 | + # Check if we have TPU devices specifically |
| 56 | + devices = jax.devices() |
| 57 | + has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices) |
| 58 | + print(f'TPU available: {has_tpu}') |
| 59 | + |
| 60 | + if not has_tpu: |
| 61 | + print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices]) |
| 62 | + exit(1) |
| 63 | + else: |
| 64 | + print(f'SUCCESS: Found {len(devices)} TPU device(s)') |
| 65 | + " |
| 66 | + |
| 67 | + - name: Run tunix model tests |
| 68 | + run: | |
| 69 | + python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only" |
| 70 | + |
| 71 | + - name: Run tunix generation tests (PASSED only) |
| 72 | + run: | |
| 73 | + python -m pytest tests/generate/utils_test.py -v --tb=short |
| 74 | +
|
| 75 | + - name: Run tunix SFT tests (PASSED only) |
| 76 | + run: | |
| 77 | + # Config tests that passed |
| 78 | + python -m pytest tests/sft/config_test.py -v --tb=short |
| 79 | + |
| 80 | + # Checkpoint manager test that passed |
| 81 | + python -m pytest tests/sft/checkpoint_manager_test.py::CheckpointManagerTest::test_empty_root_directory -v --tb=short |
| 82 | + |
| 83 | + # PEFT trainer tests that passed |
| 84 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_hooks -v --tb=short |
| 85 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_profiler -v --tb=short |
| 86 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_grad_accu -v --tb=short |
| 87 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume -v --tb=short |
| 88 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume_and_grad_accu -v --tb=short |
| 89 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_without_grad_accu -v --tb=short |
| 90 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_custom_loss_fn -v --tb=short |
| 91 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_invalid_config -v --tb=short |
| 92 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_loss_fn_with_aux -v --tb=short |
| 93 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_reusing_trainer -v --tb=short |
| 94 | + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_shard_input_on_already_sharded_input_is_noop -v --tb=short |
| 95 | + |
| 96 | + # System metrics tests that passed |
| 97 | + python -m pytest tests/sft/system_metrics_calculator_test.py -v --tb=short |
| 98 | + |
| 99 | + - name: Run tunix distillation tests (PASSED only) |
| 100 | + run: | |
| 101 | + # Distillation trainer tests that passed |
| 102 | + python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_basic_training -v --tb=short |
| 103 | + python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_complex_strategy_training -v --tb=short |
| 104 | + python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_with_loss_fn_raises_exception -v --tb=short |
| 105 | + |
| 106 | + # Feature extraction pooling tests that passed (all 12 tests) |
| 107 | + python -m pytest tests/distillation/feature_extraction/pooling_test.py -v --tb=short |
| 108 | + |
| 109 | + # Feature extraction projection tests that passed (all 6 tests) |
| 110 | + python -m pytest tests/distillation/feature_extraction/projection_test.py -v --tb=short |
| 111 | + |
| 112 | + # Feature extraction sowed module tests that passed (all 5 tests) |
| 113 | + python -m pytest tests/distillation/feature_extraction/sowed_module_test.py -v --tb=short |
| 114 | + |
| 115 | + # Attention strategy tests that passed |
| 116 | + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_get_eval_loss -v --tb=short |
| 117 | + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_neg -v --tb=short |
| 118 | + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_over -v --tb=short |
| 119 | + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_half -v --tb=short |
| 120 | + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_one -v --tb=short |
| 121 | + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_zero -v --tb=short |
| 122 | + |
| 123 | + # Feature pooling strategy tests that passed |
| 124 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_compute_loss_default_cosine_distance_alpha_one -v --tb=short |
| 125 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_get_eval_loss -v --tb=short |
| 126 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_neg -v --tb=short |
| 127 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_over -v --tb=short |
| 128 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_half -v --tb=short |
| 129 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_one -v --tb=short |
| 130 | + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_zero -v --tb=short |
| 131 | + |
| 132 | + # Feature projection strategy tests that passed |
| 133 | + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_neg -v --tb=short |
| 134 | + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_over -v --tb=short |
| 135 | + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_half -v --tb=short |
| 136 | + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_one -v --tb=short |
| 137 | + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_zero -v --tb=short |
| 138 | + |
| 139 | + # Logit strategy tests that passed |
| 140 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_eval_loss -v --tb=short |
| 141 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_train_loss -v --tb=short |
| 142 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_neg -v --tb=short |
| 143 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_over -v --tb=short |
| 144 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_neg -v --tb=short |
| 145 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_zero -v --tb=short |
| 146 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_one -v --tb=short |
| 147 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_zero -v --tb=short |
| 148 | + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_valid -v --tb=short |
| 149 | + |
| 150 | + - name: Run tunix RL tests (PASSED only) |
| 151 | + run: | |
| 152 | + # RL common tests that passed |
| 153 | + python -m pytest tests/rl/common_test.py::CommonTest::test_build_positions_from_mask -v --tb=short |
| 154 | + python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_kl -v --tb=short |
| 155 | + python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_mse_kl -v --tb=short |
| 156 | + python -m pytest tests/rl/common_test.py::CommonTest::test_make_causal_attn_mask -v --tb=short |
| 157 | + python -m pytest tests/rl/common_test.py::CommonTest::test_make_completion_mask -v --tb=short |
| 158 | + python -m pytest tests/rl/common_test.py::CommonTest::test_pad_to_length -v --tb=short |
| 159 | +
|
| 160 | + |
| 161 | + - name: Test basic model loading |
| 162 | + run: | |
| 163 | + python -m pytest tests/models/llama3/params_test.py -v --tb=short |
| 164 | +
|
0 commit comments