@@ -114,10 +114,17 @@ def upipe_attn_gqa_forward(
114114 q_out = all_to_all_4D (q_proj , scatter_idx = 2 , gather_idx = 1 )
115115 k_out = all_to_all_4D (k_proj , scatter_idx = 2 , gather_idx = 1 )
116116 v_out = all_to_all_4D (v_proj , scatter_idx = 2 , gather_idx = 1 )
117+
118+ # deleting the inp to all_to_all to avoid memory leaks
119+ del q_proj , k_proj , v_proj
120+
117121 else :
118122 q_proj = apply_rotary_emb_flash (xq = q_proj , xk = None , freqs_cis = freqs_cis )
119123 q_out = all_to_all_4D (q_proj , scatter_idx = 2 , gather_idx = 1 )
120124
125+ # deleting the inp to all_to_all to avoid memory leaks
126+ del q_proj
127+
121128 attn_out , lse = zigzag_ring_flash_attn_forward (
122129 ring_group ,
123130 q_out ,
@@ -130,11 +137,23 @@ def upipe_attn_gqa_forward(
130137 )
131138 lse_list .append (lse )
132139
140+ # deleting the inp to attn_forward to avoid memory leaks
141+ del q_out
142+ if (stage + 1 )// gqa_ratio != stage // gqa_ratio :
143+ del k_out , v_out
144+
133145 out_local = all_to_all_4D (attn_out , scatter_idx = 1 , gather_idx = 2 )
146+
147+ # deleting the inp to all_to_all to avoid memory leaks
148+ del attn_out
149+
134150 head_start = stage * ulysses_degree
135151 head_end = head_start + ulysses_degree
136152 final_out [:, :, head_start :head_end , :] = out_local
137153
154+ # deleting the output of all_to_all to avoid memory leaks
155+ del out_local
156+
138157 return [final_out ] + lse_list
139158
140159
0 commit comments