@@ -297,16 +297,65 @@ def build_chat_context(self, user_prompt: str) -> str:
297297 add_generation_prompt = True ,
298298 )
299299
300+ def _detect_padding_sample (self , raw_obs_list , cot_prefix_list , action_logprob_list , b , t ):
301+ """
302+ [LAYER 3 DEFENSE] Detect if a sample contains padding data.
303+
304+ Padding detection heuristics:
305+ 1. Check for consecutive duplicate observations (unlikely in real gameplay)
306+ 2. Check for None values in cot_prefix (initial state marker)
307+ 3. Check for duplicate cot_prefix with different obs (clear misalignment)
308+
309+ Args:
310+ raw_obs_list: List of raw observations
311+ cot_prefix_list: List of CoT prefixes
312+ action_logprob_list: List of action logprobs
313+ b: Batch index
314+ t: Time index
315+
316+ Returns:
317+ True if padding is detected, False otherwise
318+ """
319+ T = len (raw_obs_list [b ])
320+
321+ # Heuristic 1: Check for consecutive duplicates (strong padding indicator)
322+ # If both obs and cot_prefix are duplicated at the same time, very likely padding
323+ if t + 1 < T and raw_obs_list [b ][t ] == raw_obs_list [b ][t + 1 ]:
324+ # Check if cot_prefix is also duplicated
325+ if cot_prefix_list is not None and t + 2 < len (cot_prefix_list [b ]):
326+ if cot_prefix_list [b ][t + 1 ] == cot_prefix_list [b ][t + 2 ]:
327+ # Double duplication is a strong signal of padding
328+ return True
329+
330+ # Heuristic 2: Check for None cot_prefix (should only be at t=0)
331+ if t > 0 and cot_prefix_list is not None and t + 1 < len (cot_prefix_list [b ]):
332+ if cot_prefix_list [b ][t + 1 ] is None :
333+ return True
334+
335+ # Heuristic 3: Check if action_logprob is empty or None
336+ if action_logprob_list is not None and t + 1 < len (action_logprob_list [b ]):
337+ logprob = action_logprob_list [b ][t + 1 ]
338+ if logprob is None or (isinstance (logprob , dict ) and len (logprob ) == 0 ):
339+ return True
340+
341+ # Heuristic 4: Check for triple+ consecutive duplicates (very strong signal)
342+ if t + 2 < T :
343+ if (raw_obs_list [b ][t ] == raw_obs_list [b ][t + 1 ] == raw_obs_list [b ][t + 2 ]):
344+ return True
345+
346+ return False
347+
348+
300349 def build_llm_samples (self ,
301350 raw_obs_list : List [List [str ]],
302351 history_obs_list : List [List [List [Tuple [str , str , float ]]]],
303352 action_logprob_list : Optional [List [List [Any ]]] = None ,
304- pred_values : Optional [torch .Tensor ] = None , # [B, T-1]
305- target_values : Optional [torch .Tensor ] = None , # [B, T-1]
353+ pred_values : Optional [torch .Tensor ] = None , # [B, T-1]
354+ target_values : Optional [torch .Tensor ] = None , # [B, T-1]
306355 cot_prefix_list : Optional [List [List [str ]]] = None , # CoT reuse optimization
307356 ) -> List [Dict [str , Any ]]:
308357 """
309- Build training samples from collected data .
358+ [ENHANCED] Build training samples with padding detection and filtering .
310359
311360 Args:
312361 raw_obs_list: Raw observations
@@ -324,14 +373,48 @@ def build_llm_samples(self,
324373 return samples
325374 T = len (raw_obs_list [0 ])
326375
376+ # [LAYER 3] Statistics for monitoring
377+ total_samples = 0
378+ filtered_samples = 0
379+ filtered_reasons = {
380+ 'padding' : 0 ,
381+ 'empty_action' : 0 ,
382+ 'extreme_logprob' : 0 ,
383+ 'nan_value' : 0 ,
384+ }
385+
327386 for b in range (B ):
328387 for t in range (T - 1 ):
388+ total_samples += 1
389+
390+ # [LAYER 3 DEFENSE] Detect and skip padding samples
391+ if self ._detect_padding_sample (raw_obs_list , cot_prefix_list , action_logprob_list , b , t ):
392+ filtered_samples += 1
393+ filtered_reasons ['padding' ] += 1
394+ continue
395+
329396 current_obs = raw_obs_list [b ][t ]
330397 current_hist = history_obs_list [b ][t ]
331398 next_hist = history_obs_list [b ][t + 1 ]
332399
333- _ , true_action , reward_value = next_hist [- 1 ]
400+ # Validate history structure
401+ if not next_hist or len (next_hist ) == 0 :
402+ filtered_samples += 1
403+ filtered_reasons ['empty_action' ] += 1
404+ continue
405+
406+ try :
407+ _ , true_action , reward_value = next_hist [- 1 ]
408+ except (ValueError , IndexError ) as e :
409+ if self .rank == 0 :
410+ self ._logger .warning (f"Unexpected history structure at b={ b } , t={ t } : { next_hist [- 1 ]} " )
411+ filtered_samples += 1
412+ filtered_reasons ['empty_action' ] += 1
413+ continue
414+
334415 if not true_action :
416+ filtered_samples += 1
417+ filtered_reasons ['empty_action' ] += 1
335418 continue
336419
337420 instruction = self .build_llm_prompt (
@@ -341,18 +424,28 @@ def build_llm_samples(self,
341424 prompt = self .build_chat_context (instruction )
342425 old_logprob = None
343426 if action_logprob_list is not None :
344- old_logprob = action_logprob_list [b ][t + 1 ][true_action ]
427+ logprob_dict = action_logprob_list [b ][t + 1 ]
428+ if isinstance (logprob_dict , dict ):
429+ old_logprob = logprob_dict .get (true_action , None )
430+ else :
431+ old_logprob = None
345432
346433 # FIX: Filter samples with extreme logprobs to prevent ratio explosion
347434 if old_logprob is not None :
348435 # Skip if empty
349436 if len (old_logprob ) == 0 :
437+ filtered_samples += 1
438+ filtered_reasons ['extreme_logprob' ] += 1
350439 continue
351440 # Skip if contains extreme values (< -50 indicates very low probability)
352441 if min (old_logprob ) < - 50.0 :
442+ filtered_samples += 1
443+ filtered_reasons ['extreme_logprob' ] += 1
353444 continue
354445 # Skip if contains NaN/Inf
355446 if any (math .isnan (x ) or math .isinf (x ) for x in old_logprob ):
447+ filtered_samples += 1
448+ filtered_reasons ['nan_value' ] += 1
356449 continue
357450
358451 target_value = None
@@ -365,8 +458,12 @@ def build_llm_samples(self,
365458
366459 # FIX: Skip samples with NaN/Inf values
367460 if target_value is not None and (math .isnan (target_value ) or math .isinf (target_value )):
461+ filtered_samples += 1
462+ filtered_reasons ['nan_value' ] += 1
368463 continue
369464 if pred_value is not None and (math .isnan (pred_value ) or math .isinf (pred_value )):
465+ filtered_samples += 1
466+ filtered_reasons ['nan_value' ] += 1
370467 continue
371468
372469 # CoT reuse optimization: get CoT prefix from stored data
@@ -375,6 +472,9 @@ def build_llm_samples(self,
375472 prefix_cot = None
376473 if self .use_cot and cot_prefix_list is not None :
377474 prefix_cot = cot_prefix_list [b ][t + 1 ]
475+ # [FIX] Handle None prefix (initial state or padding)
476+ if prefix_cot is None :
477+ prefix_cot = ""
378478
379479 samples .append (
380480 {
@@ -388,6 +488,24 @@ def build_llm_samples(self,
388488 "prefix_cot" : prefix_cot , # CoT reuse optimization
389489 }
390490 )
491+
492+ # [LAYER 3 MONITORING] Log filtering statistics
493+ if self .rank == 0 and total_samples > 0 :
494+ filter_rate = (filtered_samples / total_samples ) * 100
495+ self ._logger .info (
496+ f"[Sample Filtering] Total: { total_samples } | Filtered: { filtered_samples } ({ filter_rate :.2f} %) | "
497+ f"padding={ filtered_reasons ['padding' ]} , empty={ filtered_reasons ['empty_action' ]} , "
498+ f"extreme_logprob={ filtered_reasons ['extreme_logprob' ]} , nan={ filtered_reasons ['nan_value' ]} "
499+ )
500+
501+ # WARNING: If too many samples are filtered due to padding, something is wrong
502+ if filtered_reasons ['padding' ] > total_samples * 0.1 : # >10% padding
503+ self ._logger .warning (
504+ f"⚠️ High padding rate detected ({ filtered_reasons ['padding' ]} /{ total_samples } = "
505+ f"{ filtered_reasons ['padding' ]/ total_samples * 100 :.1f} %)! "
506+ f"Check sampling strategy in game buffer."
507+ )
508+
391509 return samples
392510
393511 def make_llm_train_samples (self , priorzero_batch , ddp : bool = False ) -> List [Dict [str , Any ]]:
0 commit comments