Skip to content

Commit a24cace

Browse files
committed
fix(pu): fix some data type bug
1 parent 549b0b1 commit a24cace

File tree

6 files changed

+230
-32
lines changed

6 files changed

+230
-32
lines changed

lzero/model/common.py

Lines changed: 121 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,6 @@ def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
470470
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
471471
return self.encode(x, no_grad=no_grad)
472472

473-
474473
class HFLanguageRepresentationNetwork(nn.Module):
475474
def __init__(self,
476475
model_path: str = 'google-bert/bert-base-uncased',
@@ -489,32 +488,26 @@ def __init__(self,
489488
super().__init__()
490489
from transformers import AutoModel, AutoTokenizer
491490

492-
# [FIX] Load tokenizer for ALL ranks, not just non-zero ranks
493491
if tokenizer is not None:
494492
self.tokenizer = tokenizer
495493
else:
496-
# Load tokenizer with same distributed logic as model
497494
if get_rank() == 0:
498495
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
499496
if get_world_size() > 1:
500497
torch.distributed.barrier()
501498
if get_rank() != 0:
502499
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
503500

504-
# In distributed settings, ensure only rank 0 downloads the model/tokenizer.
505501
if get_rank() == 0:
506502
self.pretrained_model = AutoModel.from_pretrained(model_path)
507-
508503
if get_world_size() > 1:
509-
# Wait for rank 0 to finish loading the model.
510504
torch.distributed.barrier()
511505
if get_rank() != 0:
512506
self.pretrained_model = AutoModel.from_pretrained(model_path)
513507

514508
self.embedding_size = embedding_size
515509
self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)
516510

517-
# # Select the normalization method based on the final_norm_option_in_encoder parameter.
518511
if final_norm_option_in_encoder.lower() == "simnorm":
519512
self.norm = SimNorm(simnorm_dim=group_size)
520513
elif final_norm_option_in_encoder.lower() == "layernorm":
@@ -533,26 +526,140 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
533526
Returns:
534527
- (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size).
535528
"""
536-
529+
# Ensure the input tensor is of type long.
530+
x = x.long()
531+
537532
# Construct the attention mask to exclude padding tokens.
538-
attention_mask = x != self.tokenizer.pad_token_id
533+
attention_mask = (x != self.tokenizer.pad_token_id).long()
534+
535+
# ==================== 修复开始 ====================
536+
# 1. 显式地创建 token_type_ids
537+
# 对于单句输入,token_type_ids 是一个与 input_ids 形状相同的全零张量。
538+
token_type_ids = torch.zeros_like(x, device=x.device)
539+
540+
# 2. 移除危险的内部状态修改
541+
# 下面的代码块是导致错误的根源,必须删除。
542+
# if hasattr(self.pretrained_model, 'embeddings') and hasattr(self.pretrained_model.embeddings, 'token_type_ids'):
543+
# self.pretrained_model.embeddings.token_type_ids = None
544+
# ==================== 修复结束 ====================
539545

540546
if no_grad:
541547
with torch.no_grad():
542-
x = x.long() # Ensure the input tensor is of type long.
543-
outputs = self.pretrained_model(x, attention_mask=attention_mask)
544-
# Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
548+
# 3. 在模型调用时传入 token_type_ids
549+
outputs = self.pretrained_model(x, attention_mask=attention_mask, token_type_ids=token_type_ids)
545550
cls_embedding = outputs.last_hidden_state[:, 0, :]
546551
else:
547-
x = x.long()
548-
outputs = self.pretrained_model(x, attention_mask=attention_mask)
552+
# 3. 在模型调用时传入 token_type_ids
553+
outputs = self.pretrained_model(x, attention_mask=attention_mask, token_type_ids=token_type_ids)
549554
cls_embedding = outputs.last_hidden_state[:, 0, :]
550555

551556
cls_embedding = self.embed_proj_head(cls_embedding)
552557
cls_embedding = self.norm(cls_embedding)
553558

554559
return cls_embedding
555560

561+
# class HFLanguageRepresentationNetwork(nn.Module):
562+
# def __init__(self,
563+
# model_path: str = 'google-bert/bert-base-uncased',
564+
# embedding_size: int = 768,
565+
# group_size: int = 8,
566+
# final_norm_option_in_encoder: str = "layernorm",
567+
# tokenizer=None):
568+
# """
569+
# Arguments:
570+
# - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'.
571+
# - embedding_size (int): The dimension of the output embeddings. Default is 768.
572+
# - group_size (int): The group size for SimNorm when using normalization.
573+
# - final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
574+
# - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model.
575+
# """
576+
# super().__init__()
577+
# from transformers import AutoModel, AutoTokenizer
578+
579+
# # [FIX] Load tokenizer for ALL ranks, not just non-zero ranks
580+
# if tokenizer is not None:
581+
# self.tokenizer = tokenizer
582+
# else:
583+
# # Load tokenizer with same distributed logic as model
584+
# if get_rank() == 0:
585+
# self.tokenizer = AutoTokenizer.from_pretrained(model_path)
586+
# if get_world_size() > 1:
587+
# torch.distributed.barrier()
588+
# if get_rank() != 0:
589+
# self.tokenizer = AutoTokenizer.from_pretrained(model_path)
590+
591+
# # In distributed settings, ensure only rank 0 downloads the model/tokenizer.
592+
# if get_rank() == 0:
593+
# self.pretrained_model = AutoModel.from_pretrained(model_path)
594+
595+
# if get_world_size() > 1:
596+
# # Wait for rank 0 to finish loading the model.
597+
# torch.distributed.barrier()
598+
# if get_rank() != 0:
599+
# self.pretrained_model = AutoModel.from_pretrained(model_path)
600+
601+
# self.embedding_size = embedding_size
602+
# self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)
603+
604+
# # # Select the normalization method based on the final_norm_option_in_encoder parameter.
605+
# if final_norm_option_in_encoder.lower() == "simnorm":
606+
# self.norm = SimNorm(simnorm_dim=group_size)
607+
# elif final_norm_option_in_encoder.lower() == "layernorm":
608+
# self.norm = nn.LayerNorm(embedding_size)
609+
# else:
610+
# raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
611+
# f"Choose 'simnorm' or 'layernorm'.")
612+
613+
# def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
614+
# """
615+
# Overview:
616+
# Computes language representation from input token IDs.
617+
# Arguments:
618+
# - x (:obj:`torch.Tensor`): Input token sequence of shape (B, seq_len).
619+
# - no_grad (:obj:`bool`): If True, run the transformer model in `torch.no_grad()` context.
620+
# Returns:
621+
# - (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size).
622+
# """
623+
624+
# # Construct the attention mask to exclude padding tokens.
625+
# attention_mask = x != self.tokenizer.pad_token_id
626+
627+
# # [FIX] Clear buffered token_type_ids to prevent shape mismatch errors
628+
# # BERT models cache token_type_ids for efficiency, but this causes issues
629+
# # when batch sizes or sequence lengths vary across different forward passes.
630+
# # We delete the buffer entirely and let BERT recreate it with the correct shape.
631+
# if hasattr(self.pretrained_model, 'embeddings') and hasattr(self.pretrained_model.embeddings, 'token_type_ids'):
632+
# # Check if token_type_ids exists and has wrong shape
633+
# if self.pretrained_model.embeddings.token_type_ids is not None:
634+
# expected_seq_len = x.shape[1]
635+
# current_seq_len = self.pretrained_model.embeddings.token_type_ids.shape[1]
636+
# # Only delete if the cached buffer has wrong shape
637+
# if current_seq_len != expected_seq_len:
638+
# # Delete the registered buffer and let BERT recreate it
639+
# delattr(self.pretrained_model.embeddings, 'token_type_ids')
640+
# # Re-register with correct shape
641+
# self.pretrained_model.embeddings.register_buffer(
642+
# "token_type_ids",
643+
# torch.zeros((1, expected_seq_len), dtype=torch.long, device=x.device),
644+
# persistent=False
645+
# )
646+
647+
# if no_grad:
648+
# with torch.no_grad():
649+
# x = x.long() # Ensure the input tensor is of type long.
650+
# outputs = self.pretrained_model(x, attention_mask=attention_mask)
651+
# # Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
652+
# cls_embedding = outputs.last_hidden_state[:, 0, :]
653+
# else:
654+
# x = x.long()
655+
# outputs = self.pretrained_model(x, attention_mask=attention_mask)
656+
# cls_embedding = outputs.last_hidden_state[:, 0, :]
657+
658+
# cls_embedding = self.embed_proj_head(cls_embedding)
659+
# cls_embedding = self.norm(cls_embedding)
660+
661+
# return cls_embedding
662+
556663

557664
class RepresentationNetworkUniZero(nn.Module):
558665

lzero/model/unizero_world_models/tokenizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id: int = 0) -> torch.T
146146
x = x.contiguous().view(-1, original_shape[-1]) # Shape: (B*T, E)
147147
# Note: 2D (B, E) and 4D (B, C, H, W) inputs are processed directly without reshaping.
148148

