File tree Expand file tree Collapse file tree 1 file changed +2
-13
lines changed
slime/backends/fsdp_utils Expand file tree Collapse file tree 1 file changed +2
-13
lines changed Original file line number Diff line number Diff line change @@ -246,20 +246,9 @@ def setup_context_parallelism(self):
246246 print (f"Ring attention rank: { ring_attn_rank } " )
247247
248248 def _update_cp_cu_seqlens (self , packed_batch ):
249- cu_seqlens = packed_batch ["cu_seqlens" ]. to ( device = torch . cuda . current_device (), dtype = torch . int32 )
249+ cu_seqlens = packed_batch ["cu_seqlens" ]
250250
251- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
252- logits = self .model (
253- input_ids = packed_batch ["tokens" ].unsqueeze (0 ),
254- attention_mask = None ,
255- position_ids = packed_batch ["position_ids" ].unsqueeze (0 ),
256- ).logits
257- # Handle packed sequences
258- log_probs = gather_log_probs_packed (logits , packed_batch ["tokens" ], packed_batch ["cu_seqlens" ])
259- packed_batch ["cur_log_probs" ] = log_probs
260- unpacked_batches = unpack_sequences (packed_batch )
261-
262- # Sync ring flash attention parameters
251+ # Update the ring attention parameters
263252 update_ring_flash_attn_params (cu_seqlens , self .cp_group )
264253
265254
You can’t perform that action at this time.
0 commit comments