Skip to content

Commit 8a2d279

Browse files
Added caching optimization in building Maxtext + JAII Image
1 parent e67180e commit 8a2d279

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

docker_build_dependency_image.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ build_ai_image() {
7575
COMMIT_HASH=$(git rev-parse --short HEAD)
7676
echo "Building JAX AI MaxText Imageat commit hash ${COMMIT_HASH}..."
7777

78-
docker build --no-cache \
78+
docker build \
7979
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
8080
--build-arg COMMIT_HASH=${COMMIT_HASH} \
8181
--build-arg DEVICE="$DEVICE" \

maxtext_jax_ai_image.Dockerfile

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ RUN mkdir -p /deps
1313
# Set the working directory in the container
1414
WORKDIR /deps
1515

16-
# Copy all files from local workspace into docker container
17-
COPY . .
18-
RUN ls .
16+
# Copy setup files and dependency files separately for better caching
17+
COPY setup.sh ./
18+
COPY requirements.txt requirements_with_jax_ai_image.txt ./
1919

2020

2121
# For JAX AI tpu training images 0.4.37 AND 0.4.35
@@ -37,5 +37,9 @@ RUN apt-get update && apt-get install --yes && apt-get install --yes dnsutils
3737
RUN pip install google-cloud-monitoring
3838
RUN python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt
3939

40+
# Now copy the remaining code (source files that may change frequently)
41+
COPY . .
42+
RUN ls .
43+
4044
# Run the script available in JAX AI base image to generate the manifest file
4145
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH

0 commit comments

Comments
 (0)