Skip to content

Commit 00c406d

Browse files
author
maxtext authors
committed
Merge pull request #1595 from AI-Hypercomputer:bvandermoon-xpk-path
PiperOrigin-RevId: 748341979
2 parents 5c4090b + 27b6581 commit 00c406d

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

maxtext_jax_stable_stack.Dockerfile

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@ RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pk
3333

3434
# Install Maxtext requirements with Jax Stable Stack
3535
RUN apt-get update && apt-get install --yes google-cloud-cli && apt-get install --yes dnsutils
36-
RUN python3 -m pip install -r /deps/requirements_with_jax_stable_stack.txt
36+
37+
# Install requirements file generated with pipreqs for JSS 0.5.2.
38+
# Othewise use general requirements_with_jax_stable_stack.txt
39+
RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1" ]; then \
40+
python3 -m pip install -r /deps/requirements_with_jax_stable_stack_0_5_2_pipreqs.txt; \
41+
else \
42+
python3 -m pip install -r /deps/requirements_with_jax_stable_stack.txt; \
43+
fi
3744

3845
# Run the script available in JAX Stable Stack base image to generate the manifest file
3946
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
absl_py==2.1.0
2+
# Replacing with aqtp due to conflicts during build
3+
# aqt==25.2
4+
aqtp==0.8.2 # Added manually
5+
cloud_accelerator_diagnostics==0.1.1
6+
cloud_tpu_diagnostics==0.1.5
7+
datasets==3.5.0
8+
etils==1.12.2
9+
evaluate==0.4.3
10+
flax==0.10.4
11+
# Replacing with specific git commit due to conflicts during build
12+
# google_jetstream==0.3.0
13+
# Adding manually
14+
google-jetstream @ git+https://github.com/AI-Hypercomputer/JetStream.git@082c0ac526e50d8f732a083ed43920590d7ffd22
15+
grain_nightly==0.0.10
16+
jax==0.5.2
17+
jaxlib==0.5.1 # Manually adding to ensure consistency in future
18+
jaxtyping==0.3.1
19+
jsonlines==4.0.0
20+
libtpu==0.0.10.1 # Manually adding to ensure consistency in future
21+
ml_collections==1.0.0
22+
ml_goodput_measurement==0.0.8
23+
nltk==3.9.1
24+
# Removing due to conflicts during build
25+
# numpy==2.2.4
26+
omegaconf==2.3.0
27+
optax==0.2.4
28+
orbax==0.1.9
29+
pandas==2.2.3
30+
pathwaysutils==0.1.0
31+
# Removing due to conflicts during build
32+
# protobuf==3.20.3
33+
protobuf
34+
psutil==7.0.0
35+
pytest==8.3.5
36+
PyYAML==6.0.2
37+
PyYAML==6.0.2
38+
Requests==2.32.3
39+
safetensors==0.5.3
40+
sentencepiece==0.1.97
41+
tensorboard_plugin_profile==2.17.0
42+
tensorboardX==2.6.2.2
43+
tensorboardX==2.6.2.2
44+
tensorflow==2.19.0
45+
tensorflow_datasets==4.9.8
46+
tensorflow_text==2.19.0
47+
tensorstore==0.1.72
48+
tfds_nightly==4.9.2.dev202308090034
49+
tiktoken==0.9.0
50+
torch==2.6.0
51+
tqdm==4.67.1
52+
transformer_engine==2.1.0
53+
transformers==4.51.3
54+
trl==0.16.1
55+
urllib3==2.4.0

0 commit comments

Comments
 (0)