@@ -25,31 +25,36 @@ RUN pip3 uninstall aiter primus_turbo -y && \
2525 pip3 install -r requirements.txt && \
2626 GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation .
2727
28+ ENV MAX_JOBS=128
29+ ENV PYTORCH_ROCM_ARCH="gfx942;gfx950"
2830RUN pip uninstall flash_attn transformer_engine grouped_gemm -y
2931
30- # FA v3.0.0.r1
32+ # FA
3133ARG FA_REPO=https://github.com/ROCm/flash-attention.git
32- ARG FA_BRANCH=v3.0.0.r1-cktile
33- ARG GPU_ARCHS=gfx942
34+ ARG FA_BRANCH=6387433156558135a998d5568a9d74c1778666d8
35+ ENV GPU_ARCHS="${PYTORCH_ROCM_ARCH}"
3436
35- RUN git clone --recursive ${FA_REPO} -b ${FA_BRANCH} \
37+ RUN git clone --recursive ${FA_REPO} \
3638 && cd flash-attention \
39+ && git checkout ${FA_BRANCH} \
3740 && python setup.py install \
3841 && cd .. \
3942 && rm -rf flash-attention
4043
4144# Latest TransformerEngine
4245ARG TE_REPO=https://github.com/ROCm/TransformerEngine.git
43- ARG TE_BRANCH=ec21e1f35f5a36c1def9c89c4865eba3947f0665
46+ ARG TE_BRANCH=1834247827f47f5eab76ea790841168ac778592e
4447
4548ENV NVTE_USE_HIPBLASLT=1
4649ENV NVTE_FRAMEWORK=pytorch
4750ENV NVTE_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
4851ENV NVTE_USE_CAST_TRANSPOSE_TRITON=0
52+ ENV NVTE_CK_IS_V3_ATOMIC_FP32=0
4953ENV NVTE_CK_USES_BWD_V3=1
54+ ENV NVTE_CK_USES_FWD_V3=1
55+ ENV CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=2
5056ENV NVTE_CK_HOW_V3_BF16_CVT=2
51- ENV GPU_ARCHS=gfx942
52- ENV CU_NUM=304
57+ ENV NVTE_USE_ROCM=1
5358
5459RUN git clone --recursive ${TE_REPO} \
5560 && cd TransformerEngine \
@@ -58,20 +63,15 @@ RUN git clone --recursive ${TE_REPO} \
5863 && MAX_JOBS=${MAX_JOBS} pip install . \
5964 && cd ..
6065
61- RUN apt --fix-broken install -y
62- RUN pip install datasets numpy==1.26.4 transformers
63- RUN pip install --upgrade 'optree>=0.13.0'
64-
65- WORKDIR /workspace/
66-
6766# Groupped GEMM
6867ARG GROUPED_GEMM_REPO=https://github.com/caaatch22/grouped_gemm.git
6968ARG GROUPED_GEMM_BRANCH=rocm
69+
7070RUN git clone ${GROUPED_GEMM_REPO} &&\
7171 cd grouped_gemm &&\
7272 git checkout ${GROUPED_GEMM_BRANCH} &&\
7373 git submodule update --init --recursive &&\
74- pip install . && cd .. && rm -rf grouped_gemm
74+ pip install . && cd ../ && rm -rf grouped_gemm
7575
7676# APEX 1.7.0
7777ARG APEX_REPO=https://github.com/rocm/apex
@@ -83,7 +83,6 @@ RUN pip uninstall -y apex \
8383 && pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \
8484 && cd .. && rm -r apex
8585
86-
8786# Set the default working directory
8887WORKDIR /opt
8988
0 commit comments