Skip to content

Commit 53a2571

Browse files
authored
Merge branch 'main' into feat/4b_mega
2 parents 44f009c + aee4d69 commit 53a2571

File tree

40 files changed

+2952
-1228
lines changed

40 files changed

+2952
-1228
lines changed

docker/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ RUN apt update
99
RUN apt install -y nvtop rsync
1010

1111
# TODO: change to pip install sglang-router after it has a new release
12-
RUN pip install sglang-router --force-reinstall
12+
RUN pip install sglang-router==0.2.1 --force-reinstall
1313
RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git --no-cache-dir --force-reinstall
1414
RUN pip install ray[default]
1515
RUN pip install httpx[http2] wandb pylatexenc blobfile accelerate "mcp[cli]"
1616

1717
# mbridge
1818
RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps
1919

20-
RUN TORCH_CUDA_ARCH_LIST="8.0;8.9;9.0;9.0a" pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.4 --no-build-isolation
2120
# apex
2221
RUN NVCC_APPEND_FLAGS="--threads 4" \
2322
pip -v install --disable-pip-version-check --no-cache-dir \
@@ -31,13 +30,14 @@ RUN MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 --no-build-isolation
3130
RUN pip install flash-linear-attention
3231
RUN pip -v install --no-build-isolation transformer_engine[pytorch]
3332

33+
WORKDIR /root/
3434
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
3535
cd flash-attention/ && git submodule update --init && cd hopper/ && python setup.py install && \
3636
export python_path=`python -c "import site; print(site.getsitepackages()[0])"` && \
3737
mkdir -p $python_path/flash_attn_3 && \
3838
cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py
39+
RUN rm -rf flash-attention/
3940

40-
WORKDIR /root/
4141
RUN git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \
4242
cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \
4343
pip install -e .

docker/patch/latest/megatron.patch

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ index 860ee64a9..80944b702 100755
9494
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
9595
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
9696
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
97-
index 6aec66e6d..7aa4b2f7d 100644
97+
index 6aec66e6d..b660a2002 100644
9898
--- a/megatron/core/models/gpt/gpt_model.py
9999
+++ b/megatron/core/models/gpt/gpt_model.py
100100
@@ -355,6 +355,7 @@ class GPTModel(LanguageModule):
@@ -143,8 +143,24 @@ index 6aec66e6d..7aa4b2f7d 100644
143143
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
144144
hidden_states = hidden_states_list[0]
145145
if loss_mask is None:
146-
@@ -480,9 +485,9 @@ class GPTModel(LanguageModule):
147-
runtime_gather_output=runtime_gather_output,
146+
@@ -474,15 +479,21 @@ class GPTModel(LanguageModule):
147+
loss_mask = torch.ones_like(mtp_labels)
148+
for mtp_layer_number in range(self.config.mtp_num_layers):
149+
# output
150+
- mtp_logits, _ = self.output_layer(
151+
- hidden_states_list[mtp_layer_number + 1],
152+
- weight=output_weight,
153+
- runtime_gather_output=runtime_gather_output,
154+
+ output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()}
155+
+ output_layer_buffers = dict(self.output_layer.named_buffers())
156+
+ mtp_logits, _ = torch.func.functional_call(
157+
+ self.output_layer,
158+
+ {**output_layer_params, **output_layer_buffers},
159+
+ (hidden_states_list[mtp_layer_number + 1],),
160+
+ {
161+
+ "weight": output_weight.detach() if output_weight else None,
162+
+ "runtime_gather_output": runtime_gather_output,
163+
+ },
148164
)
149165
# Calc loss for the current Multi-Token Prediction (MTP) layers.
150166
- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)

0 commit comments

Comments
 (0)