Skip to content

Commit 0654441

Browse files
committed
Some initial tests
Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Some initial tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Some initial tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Some initial tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Some initial tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Fix for linux-x86-ct5lp-224-8tpu Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> TPU simple tests Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Refactored Signed-off-by: Vladimir Suvorov <suvorovv@google.com> Refactored Signed-off-by: Vladimir Suvorov <suvorovv@google.com>
1 parent 688c395 commit 0654441

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

.github/workflows/build_and_test_tunix.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,13 @@ jobs:
4545
needs: build_tunix_package
4646
uses: ./.github/workflows/run_tests_against_package.yml
4747

48+
tunix_tpu_tests:
49+
needs: build_tunix_package
50+
uses: ./.github/workflows/tpu-tests.yml
51+
4852
notify_failure:
4953
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
50-
needs: [tunix_cpu_unit_tests]
54+
needs: [tunix_cpu_unit_tests, tunix_tpu_tests]
5155
if: ${{ always() }}
5256
runs-on: ubuntu-latest
5357
permissions:

.github/workflows/tpu-tests.yml

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
16+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
17+
18+
name: TPU Tests
19+
20+
on:
21+
workflow_call:
22+
23+
concurrency:
24+
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
25+
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
26+
cancel-in-progress: true
27+
28+
jobs:
29+
tpu_unit_tests:
30+
runs-on: [linux-x86-ct5lp-224-8tpu]
31+
container:
32+
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
33+
options: --privileged
34+
env:
35+
CLOUD_TPU_ACCELERATOR: v5e-8
36+
JAX_PLATFORMS: tpu
37+
steps:
38+
- name: Checkout code
39+
uses: actions/checkout@v4
40+
with:
41+
fetch-depth: 0
42+
43+
- name: Install tunix dependencies
44+
run: |
45+
pip install -e .
46+
pip install pytest pytest-xdist
47+
48+
- name: Verify TPU availability
49+
run: |
50+
python -c "
51+
import jax
52+
print(f'JAX version: {jax.__version__}')
53+
print(f'JAX devices: {jax.devices()}')
54+
55+
# Check if we have TPU devices specifically
56+
devices = jax.devices()
57+
has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
58+
print(f'TPU available: {has_tpu}')
59+
60+
if not has_tpu:
61+
print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
62+
exit(1)
63+
else:
64+
print(f'SUCCESS: Found {len(devices)} TPU device(s)')
65+
"
66+
67+
- name: Run tunix model tests
68+
run: |
69+
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
70+
71+
- name: Run tunix generation tests (PASSED only)
72+
run: |
73+
python -m pytest tests/generate/utils_test.py -v --tb=short
74+
75+
- name: Run tunix SFT tests (PASSED only)
76+
run: |
77+
# Config tests that passed
78+
python -m pytest tests/sft/config_test.py -v --tb=short
79+
80+
# Checkpoint manager test that passed
81+
python -m pytest tests/sft/checkpoint_manager_test.py::CheckpointManagerTest::test_empty_root_directory -v --tb=short
82+
83+
# PEFT trainer tests that passed
84+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_hooks -v --tb=short
85+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_basic_training_with_profiler -v --tb=short
86+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_grad_accu -v --tb=short
87+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume -v --tb=short
88+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_with_resume_and_grad_accu -v --tb=short
89+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_checkpointing_without_grad_accu -v --tb=short
90+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_custom_loss_fn -v --tb=short
91+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_invalid_config -v --tb=short
92+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_loss_fn_with_aux -v --tb=short
93+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_reusing_trainer -v --tb=short
94+
python -m pytest tests/sft/peft_trainer_test.py::PeftTrainerTest::test_shard_input_on_already_sharded_input_is_noop -v --tb=short
95+
96+
# System metrics tests that passed
97+
python -m pytest tests/sft/system_metrics_calculator_test.py -v --tb=short
98+
99+
- name: Run tunix distillation tests (PASSED only)
100+
run: |
101+
# Distillation trainer tests that passed
102+
python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_basic_training -v --tb=short
103+
python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_complex_strategy_training -v --tb=short
104+
python -m pytest tests/distillation/distillation_trainer_test.py::DistillationTrainerTest::test_with_loss_fn_raises_exception -v --tb=short
105+
106+
# Feature extraction pooling tests that passed (all 12 tests)
107+
python -m pytest tests/distillation/feature_extraction/pooling_test.py -v --tb=short
108+
109+
# Feature extraction projection tests that passed (all 6 tests)
110+
python -m pytest tests/distillation/feature_extraction/projection_test.py -v --tb=short
111+
112+
# Feature extraction sowed module tests that passed (all 5 tests)
113+
python -m pytest tests/distillation/feature_extraction/sowed_module_test.py -v --tb=short
114+
115+
# Attention strategy tests that passed
116+
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_get_eval_loss -v --tb=short
117+
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_neg -v --tb=short
118+
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_invalid_alpha_over -v --tb=short
119+
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_half -v --tb=short
120+
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_one -v --tb=short
121+
python -m pytest tests/distillation/strategies/attention_test.py::AttentionTransferStrategyTest::test_init_valid_alpha_zero -v --tb=short
122+
123+
# Feature pooling strategy tests that passed
124+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_compute_loss_default_cosine_distance_alpha_one -v --tb=short
125+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_get_eval_loss -v --tb=short
126+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_neg -v --tb=short
127+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_invalid_alpha_over -v --tb=short
128+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_half -v --tb=short
129+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_one -v --tb=short
130+
python -m pytest tests/distillation/strategies/feature_pooling_test.py::FeaturePoolingStrategyTest::test_init_valid_alpha_zero -v --tb=short
131+
132+
# Feature projection strategy tests that passed
133+
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_neg -v --tb=short
134+
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_invalid_alpha_over -v --tb=short
135+
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_half -v --tb=short
136+
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_one -v --tb=short
137+
python -m pytest tests/distillation/strategies/feature_projection_test.py::FeatureProjectionStrategyTest::test_init_valid_alpha_zero -v --tb=short
138+
139+
# Logit strategy tests that passed
140+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_eval_loss -v --tb=short
141+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_get_train_loss -v --tb=short
142+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_neg -v --tb=short
143+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_alpha_over -v --tb=short
144+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_neg -v --tb=short
145+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_invalid_temp_zero -v --tb=short
146+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_one -v --tb=short
147+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_alpha_zero -v --tb=short
148+
python -m pytest tests/distillation/strategies/logit_test.py::LogitStrategyTest::test_init_valid_valid -v --tb=short
149+
150+
- name: Run tunix RL tests (PASSED only)
151+
run: |
152+
# RL common tests that passed
153+
python -m pytest tests/rl/common_test.py::CommonTest::test_build_positions_from_mask -v --tb=short
154+
python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_kl -v --tb=short
155+
python -m pytest tests/rl/common_test.py::CommonTest::test_compute_kl_divergence_mse_kl -v --tb=short
156+
python -m pytest tests/rl/common_test.py::CommonTest::test_make_causal_attn_mask -v --tb=short
157+
python -m pytest tests/rl/common_test.py::CommonTest::test_make_completion_mask -v --tb=short
158+
python -m pytest tests/rl/common_test.py::CommonTest::test_pad_to_length -v --tb=short
159+
160+
161+
- name: Test basic model loading
162+
run: |
163+
python -m pytest tests/models/llama3/params_test.py -v --tb=short
164+

0 commit comments

Comments
 (0)