Skip to content

Build Images

Build Images #571

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Build Images
on:
schedule:
# Run the job daily at 12AM UTC
- cron: '0 0 * * *'
workflow_dispatch:
inputs:
target_device:
description: 'Specify target device (all, tpu, or gpu)'
required: true
type: choice
default: 'all'
options:
- all
- tpu
- gpu
jobs:
build-tpu:
# This job will only run for 'tpu', 'all', schedule, or PR triggers.
if: >
github.event_name == 'schedule' ||
github.event_name == 'pull_request' ||
github.event.inputs.target_device == 'all' ||
github.event.inputs.target_device == 'tpu'
runs-on: linux-x86-n2-16-buildkit
container: google/cloud-sdk:524.0.0
# Use Github Actions matrix to run image builds in parallel
strategy:
fail-fast: false
matrix:
include:
# TPU Image Builds
- image_name: maxtext_jax_stable
dockerfile: ./maxtext_dependencies.Dockerfile
build_args: |
MODE=stable
JAX_VERSION=NONE
LIBTPU_GCS_PATH=NONE
- image_name: maxtext_jax_nightly
dockerfile: ./maxtext_dependencies.Dockerfile
build_args: |
MODE=nightly
JAX_VERSION=NONE
LIBTPU_GCS_PATH=NONE
# TPU Image builds using JAX AI Image
- image_name: maxtext_jax_stable_stack
dockerfile: ./maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
- image_name: maxtext_stable_stack_nightly_jax
dockerfile: ./maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
- image_name: maxtext_stable_stack_candidate
dockerfile: ./maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
# Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- name: Mark git repository as safe
run: git config --global --add safe.directory ${GITHUB_WORKSPACE}
- name: Configure Docker
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
- name: Set up Docker BuildX
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2
with:
driver: remote
endpoint: tcp://localhost:1234
# Env variables to be passed to Dockerfile
- name: Get short commit hash
id: vars
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
- name: Get current date
id: date
run: echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
# Docker BuildX command config
- name: Build and Push Docker Image
uses: docker/build-push-action@v6
with:
push: true
context: .
file: ${{ matrix.dockerfile }}
tags: |
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.date.outputs.image_date }}
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:latest
cache-from: type=gha
cache-to: type=gha,mode=max
provenance: false
build-args: |
${{ matrix.build_args }}
JAX_AI_IMAGE_BASEIMAGE=${{ matrix.base_image }}
COMMIT_HASH=${{ steps.vars.outputs.sha_short }}
DEVICE=tpu
TEST_TYPE=xlml
# Same as tpu-build step but mirrored for GPUs
build-gpu:
if: >
github.event_name == 'schedule' ||
github.event_name == 'pull_request' ||
github.event.inputs.target_device == 'all' ||
github.event.inputs.target_device == 'gpu'
runs-on: linux-x86-n2-16-buildkit
container: google/cloud-sdk:524.0.0
strategy:
fail-fast: false
matrix:
# GPU Image Builds using JAX AI Image
include:
- image_name: maxtext_gpu_jax_stable_stack
dockerfile: ./maxtext_jax_ai_image.Dockerfile
base_image: us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:latest
- image_name: maxtext_gpu_stable_stack_nightly_jax
dockerfile: ./maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest
- image_name: maxtext_stable_stack_candidate_gpu
dockerfile: ./maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- name: Mark git repository as safe
run: git config --global --add safe.directory ${GITHUB_WORKSPACE}
- name: Configure Docker
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io,us-central1-docker.pkg.dev -q
- name: Set up Docker BuildX
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2
with:
driver: remote
endpoint: tcp://localhost:1234
- name: Get short commit hash
id: vars
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
- name: Get current date
id: date
run: echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
- name: Build and Push Docker Image
uses: docker/build-push-action@v6
with:
push: true
context: .
file: ${{ matrix.dockerfile }}
tags: |
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.date.outputs.image_date }}
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:latest
cache-from: type=gha
cache-to: type=gha,mode=max
provenance: false
build-args: |
${{ matrix.build_args }}
JAX_AI_IMAGE_BASEIMAGE=${{ matrix.base_image }}
COMMIT_HASH=${{ steps.vars.outputs.sha_short }}
DEVICE=gpu
TEST_TYPE=xlml