Skip to content

Adding basic tpu github tests #2

Adding basic tpu github tests

Adding basic tpu github tests #2

Workflow file for this run

name: TPU Tests
on:
pull_request:
branches: [ main ]
push:
branches: [ main ]
workflow_dispatch:
jobs:
tpu-tests:
runs-on: ["self-hosted", "linux-x86-ct5lp-224-8tpu"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch full history for better context
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y git curl
- name: Set up JAX for TPU
run: |
pip install --upgrade pip
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- name: Install tunix dependencies
run: |
pip install -e .
pip install pytest pytest-xdist
- name: Verify TPU availability
run: |
python -c "
import jax
print(f'JAX version: {jax.__version__}')
print(f'JAX devices: {jax.devices()}')
print(f'TPU available: {len(jax.devices()) > 0}')
"
- name: Run basic tunix tests
run: |
# Run a subset of tests that are suitable for TPU testing
python -m pytest tests/models/ -v --tb=short
python -m pytest tests/generate/ -v --tb=short
python -m pytest tests/sft/ -v --tb=short -k "not test_common"
- name: Run distillation tests
run: |
python -m pytest tests/distillation/ -v --tb=short
- name: Run RL tests (basic)
run: |
python -m pytest tests/rl/common_test.py -v --tb=short
- name: Test tunix imports
run: |
python -c "
import tunix
import tunix.models
import tunix.generate
import tunix.sft
import tunix.distillation
import tunix.rl
print('All tunix modules imported successfully')
"
- name: Test basic model loading
run: |
python -c "
from tunix.models.llama3.params import Llama3Params
print('Llama3 params loaded successfully')
"
- name: Cleanup
if: always()
run: |
echo "TPU test completed"
# Add any cleanup commands here if needed