149+
# [DEBUG] Log shape before encoder
150+
import logging
151+
logger = logging.getLogger(__name__)
152+
logger.info(f"[TOKENIZER_DEBUG] Before encoder: original_shape={original_shape}, x.shape={x.shape}, encoder_type={type(encoder_module).__name__}")
153+
149154
# Step 3: Pass the processed tensor through the encoder.
150155
obs_embeddings = encoder_module(x)
151156
if len(obs_embeddings.shape) != 2:

lzero/model/unizero_world_models/world_model.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -705,14 +705,16 @@ def forward(
705705
f"Returning dummy outputs with correct shapes."
706706
)
707707
# Return outputs with shape [batch, 1, ...] to allow squeeze(1) to work
708+
# Important: logits_value and logits_rewards need support_size dimension
708709
batch_size = obs_embeddings.shape[0]
710+
support_size = self.config.support_size
709711
return WorldModelOutput(
710712
torch.zeros(batch_size, 1, self.config.embed_dim, device=self.device),
711713
torch.zeros(batch_size, 1, self.num_observations_tokens, device=self.device),
712-
torch.zeros(batch_size, 1, device=self.device),
713-
None,
714-
torch.zeros(batch_size, 1, self.config.action_space_size, device=self.device),
715-
torch.zeros(batch_size, 1, device=self.device),
714+
torch.zeros(batch_size, 1, support_size, device=self.device), # logits_rewards
715+
None, # logits_ends
716+
torch.zeros(batch_size, 1, self.config.action_space_size, device=self.device), # logits_policy
717+
torch.zeros(batch_size, 1, support_size, device=self.device), # logits_value
716718
)
717719

