Skip to content

Commit 0de11df

Browse files
committed
reinstall fa, te, apex, gg
1 parent 2703ab8 commit 0de11df

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

.github/workflows/docker/Dockerfile

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,64 @@ RUN pip3 uninstall aiter primus_turbo -y && \
2525
pip3 install -r requirements.txt && \
2626
GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation .
2727

28+
RUN pip uninstall flash_attn transformer_engine grouped_gemm -y
29+
30+
# FA v3.0.0.r1
31+
ARG FA_REPO=https://github.com/ROCm/flash-attention.git
32+
ARG FA_BRANCH=v3.0.0.r1-cktile
33+
ARG GPU_ARCHS=gfx942
34+
35+
RUN git clone --recursive ${FA_REPO} -b ${FA_BRANCH} \
36+
&& cd flash-attention \
37+
&& python setup.py install \
38+
&& cd .. \
39+
&& rm -rf flash-attention
40+
41+
# Latest TransformerEngine
42+
ARG TE_REPO=https://github.com/ROCm/TransformerEngine.git
43+
ARG TE_BRANCH=ec21e1f35f5a36c1def9c89c4865eba3947f0665
44+
45+
ENV NVTE_USE_HIPBLASLT=1
46+
ENV NVTE_FRAMEWORK=pytorch
47+
ENV NVTE_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
48+
ENV NVTE_USE_CAST_TRANSPOSE_TRITON=0
49+
ENV NVTE_CK_USES_BWD_V3=1
50+
ENV NVTE_CK_HOW_V3_BF16_CVT=2
51+
ENV GPU_ARCHS=gfx942
52+
ENV CU_NUM=304
53+
54+
RUN git clone --recursive ${TE_REPO} \
55+
&& cd TransformerEngine \
56+
&& git checkout ${TE_BRANCH} \
57+
&& git submodule update --init --recursive \
58+
&& MAX_JOBS=${MAX_JOBS} pip install . \
59+
&& cd ..
60+
61+
RUN apt --fix-broken install -y
62+
RUN pip install datasets numpy==1.26.4 transformers
63+
RUN pip install --upgrade 'optree>=0.13.0'
64+
65+
WORKDIR /workspace/
66+
67+
# Groupped GEMM
68+
ARG GROUPED_GEMM_REPO=https://github.com/caaatch22/grouped_gemm.git
69+
ARG GROUPED_GEMM_BRANCH=rocm
70+
RUN git clone ${GROUPED_GEMM_REPO} &&\
71+
cd grouped_gemm &&\
72+
git checkout ${GROUPED_GEMM_BRANCH} &&\
73+
git submodule update --init --recursive &&\
74+
pip install . && cd .. && rm -rf grouped_gemm
75+
76+
# APEX 1.7.0
77+
ARG APEX_REPO=https://github.com/rocm/apex
78+
ARG APEX_BRANCH=release/1.7.0
79+
80+
RUN pip uninstall -y apex \
81+
&& git clone ${APEX_REPO} -b ${APEX_BRANCH} \
82+
&& cd apex \
83+
&& pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \
84+
&& cd .. && rm -r apex
85+
2886

2987
# Set the default working directory
3088
WORKDIR /opt

0 commit comments

Comments
 (0)