Skip to content

Commit f70d6d7

Browse files
committed
polish(pu): polish logs
1 parent d89a3c5 commit f70d6d7

File tree

1 file changed

+105
-15
lines changed

1 file changed

+105
-15
lines changed

zoo/jericho/priorzero/prior_generator.py

Lines changed: 105 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ def __init__(
197197
# For logging VLM outputs
198198
self.episode_output = []
199199

200+
# Log control: only log every N calls
201+
self.log_interval = 100 # Log every 100 calls
202+
self.call_count = 0
203+
self.batch_call_count = 0
204+
200205
def _default_prompt_template(self) -> str:
201206
"""Default prompt template for Atari games with Qwen-VL format."""
202207
if self.use_cot:
@@ -595,6 +600,8 @@ def generate_prior(
595600
Returns:
596601
Prior dictionary with action_probs, action_logits, raw_output, cot_prefix
597602
"""
603+
self.call_count += 1
604+
598605
# Convert observation to PIL Image if needed
599606
if isinstance(observation, np.ndarray):
600607
image = self._convert_obs_to_pil_image(observation)
@@ -607,6 +614,16 @@ def generate_prior(
607614
else:
608615
prompt = self._build_prompt(action_candidates, history)
609616

617+
# Log prompt preview at intervals
618+
if self.call_count % self.log_interval == 1:
619+
import logging
620+
logger = logging.getLogger(__name__)
621+
logger.info(
622+
f"[VLM Prior Generation] Call #{self.call_count} | "
623+
f"Actions: {len(action_candidates)} | "
624+
f"Prompt preview: {prompt[:150]}..."
625+
)
626+
610627
# Generate with VLM
611628
raw_output = self.vlm_engine.generate(
612629
image=image,
@@ -624,6 +641,13 @@ def generate_prior(
624641
action_log_probs = self._action_to_logprob(chosen_action, action_candidates, temperature)
625642
action_probs = np.exp(action_log_probs)
626643

644+
# Log output at intervals
645+
if self.call_count % self.log_interval == 1:
646+
logger.info(
647+
f"[VLM Prior Output] Chosen: {chosen_action} | "
648+
f"CoT: {cot_prefix[:100] if cot_prefix else 'None'}..."
649+
)
650+
627651
return {
628652
'action_probs': action_probs,
629653
'action_logits': action_log_probs, # Store log probs for training
@@ -679,19 +703,30 @@ def batch_generate_prior(
679703
) from e
680704

681705
# Build prompts
682-
prompts = [
683-
self._build_prompt(actions, hist)
684-
for actions, hist in zip(action_candidates_list, histories)
685-
]
706+
prompts = []
707+
for action_candidates, history in zip(action_candidates_list, histories):
708+
if self.use_cot:
709+
prompt = self.get_user_prompt(action_candidates, history)
710+
else:
711+
prompt = self._build_prompt(action_candidates, history)
712+
prompts.append(prompt)
713+
714+
# Increment batch call counter
715+
self.batch_call_count += 1
686716

687-
# Debug: Log first prompt to verify vision tokens
688-
if prompts and len(prompts) > 0:
717+
# Log batch info at intervals (every 10 batch calls)
718+
if self.batch_call_count % 10 == 1:
689719
import logging
690720
logger = logging.getLogger(__name__)
691-
logger.info(f"[VLM Debug] First prompt preview (first 200 chars): {prompts[0][:200]}")
721+
logger.info(
722+
f"[VLM Batch Generation] Batch #{self.batch_call_count} | "
723+
f"Batch size: {len(observations)} | "
724+
f"Avg actions: {sum(len(a) for a in action_candidates_list) / len(action_candidates_list):.1f}"
725+
)
726+
# logger.debug(f"[VLM Debug] First prompt preview: {prompts[0][:200]}")
727+
logger.debug(f"[VLM Debug] First prompt preview: {prompts[0]}")
692728
if "<|vision_start|>" not in prompts[0]:
693729
logger.error(f"[VLM Error] Missing <|vision_start|> token in prompt!")
694-
logger.error(f"[VLM Error] Full prompt: {prompts[0]}")
695730

696731
# Batch generate with VLM
697732
raw_outputs = self.vlm_engine.batch_generate(
@@ -704,14 +739,29 @@ def batch_generate_prior(
704739
# Parse outputs
705740
results = []
706741
for raw_output, action_candidates in zip(raw_outputs, action_candidates_list):
707-
action_probs = self._parse_vlm_output(raw_output, action_candidates)
708-
action_logits = np.log(action_probs + 1e-10) * temperature
742+
if self.use_cot:
743+
# Parse CoT output
744+
chosen_action, cot_prefix = self._parse_vlm_output_with_cot(raw_output, action_candidates)
745+
action_log_probs = self._action_to_logprob(chosen_action, action_candidates, temperature)
746+
action_probs = np.exp(action_log_probs)
747+
748+
results.append({
749+
'action_probs': action_probs,
750+
'action_logits': action_log_probs,
751+
'raw_output': raw_output,
752+
'cot_prefix': cot_prefix,
753+
'chosen_action': chosen_action,
754+
})
755+
else:
756+
# Legacy: probability distribution
757+
action_probs = self._parse_vlm_output(raw_output, action_candidates)
758+
action_logits = np.log(action_probs + 1e-10) * temperature
709759

710-
results.append({
711-
'action_probs': action_probs,
712-
'action_logits': action_logits,
713-
'raw_output': raw_output,
714-
})
760+
results.append({
761+
'action_probs': action_probs,
762+
'action_logits': action_logits,
763+
'raw_output': raw_output,
764+
})
715765

716766
return results
717767

@@ -740,7 +790,13 @@ def build_vlm_train_samples(
740790
- advantage: Advantage value for PPO loss
741791
- cot_prefix: CoT reasoning (if use_cot=True)
742792
"""
793+
import logging
794+
logger = logging.getLogger(__name__)
795+
743796
train_samples = []
797+
total_steps = 0
798+
799+
logger.info(f"[VLM Training Samples] Building samples from {len(game_segments)} segments...")
744800

745801
for seg_idx, segment in enumerate(game_segments):
746802
# Extract segment data
@@ -799,6 +855,17 @@ def build_vlm_train_samples(
799855
}
800856

801857
train_samples.append(sample)
858+
total_steps += 1
859+
860+
# Log summary
861+
if len(train_samples) > 0:
862+
avg_advantage = np.mean([s['advantage'] for s in train_samples])
863+
avg_old_logprob = np.mean([s['old_log_prob'] for s in train_samples])
864+
logger.info(
865+
f"[VLM Training Samples] Built {len(train_samples)} samples | "
866+
f"Avg advantage: {avg_advantage:.4f} | "
867+
f"Avg old_logprob: {avg_old_logprob:.4f}"
868+
)
802869

803870
return train_samples
804871

@@ -850,6 +917,29 @@ def compute_action_log_prob(
850917
return -10.0
851918

852919

920+
def get_vlm_output_log(
921+
self,
922+
wm_train_iter: int,
923+
vlm_train_iter: int,
924+
) -> None:
925+
"""
926+
Log VLM output statistics (similar to LLM's get_llm_output_log).
927+
928+
Args:
929+
wm_train_iter: World model training iteration
930+
vlm_train_iter: VLM training iteration
931+
"""
932+
import logging
933+
logger = logging.getLogger(__name__)
934+
935+
if len(self.episode_output) > 0:
936+
logger.info(
937+
f"[WM Iter {wm_train_iter} | VLM Iter {vlm_train_iter}] "
938+
f"Collected {len(self.episode_output)} VLM outputs"
939+
)
940+
self.episode_output = []
941+
942+
853943
def create_prior_generator(
854944
obs_type: str,
855945
model_config: Dict[str, Any],

0 commit comments

Comments
 (0)