718720
if not self.config.rotary_emb:
@@ -1650,11 +1652,19 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
16501652
if self.analysis_dormant_ratio_weight_rank:
16511653
# --- Dormant Ratio Calculation ---
16521654
# Calculate the dormant ratio of the encoder to monitor neuron activity.
1653-
shape = batch['observations'].shape # Original shape, e.g., (B, T, C, H, W)
1655+
shape = batch['observations'].shape # Original shape, e.g., (B, T, C, H, W) for images or (B, T, E) for text
16541656
# Reshape observations to create a single large batch for the encoder.
1655-
# E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64)
1656-
inputs = batch['observations'].contiguous().view(-1, *shape[-3:])
1657-
1657+
1658+
# [FIX] Handle both image and text observations
1659+
if len(shape) == 5: # Image: (B, T, C, H, W)
1660+
# E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64)
1661+
inputs = batch['observations'].contiguous().view(-1, *shape[-3:])
1662+
elif len(shape) == 3: # Text: (B, T, E)
1663+
# E.g., (2, 11, 512) -> (22, 512)
1664+
inputs = batch['observations'].contiguous().view(-1, shape[-1])
1665+
else: # Fall back to original behavior for 2D or 4D
1666+
inputs = batch['observations'].contiguous().view(-1, *shape[-3:])
1667+
16581668
dormant_ratio_encoder_dict = calculate_dormant_ratio(
16591669
self.tokenizer.encoder, inputs.detach(), dormant_threshold=self.dormant_threshold
16601670
)
@@ -1732,7 +1742,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
17321742
step_counter=global_step
17331743
)
17341744

