Skip to content

Commit 72979a3

Browse files
Merge pull request #2540 from AI-Hypercomputer:grpo_docker_rename
PiperOrigin-RevId: 823673423
2 parents 823a63f + 0c5af95 commit 72979a3

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

docker_build_dependency_image.sh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# works with any custom wheels.
2828
# bash docker_build_dependency_image.sh MODE=custom_wheels
2929

30-
# bash docker_build_dependency_image.sh MODE=grpo
30+
# bash docker_build_dependency_image.sh MODE=post-training
3131

3232
# Enable "exit immediately if any command fails" option
3333
set -e
@@ -68,17 +68,17 @@ if [[ -z ${MODE} ]]; then
6868
export MODE=stable
6969
echo "Default MODE=${MODE}"
7070
export CUSTOM_JAX=0
71-
export INSTALL_GRPO=0
71+
export INSTALL_POST_TRAINING=0
7272
elif [[ ${MODE} == "custom_wheels" ]] ; then
7373
export MODE=nightly
7474
export CUSTOM_JAX=1
75-
export INSTALL_GRPO=0
76-
elif [[ ${MODE} == "grpo" || ${MODE} == "grpo-experimental" ]] ; then
77-
export INSTALL_GRPO=1
75+
export INSTALL_POST_TRAINING=0
76+
elif [[ ${MODE} == "post-training" || ${MODE} == "post-training-experimental" ]] ; then
77+
export INSTALL_POST_TRAINING=1
7878
export CUSTOM_JAX=0
7979
else
8080
export CUSTOM_JAX=0
81-
export INSTALL_GRPO=0
81+
export INSTALL_POST_TRAINING=0
8282
fi
8383

8484
if [[ -z ${DEVICE} ]]; then
@@ -124,8 +124,8 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
124124
elif [[ ${MANTARAY} == "true" ]]; then
125125
echo "Building with benchmark-db"
126126
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
127-
elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then
128-
echo "Installing MaxText stable mode dependencies for GRPO"
127+
elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then
128+
echo "Installing MaxText stable mode dependencies for Post-Training"
129129
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
130130
else
131131
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
@@ -136,9 +136,9 @@ else
136136
docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
137137
fi
138138

139-
if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
139+
if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
140140
if [[ ${DEVICE} != "tpu" ]] ; then
141-
echo "Error: MODE=grpo is only supported for DEVICE=tpu"
141+
echo "Error: MODE=post-training is only supported for DEVICE=tpu"
142142
exit 1
143143
fi
144144

@@ -158,7 +158,7 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
158158
--network host \
159159
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
160160
--build-arg MODE=${MODE} \
161-
-f ./maxtext_grpo_dependencies.Dockerfile \
161+
-f ./maxtext_post_training_dependencies.Dockerfile \
162162
-t ${LOCAL_IMAGE_NAME} .
163163
fi
164164

docs/tutorials/grpo_with_pathways.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ In addition to MaxText dependencies,
3939
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
4040

4141
```
42-
bash docker_build_dependency_image.sh MODE=grpo
42+
bash docker_build_dependency_image.sh MODE=post-training
4343
```
4444

45-
You can also use `bash docker_build_dependency_image.sh MODE=grpo-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
45+
You can also use `bash docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
4646

4747

4848

maxtext_grpo_dependencies.Dockerfile renamed to maxtext_post_training_dependencies.Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ARG MODE
1818

1919
ENV MODE=$MODE
2020

21-
RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
21+
RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"
2222

2323

2424
# Uninstall existing jax to avoid conflicts
@@ -52,7 +52,7 @@ RUN pip install --no-cache-dir --pre \
5252
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
5353
tpu-commons==0.1.2
5454

55-
RUN if [ "$MODE" = "grpo-experimental" ]; then \
55+
RUN if [ "$MODE" = "post-training-experimental" ]; then \
5656
pip uninstall -y jax jaxlib libtpu && \
5757
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
5858
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \

0 commit comments

Comments
 (0)