1212
1313permissions :
1414 contents : read
15+ id-token : write
16+
17+ env :
18+ PYTHON : ${{ matrix.python-version }}
19+ KERAS_HOME : .github/workflows/config/${{ matrix.backend }}
20+ KERAS_BACKEND : jax
21+ PROJECT_ID : gtech-rmi-dev # Replace with your GCP project ID
22+ GAR_LOCATION : us-central1 # Replace with your Artifact Registry location (e.g., us-central1)
23+ IMAGE_REPO : keras-docker-images
24+ IMAGE_NAME : keras-jax-tpu-amd64:latest # Name of your Docker image
25+ TPU_VM_NAME : kharshith-jax-tpu # Replace with your TPU VM instance name
26+ TPU_VM_ZONE : us-central1-b # Replace with your TPU VM zone
1527
1628jobs :
1729 build-and-test-on-tpu :
@@ -22,105 +34,78 @@ jobs:
2234 backend : [jax]
2335 name : Run TPU tests
2436 runs-on :
25- - linux-x86-ct5lp-112-4tpu
37+ # - keras-jax-tpu-runner
38+ # - linux-x86-ct5lp-112-4tpu
2639 # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n
27- # - linux-x86-ct6e-44-1tpu
40+ - linux-x86-ct6e-44-1tpu
2841 # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4
2942 # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc
3043
44+ container : us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest
45+
46+ # container:
47+ # image: docker:latest # Provides the Docker CLI within the job container
48+ # volumes:
49+ # - /var/run/docker.sock:/var/run/docker.sock # Mounts host's Docker socket for control
50+ # options: --privileged
51+
52+ # steps:
53+ # - name: Checkout Repository
54+ # uses: actions/checkout@v4
55+
56+ # - name: Set up Docker BuildX
57+ # uses: docker/setup-buildx-action@v3
58+
59+ # - name: Authenticate to Google Cloud (Workload Identity Federation)
60+ # id: 'auth'
61+ # uses: 'google-github-actions/auth@v2'
62+ # with:
63+ # # Replace with your Workload Identity Federation provider details.
64+ # # This service account needs 'Artifact Registry Writer' role.
65+ # workload_identity_provider: 'projects/YOUR_PROJECT_NUMBER/locations/global/workloadIdentityPools/YOUR_POOL_ID/providers/YOUR_PROVIDER_ID'
66+ # service_account: 'your-github-actions-sa@${{ env.PROJECT_ID }}.iam.gserviceaccount.com'
67+
68+ # - name: Configure Docker to use Google Artifact Registry
69+ # run: gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev
70+
71+ # - name: Build Docker Image
72+ # run: |
73+ # IMAGE_TAG="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:${{ github.sha }}"
74+ # echo "Building Docker image: $IMAGE_TAG"
75+ # docker build \
76+ # --platform=linux/amd64 \
77+ # -f .github/workflows/tpu/Dockerfile \
78+ # -t "$IMAGE_TAG" \
79+ # -t "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" \
80+ # .
81+ # echo "Built Docker image: $IMAGE_TAG"
82+ # echo "LOCAL_TEST_IMAGE_TAG=$IMAGE_TAG" >> $GITHUB_ENV # Store for immediate use in run step
83+
84+ # - name: Push Docker Image to Artifact Registry
85+ # run: |
86+ # echo "Pushing Docker image to Artifact Registry: ${{ env.LOCAL_TEST_IMAGE_TAG }}"
87+ # docker push "${{ env.LOCAL_TEST_IMAGE_TAG }}"
88+ # docker push "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest"
89+ # echo "Pushed Docker image."
90+
91+ # - name: Run Docker container and execute tests on TPU
92+ # run: |
93+ # echo "Running Docker container with TPU access and executing tests..."
94+ # docker run --rm \
95+ # --privileged \
96+ # --network host \
97+ # -e PYTHON=3.10 \ # Use a specific version or derive from matrix
98+ # -e KERAS_HOME=.github/workflows/config/jax \
99+ # -e KERAS_BACKEND=jax \
100+ # ${{ env.LOCAL_TEST_IMAGE_TAG }} \
101+ # /bin/bash -c ' \
102+ # echo "Verifying JAX TPU backend inside container..." && \
103+ # python3 -c "import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")" \
104+ # # Add your actual pytest command here. Ensure pytest is installed inside your Docker image.
105+ # # && pytest keras --ignore keras/src/applications --ignore keras/src/layers/merging/merging_test.py --cov=keras --cov-config=pyproject.toml
106+ # '
107+ # echo "Docker container finished running tests."
31108
32- container :
33- # Use an official Docker image that includes the Docker CLI.
34- # This allows you to run 'docker' commands from within this job's container.
35- # 'docker:latest' is a good choice. You could also specify a version like 'docker:24.0.5'.
36- image : docker:latest
37- # Mount the host's Docker socket into this container.
38- # This is CRUCIAL: It allows 'docker' commands executed *inside* this container
39- # to control the *host's* Docker daemon.
40- # volumes:
41- # - /var/run/docker.sock:/var/run/docker.sock
42- # Running this "controlling" container in privileged mode is often necessary
43- # when it needs to manage other containers and access host resources like TPUs
44- # through the host's Docker daemon.
45- options : --privileged
46-
47- env :
48- PYTHON : ${{ matrix.python-version }}
49- KERAS_HOME : .github/workflows/config/${{ matrix.backend }}
50- KERAS_BACKEND : jax
51- PROJECT_ID : gtech-rmi-dev # Replace with your GCP project ID
52- GAR_LOCATION : us-central1 # Replace with your Artifact Registry location (e.g., us-central1)
53- IMAGE_NAME : keras-jax-tpu-amd64:latest # Name of your Docker image
54- TPU_VM_NAME : kharshith-jax-tpu # Replace with your TPU VM instance name
55- TPU_VM_ZONE : us-central1-b # Replace with your TPU VM zone
56-
57- steps :
58-
59- - name : Checkout Repository
60- uses : actions/checkout@v4
61-
62- - name : Install Docker (if not present)
63- run : |
64- # Check if docker is already installed
65- if ! command -v docker &> /dev/null
66- then
67- echo "Docker not found. Installing Docker..."
68- # Update apt package index
69- sudo apt-get update
70- # Install packages to allow apt to use a repository over HTTPS
71- sudo apt-get install -y ca-certificates curl gnupg
72- # Add Docker's official GPG key
73- sudo install -m 0755 -d /etc/apt/keyrings
74- curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
75- sudo chmod a+r /etc/apt/keyrings/docker.gpg
76- # Add the repository to Apt sources
77- echo \
78- "deb [arch=\"$(dpkg --print-architecture)\" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
79- \"$(. /etc/os-release && echo \"$VERSION_CODENAME\")\" stable" | \
80- sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
81- # Install Docker Engine, containerd, and Docker Compose
82- sudo apt-get update
83- sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
84- # Add the current user (runner user) to the docker group to run docker without sudo
85- sudo usermod -aG docker $USER
86- # You might need to log out and log back in for group changes to take effect,
87- # or restart the Docker daemon and the runner agent.
88- # For a CI environment, `newgrp docker` might work temporarily or a restart is implied.
89- sudo systemctl start docker
90- sudo systemctl enable docker
91- echo "Docker installed."
92- else
93- echo "Docker is already installed."
94- fi
95-
96- - name : Set up Docker BuildX
97- uses : docker/setup-buildx-action@v3
98-
99-
100- - name : Build Docker image for TPU tests
101- run : |
102- echo "Building Docker image using Dockerfile at .github/workflows/tpu/Dockerfile..."
103- # Use 'sudo docker' if the 'docker' group membership hasn't fully applied yet.
104- docker build -f .github/workflows/tpu/Dockerfile -t keras-tpu-test .
105- echo "Docker image built successfully."
106-
107- - name : Run Docker container and execute tests on TPU
108- run : |
109- echo "Running Docker container with TPU access and executing tests..."
110- # Use 'sudo docker' if the 'docker' group membership hasn't fully applied yet.
111- docker run --rm \
112- --privileged \
113- --network host \
114- -e PYTHON=${{ env.PYTHON }} \
115- -e KERAS_HOME=${{ env.KERAS_HOME }} \
116- -e KERAS_BACKEND=${{ env.KERAS_BACKEND }} \
117- keras-tpu-test \
118- /bin/bash -c "\
119- echo 'Verifying JAX TPU backend inside container...' && \
120- python3 -c 'import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")' \
121- # Add your actual pytest command here. Ensure pytest is installed inside your Docker image.
122- "
123- echo "Docker container finished running tests."
124109
125110 build :
126111 strategy :
0 commit comments