@@ -25,6 +25,64 @@ RUN pip3 uninstall aiter primus_turbo -y && \
2525 pip3 install -r requirements.txt && \
2626 GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation .
2727
28+ RUN pip uninstall flash_attn transformer_engine grouped_gemm -y
29+
30+ # FA v3.0.0.r1
31+ ARG FA_REPO=https://github.com/ROCm/flash-attention.git
32+ ARG FA_BRANCH=v3.0.0.r1-cktile
33+ ARG GPU_ARCHS=gfx942
34+
35+ RUN git clone --recursive ${FA_REPO} -b ${FA_BRANCH} \
36+ && cd flash-attention \
37+ && python setup.py install \
38+ && cd .. \
39+ && rm -rf flash-attention
40+
41+ # Latest TransformerEngine
42+ ARG TE_REPO=https://github.com/ROCm/TransformerEngine.git
43+ ARG TE_BRANCH=ec21e1f35f5a36c1def9c89c4865eba3947f0665
44+
45+ ENV NVTE_USE_HIPBLASLT=1
46+ ENV NVTE_FRAMEWORK=pytorch
47+ ENV NVTE_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
48+ ENV NVTE_USE_CAST_TRANSPOSE_TRITON=0
49+ ENV NVTE_CK_USES_BWD_V3=1
50+ ENV NVTE_CK_HOW_V3_BF16_CVT=2
51+ ENV GPU_ARCHS=gfx942
52+ ENV CU_NUM=304
53+
54+ RUN git clone --recursive ${TE_REPO} \
55+ && cd TransformerEngine \
56+ && git checkout ${TE_BRANCH} \
57+ && git submodule update --init --recursive \
58+ && MAX_JOBS=${MAX_JOBS} pip install . \
59+ && cd ..
60+
61+ RUN apt --fix-broken install -y
62+ RUN pip install datasets numpy==1.26.4 transformers
63+ RUN pip install --upgrade 'optree>=0.13.0'
64+
65+ WORKDIR /workspace/
66+
67+ # Groupped GEMM
68+ ARG GROUPED_GEMM_REPO=https://github.com/caaatch22/grouped_gemm.git
69+ ARG GROUPED_GEMM_BRANCH=rocm
70+ RUN git clone ${GROUPED_GEMM_REPO} &&\
71+ cd grouped_gemm &&\
72+ git checkout ${GROUPED_GEMM_BRANCH} &&\
73+ git submodule update --init --recursive &&\
74+ pip install . && cd .. && rm -rf grouped_gemm
75+
76+ # APEX 1.7.0
77+ ARG APEX_REPO=https://github.com/rocm/apex
78+ ARG APEX_BRANCH=release/1.7.0
79+
80+ RUN pip uninstall -y apex \
81+ && git clone ${APEX_REPO} -b ${APEX_BRANCH} \
82+ && cd apex \
83+ && pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \
84+ && cd .. && rm -r apex
85+
2886
2987# Set the default working directory
3088WORKDIR /opt
0 commit comments