@@ -197,6 +197,11 @@ def __init__(
197197 # For logging VLM outputs
198198 self .episode_output = []
199199
200+ # Log control: only log every N calls
201+ self .log_interval = 100 # Log every 100 calls
202+ self .call_count = 0
203+ self .batch_call_count = 0
204+
200205 def _default_prompt_template (self ) -> str :
201206 """Default prompt template for Atari games with Qwen-VL format."""
202207 if self .use_cot :
@@ -595,6 +600,8 @@ def generate_prior(
595600 Returns:
596601 Prior dictionary with action_probs, action_logits, raw_output, cot_prefix
597602 """
603+ self .call_count += 1
604+
598605 # Convert observation to PIL Image if needed
599606 if isinstance (observation , np .ndarray ):
600607 image = self ._convert_obs_to_pil_image (observation )
@@ -607,6 +614,16 @@ def generate_prior(
607614 else :
608615 prompt = self ._build_prompt (action_candidates , history )
609616
617+ # Log prompt preview at intervals
618+ if self .call_count % self .log_interval == 1 :
619+ import logging
620+ logger = logging .getLogger (__name__ )
621+ logger .info (
622+ f"[VLM Prior Generation] Call #{ self .call_count } | "
623+ f"Actions: { len (action_candidates )} | "
624+ f"Prompt preview: { prompt [:150 ]} ..."
625+ )
626+
610627 # Generate with VLM
611628 raw_output = self .vlm_engine .generate (
612629 image = image ,
@@ -624,6 +641,13 @@ def generate_prior(
624641 action_log_probs = self ._action_to_logprob (chosen_action , action_candidates , temperature )
625642 action_probs = np .exp (action_log_probs )
626643
644+ # Log output at intervals
645+ if self .call_count % self .log_interval == 1 :
646+ logger .info (
647+ f"[VLM Prior Output] Chosen: { chosen_action } | "
648+ f"CoT: { cot_prefix [:100 ] if cot_prefix else 'None' } ..."
649+ )
650+
627651 return {
628652 'action_probs' : action_probs ,
629653 'action_logits' : action_log_probs , # Store log probs for training
@@ -679,19 +703,30 @@ def batch_generate_prior(
679703 ) from e
680704
681705 # Build prompts
682- prompts = [
683- self ._build_prompt (actions , hist )
684- for actions , hist in zip (action_candidates_list , histories )
685- ]
706+ prompts = []
707+ for action_candidates , history in zip (action_candidates_list , histories ):
708+ if self .use_cot :
709+ prompt = self .get_user_prompt (action_candidates , history )
710+ else :
711+ prompt = self ._build_prompt (action_candidates , history )
712+ prompts .append (prompt )
713+
714+ # Increment batch call counter
715+ self .batch_call_count += 1
686716
687- # Debug: Log first prompt to verify vision tokens
688- if prompts and len ( prompts ) > 0 :
717+ # Log batch info at intervals (every 10 batch calls)
718+ if self . batch_call_count % 10 == 1 :
689719 import logging
690720 logger = logging .getLogger (__name__ )
691- logger .info (f"[VLM Debug] First prompt preview (first 200 chars): { prompts [0 ][:200 ]} " )
721+ logger .info (
722+ f"[VLM Batch Generation] Batch #{ self .batch_call_count } | "
723+ f"Batch size: { len (observations )} | "
724+ f"Avg actions: { sum (len (a ) for a in action_candidates_list ) / len (action_candidates_list ):.1f} "
725+ )
726+ # logger.debug(f"[VLM Debug] First prompt preview: {prompts[0][:200]}")
727+ logger .debug (f"[VLM Debug] First prompt preview: { prompts [0 ]} " )
692728 if "<|vision_start|>" not in prompts [0 ]:
693729 logger .error (f"[VLM Error] Missing <|vision_start|> token in prompt!" )
694- logger .error (f"[VLM Error] Full prompt: { prompts [0 ]} " )
695730
696731 # Batch generate with VLM
697732 raw_outputs = self .vlm_engine .batch_generate (
@@ -704,14 +739,29 @@ def batch_generate_prior(
704739 # Parse outputs
705740 results = []
706741 for raw_output , action_candidates in zip (raw_outputs , action_candidates_list ):
707- action_probs = self ._parse_vlm_output (raw_output , action_candidates )
708- action_logits = np .log (action_probs + 1e-10 ) * temperature
742+ if self .use_cot :
743+ # Parse CoT output
744+ chosen_action , cot_prefix = self ._parse_vlm_output_with_cot (raw_output , action_candidates )
745+ action_log_probs = self ._action_to_logprob (chosen_action , action_candidates , temperature )
746+ action_probs = np .exp (action_log_probs )
747+
748+ results .append ({
749+ 'action_probs' : action_probs ,
750+ 'action_logits' : action_log_probs ,
751+ 'raw_output' : raw_output ,
752+ 'cot_prefix' : cot_prefix ,
753+ 'chosen_action' : chosen_action ,
754+ })
755+ else :
756+ # Legacy: probability distribution
757+ action_probs = self ._parse_vlm_output (raw_output , action_candidates )
758+ action_logits = np .log (action_probs + 1e-10 ) * temperature
709759
710- results .append ({
711- 'action_probs' : action_probs ,
712- 'action_logits' : action_logits ,
713- 'raw_output' : raw_output ,
714- })
760+ results .append ({
761+ 'action_probs' : action_probs ,
762+ 'action_logits' : action_logits ,
763+ 'raw_output' : raw_output ,
764+ })
715765
716766 return results
717767
@@ -740,7 +790,13 @@ def build_vlm_train_samples(
740790 - advantage: Advantage value for PPO loss
741791 - cot_prefix: CoT reasoning (if use_cot=True)
742792 """
793+ import logging
794+ logger = logging .getLogger (__name__ )
795+
743796 train_samples = []
797+ total_steps = 0
798+
799+ logger .info (f"[VLM Training Samples] Building samples from { len (game_segments )} segments..." )
744800
745801 for seg_idx , segment in enumerate (game_segments ):
746802 # Extract segment data
@@ -799,6 +855,17 @@ def build_vlm_train_samples(
799855 }
800856
801857 train_samples .append (sample )
858+ total_steps += 1
859+
860+ # Log summary
861+ if len (train_samples ) > 0 :
862+ avg_advantage = np .mean ([s ['advantage' ] for s in train_samples ])
863+ avg_old_logprob = np .mean ([s ['old_log_prob' ] for s in train_samples ])
864+ logger .info (
865+ f"[VLM Training Samples] Built { len (train_samples )} samples | "
866+ f"Avg advantage: { avg_advantage :.4f} | "
867+ f"Avg old_logprob: { avg_old_logprob :.4f} "
868+ )
802869
803870 return train_samples
804871
@@ -850,6 +917,29 @@ def compute_action_log_prob(
850917 return - 10.0
851918
852919
920+ def get_vlm_output_log (
921+ self ,
922+ wm_train_iter : int ,
923+ vlm_train_iter : int ,
924+ ) -> None :
925+ """
926+ Log VLM output statistics (similar to LLM's get_llm_output_log).
927+
928+ Args:
929+ wm_train_iter: World model training iteration
930+ vlm_train_iter: VLM training iteration
931+ """
932+ import logging
933+ logger = logging .getLogger (__name__ )
934+
935+ if len (self .episode_output ) > 0 :
936+ logger .info (
937+ f"[WM Iter { wm_train_iter } | VLM Iter { vlm_train_iter } ] "
938+ f"Collected { len (self .episode_output )} VLM outputs"
939+ )
940+ self .episode_output = []
941+
942+
853943def create_prior_generator (
854944 obs_type : str ,
855945 model_config : Dict [str , Any ],
0 commit comments