Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/build_and_test_tunix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ jobs:
needs: build_tunix_package
uses: ./.github/workflows/run_tests_against_package.yml

tunix_tpu_tests:
needs: build_tunix_package
uses: ./.github/workflows/tpu-tests.yml

notify_failure:
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
needs: [tunix_cpu_unit_tests]
needs: [tunix_cpu_unit_tests, tunix_tpu_tests]
if: ${{ always() }}
runs-on: ubuntu-latest
permissions:
Expand Down
164 changes: 164 additions & 0 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: TPU Tests

on:
workflow_call:

concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
cancel-in-progress: true

jobs:
tpu_unit_tests:
runs-on: [linux-x86-ct5lp-224-8tpu]
container:
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Install tunix dependencies
run: |
pip install -e .
pip install pytest pytest-xdist

- name: Verify TPU availability
run: |
python -c "
import jax
print(f'JAX version: {jax.__version__}')
print(f'JAX devices: {jax.devices()}')

# Check if we have TPU devices specifically
devices = jax.devices()
has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
print(f'TPU available: {has_tpu}')

if not has_tpu:
print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
exit(1)
else:
print(f'SUCCESS: Found {len(devices)} TPU device(s)')
"

- name: Run tunix model tests
run: |
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"

- name: Run tunix generation tests (PASSED only)
run: |
python -m pytest tests/generate/utils_test.py -v --tb=short

- name: Run tunix SFT tests (PASSED only)
run: |
# Config tests that passed
python -m pytest tests/sft/config_test.py -v --tb=short

# Checkpoint manager test that passed
python -m pytest tests/sft/checkpoint_manager_test.py::CheckpointManagerTest::test_empty_root_directory -v --tb=short

# PEFT trainer tests that passed
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_hooks -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_profiler -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_grad_accu -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume_and_grad_accu -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_without_grad_accu -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_custom_loss_fn -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_invalid_config -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_loss_fn_with_aux -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_reusing_trainer -v --tb=short
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_shard_input_on_already_sharded_input_is_noop -v --tb=short

# System metrics tests that passed
python -m pytest tests/sft/system_metrics_calculator_test.py -v --tb=short

- name: Run tunix distillation tests (PASSED only)
run: |
# Distillation trainer tests that passed
python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_basic_training -v --tb=short
python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_complex_strategy_training -v --tb=short
python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_with_loss_fn_raises_exception -v --tb=short

# Feature extraction pooling tests that passed (all 12 tests)
python -m pytest tests/distillation/feature_extraction/pooling_test.py -v --tb=short

# Feature extraction projection tests that passed (all 6 tests)
python -m pytest tests/distillation/feature_extraction/projection_test.py -v --tb=short

# Feature extraction sowed module tests that passed (all 5 tests)
python -m pytest tests/distillation/feature_extraction/sowed_module_test.py -v --tb=short

# Attention strategy tests that passed
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_get_eval_loss -v --tb=short
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_neg -v --tb=short
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_over -v --tb=short
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_half -v --tb=short
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_one -v --tb=short
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_zero -v --tb=short

# Feature pooling strategy tests that passed
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_compute_loss_default_cosine_distance_alpha_one -v --tb=short
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_get_eval_loss -v --tb=short
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_neg -v --tb=short
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_over -v --tb=short
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_half -v --tb=short
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_one -v --tb=short
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_zero -v --tb=short

# Feature projection strategy tests that passed
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_neg -v --tb=short
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_over -v --tb=short
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_half -v --tb=short
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_one -v --tb=short
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_zero -v --tb=short

# Logit strategy tests that passed
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_eval_loss -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_train_loss -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_neg -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_over -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_neg -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_zero -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_one -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_zero -v --tb=short
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_valid -v --tb=short

- name: Run tunix RL tests (PASSED only)
run: |
# RL common tests that passed
python -m pytest tests/rl/common_test.py::CommonTest::test_build_positions_from_mask -v --tb=short
python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_kl -v --tb=short
python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_mse_kl -v --tb=short
python -m pytest tests/rl/common_test.py::CommonTest::test_make_causal_attn_mask -v --tb=short
python -m pytest tests/rl/common_test.py::CommonTest::test_make_completion_mask -v --tb=short
python -m pytest tests/rl/common_test.py::CommonTest::test_pad_to_length -v --tb=short


- name: Test basic model loading
run: |
python -m pytest tests/models/llama3/params_test.py -v --tb=short