Skip to content

Commit dbd92b1

Browse files
committed
Some initial tests
Signed-off-by: Vladimir Suvorov <suvorovv@google.com>
1 parent 74aa4c8 commit dbd92b1

File tree

3 files changed

+105
-225
lines changed

3 files changed

+105
-225
lines changed

.github/workflows/build_upload_internal.yml

Lines changed: 0 additions & 62 deletions
This file was deleted.

.github/workflows/run_tests_internal.yml

Lines changed: 0 additions & 115 deletions
This file was deleted.

.github/workflows/tpu-tests.yml

Lines changed: 105 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -39,59 +39,116 @@ jobs:
3939
- name: Cleanup old docker images
4040
run: docker system prune --all --force
4141

42-
tpu_image:
43-
needs: prelim
44-
uses: ./.github/workflows/build_upload_internal.yml
45-
with:
46-
device_type: tpu
47-
device_name: v5litepod-8
48-
cloud_runner: linux-x86-ct5lp-224-8tpu
49-
build_mode: jax_ai_image
50-
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
51-
5242
tpu_unit_tests:
53-
needs: tpu_image
54-
uses: ./.github/workflows/run_tests_internal.yml
55-
with:
56-
device_type: tpu
57-
device_name: v5litepod-8
58-
cloud_runner: linux-x86-ct5lp-224-8tpu
59-
image_type: tpu
60-
pytest_marker: 'not cpu_only and not gpu_only and not integration_test'
61-
xla_python_client_mem_fraction: 0.75
62-
tf_force_gpu_allow_growth: false
63-
container_resource_option: "--privileged"
64-
is_scheduled_run: ${{ github.event_name == 'schedule' }}
43+
needs: prelim
44+
runs-on: [self-hosted, linux-x86-ct5lp-224-8tpu]
45+
steps:
46+
- name: Checkout code
47+
uses: actions/checkout@v4
48+
with:
49+
fetch-depth: 0
50+
51+
- name: Set up Python
52+
uses: actions/setup-python@v4
53+
with:
54+
python-version: '3.12'
55+
56+
- name: Install system dependencies
57+
run: |
58+
sudo apt-get update
59+
sudo apt-get install -y git curl
60+
61+
- name: Set up JAX for TPU
62+
run: |
63+
pip install --upgrade pip
64+
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
65+
66+
- name: Install tunix dependencies
67+
run: |
68+
pip install -e .
69+
pip install pytest pytest-xdist
70+
71+
- name: Verify TPU availability
72+
run: |
73+
python -c "
74+
import jax
75+
print(f'JAX version: {jax.__version__}')
76+
print(f'JAX devices: {jax.devices()}')
77+
print(f'TPU available: {len(jax.devices()) > 0}')
78+
"
79+
80+
- name: Run tunix model tests
81+
run: |
82+
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
83+
84+
- name: Run tunix generation tests
85+
run: |
86+
python -m pytest tests/generate/ -v --tb=short -m "not cpu_only and not gpu_only"
87+
88+
- name: Run tunix SFT tests
89+
run: |
90+
python -m pytest tests/sft/ -v --tb=short -m "not cpu_only and not gpu_only" -k "not test_common"
91+
92+
- name: Run tunix distillation tests
93+
run: |
94+
python -m pytest tests/distillation/ -v --tb=short -m "not cpu_only and not gpu_only"
95+
96+
- name: Run tunix RL tests (basic)
97+
run: |
98+
python -m pytest tests/rl/common_test.py -v --tb=short -m "not cpu_only and not gpu_only"
99+
100+
- name: Test tunix imports
101+
run: |
102+
python -c "
103+
import tunix
104+
import tunix.models
105+
import tunix.generate
106+
import tunix.sft
107+
import tunix.distillation
108+
import tunix.rl
109+
print('All tunix modules imported successfully')
110+
"
111+
112+
- name: Test basic model loading
113+
run: |
114+
python -c "
115+
from tunix.models.llama3.params import Llama3Params
116+
print('Llama3 params loaded successfully')
117+
"
65118
66119
tpu_integration_tests:
67-
needs: tpu_image
68-
uses: ./.github/workflows/run_tests_internal.yml
69-
with:
70-
device_type: tpu
71-
device_name: v5litepod-8
72-
cloud_runner: linux-x86-ct5lp-224-8tpu
73-
pytest_marker: 'not cpu_only and not gpu_only and integration_test'
74-
xla_python_client_mem_fraction: 0.75
75-
tf_force_gpu_allow_growth: false
76-
container_resource_option: "--privileged"
77-
is_scheduled_run: ${{ github.event_name == 'schedule' }}
78-
79-
clean_up:
80-
if: ${{ always() }}
81-
needs: [tpu_unit_tests, tpu_integration_tests]
82-
name: "Clean up"
83-
runs-on: ["self-hosted"]
84-
permissions:
85-
contents: read
86-
issues: write
120+
needs: prelim
121+
runs-on: [self-hosted, linux-x86-ct5lp-224-8tpu]
87122
steps:
88-
- name: Authenticate gcloud
123+
- name: Checkout code
124+
uses: actions/checkout@v4
125+
with:
126+
fetch-depth: 0
127+
128+
- name: Set up Python
129+
uses: actions/setup-python@v4
130+
with:
131+
python-version: '3.12'
132+
133+
- name: Install system dependencies
134+
run: |
135+
sudo apt-get update
136+
sudo apt-get install -y git curl
137+
138+
- name: Set up JAX for TPU
139+
run: |
140+
pip install --upgrade pip
141+
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
142+
143+
- name: Install tunix dependencies
144+
run: |
145+
pip install -e .
146+
pip install pytest pytest-xdist
147+
148+
- name: Run integration tests
89149
run: |
90-
# configure registries as root and as runner
91-
gcloud auth configure-docker --quiet
92-
gcloud auth configure-docker us-docker.pkg.dev --quiet
93-
- name: Delete the tpu image
94-
run: gcloud container images delete "gcr.io/tpu-prod-env-multipod/tunix_${{ github.run_id }}:tpu" --force-delete-tags --quiet
150+
# Run more comprehensive tests that might take longer
151+
python -m pytest tests/ -v --tb=short -m "integration_test" --timeout=300
95152
96153
notify_failure:
97154
name: Notify failed build

0 commit comments

Comments
 (0)