Skip to content

Commit 5ad64ee

Browse files
committed
Delete mem buffers proactively during UPipe forward pass
1 parent 947619e commit 5ad64ee

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

untied_ulysses/fully_fused_attn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,17 @@ def upipe_attn_gqa_forward(
114114
q_out = all_to_all_4D(q_proj, scatter_idx=2, gather_idx=1)
115115
k_out = all_to_all_4D(k_proj, scatter_idx=2, gather_idx=1)
116116
v_out = all_to_all_4D(v_proj, scatter_idx=2, gather_idx=1)
117+
118+
# deleting the inp to all_to_all to avoid memory leaks
119+
del q_proj, k_proj, v_proj
120+
117121
else:
118122
q_proj = apply_rotary_emb_flash(xq=q_proj, xk=None, freqs_cis=freqs_cis)
119123
q_out = all_to_all_4D(q_proj, scatter_idx=2, gather_idx=1)
120124

125+
# deleting the inp to all_to_all to avoid memory leaks
126+
del q_proj
127+
121128
attn_out, lse = zigzag_ring_flash_attn_forward(
122129
ring_group,
123130
q_out,
@@ -130,11 +137,23 @@ def upipe_attn_gqa_forward(
130137
)
131138
lse_list.append(lse)
132139

140+
# deleting the inp to attn_forward to avoid memory leaks
141+
del q_out
142+
if (stage+1)//gqa_ratio != stage//gqa_ratio:
143+
del k_out, v_out
144+
133145
out_local = all_to_all_4D(attn_out, scatter_idx=1, gather_idx=2)
146+
147+
# deleting the inp to all_to_all to avoid memory leaks
148+
del attn_out
149+
134150
head_start = stage * ulysses_degree
135151
head_end = head_start + ulysses_degree
136152
final_out[:, :, head_start:head_end, :] = out_local
137153

154+
# deleting the output of all_to_all to avoid memory leaks
155+
del out_local
156+
138157
return [final_out] + lse_list
139158

140159

0 commit comments

Comments
 (0)