Skip to content

Commit 07cfa23

Browse files
committed
reinstall fa, te, apex, gg
1 parent 0de11df commit 07cfa23

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

.github/workflows/docker/Dockerfile

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,36 @@ RUN pip3 uninstall aiter primus_turbo -y && \
2525
pip3 install -r requirements.txt && \
2626
GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation .
2727

28+
ENV MAX_JOBS=128
29+
ENV PYTORCH_ROCM_ARCH="gfx942;gfx950"
2830
RUN pip uninstall flash_attn transformer_engine grouped_gemm -y
2931

30-
# FA v3.0.0.r1
32+
# FA
3133
ARG FA_REPO=https://github.com/ROCm/flash-attention.git
32-
ARG FA_BRANCH=v3.0.0.r1-cktile
33-
ARG GPU_ARCHS=gfx942
34+
ARG FA_BRANCH=6387433156558135a998d5568a9d74c1778666d8
35+
ENV GPU_ARCHS="${PYTORCH_ROCM_ARCH}"
3436

35-
RUN git clone --recursive ${FA_REPO} -b ${FA_BRANCH} \
37+
RUN git clone --recursive ${FA_REPO} \
3638
&& cd flash-attention \
39+
&& git checkout ${FA_BRANCH} \
3740
&& python setup.py install \
3841
&& cd .. \
3942
&& rm -rf flash-attention
4043

4144
# Latest TransformerEngine
4245
ARG TE_REPO=https://github.com/ROCm/TransformerEngine.git
43-
ARG TE_BRANCH=ec21e1f35f5a36c1def9c89c4865eba3947f0665
46+
ARG TE_BRANCH=1834247827f47f5eab76ea790841168ac778592e
4447

4548
ENV NVTE_USE_HIPBLASLT=1
4649
ENV NVTE_FRAMEWORK=pytorch
4750
ENV NVTE_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
4851
ENV NVTE_USE_CAST_TRANSPOSE_TRITON=0
52+
ENV NVTE_CK_IS_V3_ATOMIC_FP32=0
4953
ENV NVTE_CK_USES_BWD_V3=1
54+
ENV NVTE_CK_USES_FWD_V3=1
55+
ENV CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=2
5056
ENV NVTE_CK_HOW_V3_BF16_CVT=2
51-
ENV GPU_ARCHS=gfx942
52-
ENV CU_NUM=304
57+
ENV NVTE_USE_ROCM=1
5358

5459
RUN git clone --recursive ${TE_REPO} \
5560
&& cd TransformerEngine \
@@ -58,20 +63,15 @@ RUN git clone --recursive ${TE_REPO} \
5863
&& MAX_JOBS=${MAX_JOBS} pip install . \
5964
&& cd ..
6065

61-
RUN apt --fix-broken install -y
62-
RUN pip install datasets numpy==1.26.4 transformers
63-
RUN pip install --upgrade 'optree>=0.13.0'
64-
65-
WORKDIR /workspace/
66-
6766
# Groupped GEMM
6867
ARG GROUPED_GEMM_REPO=https://github.com/caaatch22/grouped_gemm.git
6968
ARG GROUPED_GEMM_BRANCH=rocm
69+
7070
RUN git clone ${GROUPED_GEMM_REPO} &&\
7171
cd grouped_gemm &&\
7272
git checkout ${GROUPED_GEMM_BRANCH} &&\
7373
git submodule update --init --recursive &&\
74-
pip install . && cd .. && rm -rf grouped_gemm
74+
pip install . && cd ../ && rm -rf grouped_gemm
7575

7676
# APEX 1.7.0
7777
ARG APEX_REPO=https://github.com/rocm/apex
@@ -83,7 +83,6 @@ RUN pip uninstall -y apex \
8383
&& pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \
8484
&& cd .. && rm -r apex
8585

86-
8786
# Set the default working directory
8887
WORKDIR /opt
8988

0 commit comments

Comments
 (0)