File tree Expand file tree Collapse file tree 2 files changed +63
-1
lines changed Expand file tree Collapse file tree 2 files changed +63
-1
lines changed Original file line number Diff line number Diff line change @@ -33,7 +33,14 @@ RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pk
3333
3434# Install Maxtext requirements with Jax Stable Stack
3535RUN apt-get update && apt-get install --yes google-cloud-cli && apt-get install --yes dnsutils
36- RUN python3 -m pip install -r /deps/requirements_with_jax_stable_stack.txt
36+
37+ # Install requirements file generated with pipreqs for JSS 0.5.2.
38+ # Othewise use general requirements_with_jax_stable_stack.txt
39+ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1" ]; then \
40+ python3 -m pip install -r /deps/requirements_with_jax_stable_stack_0_5_2_pipreqs.txt; \
41+ else \
42+ python3 -m pip install -r /deps/requirements_with_jax_stable_stack.txt; \
43+ fi
3744
3845# Run the script available in JAX Stable Stack base image to generate the manifest file
3946RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
Original file line number Diff line number Diff line change 1+ absl_py==2.1.0
2+ # Replacing with aqtp due to conflicts during build
3+ # aqt==25.2
4+ aqtp==0.8.2 # Added manually
5+ cloud_accelerator_diagnostics==0.1.1
6+ cloud_tpu_diagnostics==0.1.5
7+ datasets==3.5.0
8+ etils==1.12.2
9+ evaluate==0.4.3
10+ flax==0.10.4
11+ # Replacing with specific git commit due to conflicts during build
12+ # google_jetstream==0.3.0
13+ # Adding manually
14+ google-jetstream @ git+https://github.com/AI-Hypercomputer/JetStream.git@082c0ac526e50d8f732a083ed43920590d7ffd22
15+ grain_nightly==0.0.10
16+ jax==0.5.2
17+ jaxlib==0.5.1 # Manually adding to ensure consistency in future
18+ jaxtyping==0.3.1
19+ jsonlines==4.0.0
20+ libtpu==0.0.10.1 # Manually adding to ensure consistency in future
21+ ml_collections==1.0.0
22+ ml_goodput_measurement==0.0.8
23+ nltk==3.9.1
24+ # Removing due to conflicts during build
25+ # numpy==2.2.4
26+ omegaconf==2.3.0
27+ optax==0.2.4
28+ orbax==0.1.9
29+ pandas==2.2.3
30+ pathwaysutils==0.1.0
31+ # Removing due to conflicts during build
32+ # protobuf==3.20.3
33+ protobuf
34+ psutil==7.0.0
35+ pytest==8.3.5
36+ PyYAML==6.0.2
37+ PyYAML==6.0.2
38+ Requests==2.32.3
39+ safetensors==0.5.3
40+ sentencepiece==0.1.97
41+ tensorboard_plugin_profile==2.17.0
42+ tensorboardX==2.6.2.2
43+ tensorboardX==2.6.2.2
44+ tensorflow==2.19.0
45+ tensorflow_datasets==4.9.8
46+ tensorflow_text==2.19.0
47+ tensorstore==0.1.72
48+ tfds_nightly==4.9.2.dev202308090034
49+ tiktoken==0.9.0
50+ torch==2.6.0
51+ tqdm==4.67.1
52+ transformer_engine==2.1.0
53+ transformers==4.51.3
54+ trl==0.16.1
55+ urllib3==2.4.0
You can’t perform that action at this time.
0 commit comments