From 0654441ccf6e984d3675068e2a230a774d2b1a3a Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Thu, 18 Sep 2025 22:41:57 +0400 Subject: [PATCH] Some initial tests Signed-off-by: Vladimir Suvorov Some initial tests Signed-off-by: Vladimir Suvorov Some initial tests Signed-off-by: Vladimir Suvorov Some initial tests Signed-off-by: Vladimir Suvorov Some initial tests Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov TPU simple tests Signed-off-by: Vladimir Suvorov Refactored Signed-off-by: Vladimir Suvorov Refactored Signed-off-by: Vladimir Suvorov --- .github/workflows/build_and_test_tunix.yml | 6 +- .github/workflows/tpu-tests.yml | 164 +++++++++++++++++++++ 2 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/tpu-tests.yml diff --git a/.github/workflows/build_and_test_tunix.yml b/.github/workflows/build_and_test_tunix.yml index ef4b522a4..dffd13803 100644 --- a/.github/workflows/build_and_test_tunix.yml +++ b/.github/workflows/build_and_test_tunix.yml @@ -45,9 +45,13 @@ jobs: needs: build_tunix_package uses: ./.github/workflows/run_tests_against_package.yml + tunix_tpu_tests: + needs: build_tunix_package + uses: ./.github/workflows/tpu-tests.yml + notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [tunix_cpu_unit_tests] + needs: [tunix_cpu_unit_tests, tunix_tpu_tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml new file mode 100644 index 000000000..83c09d3ff --- /dev/null +++ b/.github/workflows/tpu-tests.yml @@ -0,0 +1,164 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: TPU Tests + +on: + workflow_call: + +concurrency: + # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else + group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }} + cancel-in-progress: true + +jobs: + tpu_unit_tests: + runs-on: [linux-x86-ct5lp-224-8tpu] + container: + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest + options: --privileged + env: + CLOUD_TPU_ACCELERATOR: v5e-8 + JAX_PLATFORMS: tpu + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install tunix dependencies + run: | + pip install -e . + pip install pytest pytest-xdist + + - name: Verify TPU availability + run: | + python -c " + import jax + print(f'JAX version: {jax.__version__}') + print(f'JAX devices: {jax.devices()}') + + # Check if we have TPU devices specifically + devices = jax.devices() + has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices) + print(f'TPU available: {has_tpu}') + + if not has_tpu: + print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices]) + exit(1) + else: + print(f'SUCCESS: Found {len(devices)} TPU device(s)') + " + + - name: Run tunix model tests + run: | + python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only" + + - name: Run tunix generation tests (PASSED only) + run: | + python -m pytest tests/generate/utils_test.py -v --tb=short + + - name: Run tunix SFT tests (PASSED only) + run: | + # Config tests that passed + python -m pytest tests/sft/config_test.py -v --tb=short + + # Checkpoint manager test that passed + python -m pytest tests/sft/checkpoint_manager_test.py::CheckpointManagerTest::test_empty_root_directory -v --tb=short + + # PEFT trainer tests that passed + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_hooks -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_profiler -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_grad_accu -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume_and_grad_accu -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_without_grad_accu -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_custom_loss_fn -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_invalid_config -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_loss_fn_with_aux -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_reusing_trainer -v --tb=short + python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_shard_input_on_already_sharded_input_is_noop -v --tb=short + + # System metrics tests that passed + python -m pytest tests/sft/system_metrics_calculator_test.py -v --tb=short + + - name: Run tunix distillation tests (PASSED only) + run: | + # Distillation trainer tests that passed + python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_basic_training -v --tb=short + python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_complex_strategy_training -v --tb=short + python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_with_loss_fn_raises_exception -v --tb=short + + # Feature extraction pooling tests that passed (all 12 tests) + python -m pytest tests/distillation/feature_extraction/pooling_test.py -v --tb=short + + # Feature extraction projection tests that passed (all 6 tests) + python -m pytest tests/distillation/feature_extraction/projection_test.py -v --tb=short + + # Feature extraction sowed module tests that passed (all 5 tests) + python -m pytest tests/distillation/feature_extraction/sowed_module_test.py -v --tb=short + + # Attention strategy tests that passed + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_get_eval_loss -v --tb=short + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_neg -v --tb=short + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_over -v --tb=short + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_half -v --tb=short + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_one -v --tb=short + python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_zero -v --tb=short + + # Feature pooling strategy tests that passed + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_compute_loss_default_cosine_distance_alpha_one -v --tb=short + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_get_eval_loss -v --tb=short + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_neg -v --tb=short + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_over -v --tb=short + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_half -v --tb=short + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_one -v --tb=short + python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_zero -v --tb=short + + # Feature projection strategy tests that passed + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_neg -v --tb=short + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_over -v --tb=short + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_half -v --tb=short + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_one -v --tb=short + python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_zero -v --tb=short + + # Logit strategy tests that passed + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_eval_loss -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_train_loss -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_neg -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_over -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_neg -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_zero -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_one -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_zero -v --tb=short + python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_valid -v --tb=short + + - name: Run tunix RL tests (PASSED only) + run: | + # RL common tests that passed + python -m pytest tests/rl/common_test.py::CommonTest::test_build_positions_from_mask -v --tb=short + python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_kl -v --tb=short + python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_mse_kl -v --tb=short + python -m pytest tests/rl/common_test.py::CommonTest::test_make_causal_attn_mask -v --tb=short + python -m pytest tests/rl/common_test.py::CommonTest::test_make_completion_mask -v --tb=short + python -m pytest tests/rl/common_test.py::CommonTest::test_pad_to_length -v --tb=short + + + - name: Test basic model loading + run: | + python -m pytest tests/models/llama3/params_test.py -v --tb=short +