2727# works with any custom wheels.
2828# bash docker_build_dependency_image.sh MODE=custom_wheels
2929
30- # bash docker_build_dependency_image.sh MODE=grpo
30+ # bash docker_build_dependency_image.sh MODE=post-training
3131
3232# Enable "exit immediately if any command fails" option
3333set -e
@@ -68,17 +68,17 @@ if [[ -z ${MODE} ]]; then
6868 export MODE=stable
6969 echo " Default MODE=${MODE} "
7070 export CUSTOM_JAX=0
71- export INSTALL_GRPO =0
71+ export INSTALL_POST_TRAINING =0
7272elif [[ ${MODE} == " custom_wheels" ]] ; then
7373 export MODE=nightly
7474 export CUSTOM_JAX=1
75- export INSTALL_GRPO =0
76- elif [[ ${MODE} == " grpo " || ${MODE} == " grpo -experimental" ]] ; then
77- export INSTALL_GRPO =1
75+ export INSTALL_POST_TRAINING =0
76+ elif [[ ${MODE} == " post-training " || ${MODE} == " post-training -experimental" ]] ; then
77+ export INSTALL_POST_TRAINING =1
7878 export CUSTOM_JAX=0
7979else
8080 export CUSTOM_JAX=0
81- export INSTALL_GRPO =0
81+ export INSTALL_POST_TRAINING =0
8282fi
8383
8484if [[ -z ${DEVICE} ]]; then
@@ -124,8 +124,8 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
124124 elif [[ ${MANTARAY} == " true" ]]; then
125125 echo " Building with benchmark-db"
126126 docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
127- elif [[ ${INSTALL_GRPO } -eq 1 && ${DEVICE} == " tpu" ]]; then
128- echo " Installing MaxText stable mode dependencies for GRPO "
127+ elif [[ ${INSTALL_POST_TRAINING } -eq 1 && ${DEVICE} == " tpu" ]]; then
128+ echo " Installing MaxText stable mode dependencies for Post-Training "
129129 docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
130130 else
131131 docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
136136 docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
137137fi
138138
139- if [[ ${INSTALL_GRPO } -eq 1 ]] ; then
139+ if [[ ${INSTALL_POST_TRAINING } -eq 1 ]] ; then
140140 if [[ ${DEVICE} != " tpu" ]] ; then
141- echo " Error: MODE=grpo is only supported for DEVICE=tpu"
141+ echo " Error: MODE=post-training is only supported for DEVICE=tpu"
142142 exit 1
143143 fi
144144
@@ -158,7 +158,7 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
158158 --network host \
159159 --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
160160 --build-arg MODE=${MODE} \
161- -f ./maxtext_grpo_dependencies .Dockerfile \
161+ -f ./maxtext_post_training_dependencies .Dockerfile \
162162 -t ${LOCAL_IMAGE_NAME} .
163163fi
164164
0 commit comments