@@ -39,59 +39,116 @@ jobs:
3939 - name : Cleanup old docker images
4040 run : docker system prune --all --force
4141
42- tpu_image :
43- needs : prelim
44- uses : ./.github/workflows/build_upload_internal.yml
45- with :
46- device_type : tpu
47- device_name : v5litepod-8
48- cloud_runner : linux-x86-ct5lp-224-8tpu
49- build_mode : jax_ai_image
50- base_image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
51-
5242 tpu_unit_tests :
53- needs : tpu_image
54- uses : ./.github/workflows/run_tests_internal.yml
55- with :
56- device_type : tpu
57- device_name : v5litepod-8
58- cloud_runner : linux-x86-ct5lp-224-8tpu
59- image_type : tpu
60- pytest_marker : ' not cpu_only and not gpu_only and not integration_test'
61- xla_python_client_mem_fraction : 0.75
62- tf_force_gpu_allow_growth : false
63- container_resource_option : " --privileged"
64- is_scheduled_run : ${{ github.event_name == 'schedule' }}
43+ needs : prelim
44+ runs-on : [self-hosted, linux-x86-ct5lp-224-8tpu]
45+ steps :
46+ - name : Checkout code
47+ uses : actions/checkout@v4
48+ with :
49+ fetch-depth : 0
50+
51+ - name : Set up Python
52+ uses : actions/setup-python@v4
53+ with :
54+ python-version : ' 3.12'
55+
56+ - name : Install system dependencies
57+ run : |
58+ sudo apt-get update
59+ sudo apt-get install -y git curl
60+
61+ - name : Set up JAX for TPU
62+ run : |
63+ pip install --upgrade pip
64+ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
65+
66+ - name : Install tunix dependencies
67+ run : |
68+ pip install -e .
69+ pip install pytest pytest-xdist
70+
71+ - name : Verify TPU availability
72+ run : |
73+ python -c "
74+ import jax
75+ print(f'JAX version: {jax.__version__}')
76+ print(f'JAX devices: {jax.devices()}')
77+ print(f'TPU available: {len(jax.devices()) > 0}')
78+ "
79+
80+ - name : Run tunix model tests
81+ run : |
82+ python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
83+
84+ - name : Run tunix generation tests
85+ run : |
86+ python -m pytest tests/generate/ -v --tb=short -m "not cpu_only and not gpu_only"
87+
88+ - name : Run tunix SFT tests
89+ run : |
90+ python -m pytest tests/sft/ -v --tb=short -m "not cpu_only and not gpu_only" -k "not test_common"
91+
92+ - name : Run tunix distillation tests
93+ run : |
94+ python -m pytest tests/distillation/ -v --tb=short -m "not cpu_only and not gpu_only"
95+
96+ - name : Run tunix RL tests (basic)
97+ run : |
98+ python -m pytest tests/rl/common_test.py -v --tb=short -m "not cpu_only and not gpu_only"
99+
100+ - name : Test tunix imports
101+ run : |
102+ python -c "
103+ import tunix
104+ import tunix.models
105+ import tunix.generate
106+ import tunix.sft
107+ import tunix.distillation
108+ import tunix.rl
109+ print('All tunix modules imported successfully')
110+ "
111+
112+ - name : Test basic model loading
113+ run : |
114+ python -c "
115+ from tunix.models.llama3.params import Llama3Params
116+ print('Llama3 params loaded successfully')
117+ "
65118
66119 tpu_integration_tests :
67- needs : tpu_image
68- uses : ./.github/workflows/run_tests_internal.yml
69- with :
70- device_type : tpu
71- device_name : v5litepod-8
72- cloud_runner : linux-x86-ct5lp-224-8tpu
73- pytest_marker : ' not cpu_only and not gpu_only and integration_test'
74- xla_python_client_mem_fraction : 0.75
75- tf_force_gpu_allow_growth : false
76- container_resource_option : " --privileged"
77- is_scheduled_run : ${{ github.event_name == 'schedule' }}
78-
79- clean_up :
80- if : ${{ always() }}
81- needs : [tpu_unit_tests, tpu_integration_tests]
82- name : " Clean up"
83- runs-on : ["self-hosted"]
84- permissions :
85- contents : read
86- issues : write
120+ needs : prelim
121+ runs-on : [self-hosted, linux-x86-ct5lp-224-8tpu]
87122 steps :
88- - name : Authenticate gcloud
123+ - name : Checkout code
124+ uses : actions/checkout@v4
125+ with :
126+ fetch-depth : 0
127+
128+ - name : Set up Python
129+ uses : actions/setup-python@v4
130+ with :
131+ python-version : ' 3.12'
132+
133+ - name : Install system dependencies
134+ run : |
135+ sudo apt-get update
136+ sudo apt-get install -y git curl
137+
138+ - name : Set up JAX for TPU
139+ run : |
140+ pip install --upgrade pip
141+ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
142+
143+ - name : Install tunix dependencies
144+ run : |
145+ pip install -e .
146+ pip install pytest pytest-xdist
147+
148+ - name : Run integration tests
89149 run : |
90- # configure registries as root and as runner
91- gcloud auth configure-docker --quiet
92- gcloud auth configure-docker us-docker.pkg.dev --quiet
93- - name : Delete the tpu image
94- run : gcloud container images delete "gcr.io/tpu-prod-env-multipod/tunix_${{ github.run_id }}:tpu" --force-delete-tags --quiet
150+ # Run more comprehensive tests that might take longer
151+ python -m pytest tests/ -v --tb=short -m "integration_test" --timeout=300
95152
96153 notify_failure :
97154 name : Notify failed build
0 commit comments