1735-
if self.config.use_priority:
1745+
# [FIX] Add default value for use_priority if not present in config
1746+
use_priority = getattr(self.config, 'use_priority', False)
1747+
1748+
if use_priority:
17361749
# ==================== START MODIFICATION 5 ====================
17371750
# Calculate value_priority, similar to MuZero.
17381751
with torch.no_grad():

lzero/policy/scaling_transform.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,20 @@ def __init__(
8080

8181
def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor:
8282
if self.categorical_distribution:
83+
# [FIX] Handle edge case where logits might be 1D (batch_size=1 and squeezed)
84+
# Ensure logits is at least 2D for softmax operation
85+
if logits.dim() == 1:
86+
logits = logits.unsqueeze(0) # [support_size] -> [1, support_size]
87+
was_1d = True
88+
else:
89+
was_1d = False
90+
8391
value_probs = torch.softmax(logits, dim=1)
8492
value = value_probs.mul_(self.value_support).sum(1, keepdim=True)
93+
94+
# If input was 1D, squeeze back to maintain shape consistency
95+
if was_1d:
96+
value = value.squeeze(0) # [1, 1] -> [1]
8597
else:
8698
value = logits
8799
tmp = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon))

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def get_priorzero_config(
210210
# Analysis flags
211211
analysis_sim_norm=False,
212212
analysis_dormant_ratio_weight_rank=False,
213-
213+
# use_priority=False,
214+
use_priority=True,
215+
214216
# Position encoding
215217
rotary_emb=False, # Whether to use RoPE
216218
rope_theta=10000,
@@ -515,7 +517,7 @@ def get_priorzero_config_for_quick_test(env_id: str = 'zork1.z5', seed: int = 0,
515517

516518
main_config.policy.num_simulations = 2
517519
main_config.policy.batch_size = 2
518-
main_config.policy.game_segment_length = 50
520+
main_config.policy.game_segment_length = 20
519521
main_config.policy.num_segments = 2
520522
main_config.policy.replay_buffer_size = 1000
521523

@@ -525,11 +527,13 @@ def get_priorzero_config_for_quick_test(env_id: str = 'zork1.z5', seed: int = 0,
525527
main_config.env.collector_env_num,
526528
main_config.env.evaluator_env_num
527529
)
528-
main_config.policy.model.world_model_cfg.num_heads = 4
529-
main_config.policy.model.world_model_cfg.context_length = 3
530-
main_config.policy.model.world_model_cfg.num_unroll_steps = 5
531-
main_config.policy.model.world_model_cfg.max_blocks = 5
532-
main_config.policy.model.world_model_cfg.max_blocks = 10
530+
main_config.policy.model.world_model_cfg.num_heads = 2
531+
# [FIX] Set infer_context_length to match reduced num_unroll_steps
532+
main_config.policy.model.world_model_cfg.infer_context_length = 2 # Reduced from 4
533+
main_config.policy.model.world_model_cfg.context_length = 4 # 2 * infer_context_length
534+
main_config.policy.model.world_model_cfg.num_unroll_steps = 3
535+
main_config.policy.model.world_model_cfg.max_blocks = 3
536+
main_config.policy.model.world_model_cfg.max_tokens = 6 # 2 * max_blocks
533537

534538
main_config.policy.llm_policy_cfg.prompt_max_len = 1024
535539
main_config.policy.llm_policy_cfg.generate_max_len = 128

0 commit comments

Comments
 (0)