@@ -470,7 +470,6 @@ def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
470470 def forward (self , x : torch .Tensor , no_grad : bool = True ) -> torch .Tensor :
471471 return self .encode (x , no_grad = no_grad )
472472
473-
474473class HFLanguageRepresentationNetwork (nn .Module ):
475474 def __init__ (self ,
476475 model_path : str = 'google-bert/bert-base-uncased' ,
@@ -489,32 +488,26 @@ def __init__(self,
489488 super ().__init__ ()
490489 from transformers import AutoModel , AutoTokenizer
491490
492- # [FIX] Load tokenizer for ALL ranks, not just non-zero ranks
493491 if tokenizer is not None :
494492 self .tokenizer = tokenizer
495493 else :
496- # Load tokenizer with same distributed logic as model
497494 if get_rank () == 0 :
498495 self .tokenizer = AutoTokenizer .from_pretrained (model_path )
499496 if get_world_size () > 1 :
500497 torch .distributed .barrier ()
501498 if get_rank () != 0 :
502499 self .tokenizer = AutoTokenizer .from_pretrained (model_path )
503500
504- # In distributed settings, ensure only rank 0 downloads the model/tokenizer.
505501 if get_rank () == 0 :
506502 self .pretrained_model = AutoModel .from_pretrained (model_path )
507-
508503 if get_world_size () > 1 :
509- # Wait for rank 0 to finish loading the model.
510504 torch .distributed .barrier ()
511505 if get_rank () != 0 :
512506 self .pretrained_model = AutoModel .from_pretrained (model_path )
513507
514508 self .embedding_size = embedding_size
515509 self .embed_proj_head = nn .Linear (self .pretrained_model .config .hidden_size , self .embedding_size )
516510
517- # # Select the normalization method based on the final_norm_option_in_encoder parameter.
518511 if final_norm_option_in_encoder .lower () == "simnorm" :
519512 self .norm = SimNorm (simnorm_dim = group_size )
520513 elif final_norm_option_in_encoder .lower () == "layernorm" :
@@ -533,26 +526,140 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
533526 Returns:
534527 - (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size).
535528 """
536-
529+ # Ensure the input tensor is of type long.
530+ x = x .long ()
531+
537532 # Construct the attention mask to exclude padding tokens.
538- attention_mask = x != self .tokenizer .pad_token_id
533+ attention_mask = (x != self .tokenizer .pad_token_id ).long ()
534+
535+ # ==================== 修复开始 ====================
536+ # 1. 显式地创建 token_type_ids
537+ # 对于单句输入,token_type_ids 是一个与 input_ids 形状相同的全零张量。
538+ token_type_ids = torch .zeros_like (x , device = x .device )
539+
540+ # 2. 移除危险的内部状态修改
541+ # 下面的代码块是导致错误的根源,必须删除。
542+ # if hasattr(self.pretrained_model, 'embeddings') and hasattr(self.pretrained_model.embeddings, 'token_type_ids'):
543+ # self.pretrained_model.embeddings.token_type_ids = None
544+ # ==================== 修复结束 ====================
539545
540546 if no_grad :
541547 with torch .no_grad ():
542- x = x .long () # Ensure the input tensor is of type long.
543- outputs = self .pretrained_model (x , attention_mask = attention_mask )
544- # Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
548+ # 3. 在模型调用时传入 token_type_ids
549+ outputs = self .pretrained_model (x , attention_mask = attention_mask , token_type_ids = token_type_ids )
545550 cls_embedding = outputs .last_hidden_state [:, 0 , :]
546551 else :
547- x = x . long ()
548- outputs = self .pretrained_model (x , attention_mask = attention_mask )
552+ # 3. 在模型调用时传入 token_type_ids
553+ outputs = self .pretrained_model (x , attention_mask = attention_mask , token_type_ids = token_type_ids )
549554 cls_embedding = outputs .last_hidden_state [:, 0 , :]
550555
551556 cls_embedding = self .embed_proj_head (cls_embedding )
552557 cls_embedding = self .norm (cls_embedding )
553558
554559 return cls_embedding
555560
561+ # class HFLanguageRepresentationNetwork(nn.Module):
562+ # def __init__(self,
563+ # model_path: str = 'google-bert/bert-base-uncased',
564+ # embedding_size: int = 768,
565+ # group_size: int = 8,
566+ # final_norm_option_in_encoder: str = "layernorm",
567+ # tokenizer=None):
568+ # """
569+ # Arguments:
570+ # - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'.
571+ # - embedding_size (int): The dimension of the output embeddings. Default is 768.
572+ # - group_size (int): The group size for SimNorm when using normalization.
573+ # - final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
574+ # - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model.
575+ # """
576+ # super().__init__()
577+ # from transformers import AutoModel, AutoTokenizer
578+
579+ # # [FIX] Load tokenizer for ALL ranks, not just non-zero ranks
580+ # if tokenizer is not None:
581+ # self.tokenizer = tokenizer
582+ # else:
583+ # # Load tokenizer with same distributed logic as model
584+ # if get_rank() == 0:
585+ # self.tokenizer = AutoTokenizer.from_pretrained(model_path)
586+ # if get_world_size() > 1:
587+ # torch.distributed.barrier()
588+ # if get_rank() != 0:
589+ # self.tokenizer = AutoTokenizer.from_pretrained(model_path)
590+
591+ # # In distributed settings, ensure only rank 0 downloads the model/tokenizer.
592+ # if get_rank() == 0:
593+ # self.pretrained_model = AutoModel.from_pretrained(model_path)
594+
595+ # if get_world_size() > 1:
596+ # # Wait for rank 0 to finish loading the model.
597+ # torch.distributed.barrier()
598+ # if get_rank() != 0:
599+ # self.pretrained_model = AutoModel.from_pretrained(model_path)
600+
601+ # self.embedding_size = embedding_size
602+ # self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)
603+
604+ # # # Select the normalization method based on the final_norm_option_in_encoder parameter.
605+ # if final_norm_option_in_encoder.lower() == "simnorm":
606+ # self.norm = SimNorm(simnorm_dim=group_size)
607+ # elif final_norm_option_in_encoder.lower() == "layernorm":
608+ # self.norm = nn.LayerNorm(embedding_size)
609+ # else:
610+ # raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
611+ # f"Choose 'simnorm' or 'layernorm'.")
612+
613+ # def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
614+ # """
615+ # Overview:
616+ # Computes language representation from input token IDs.
617+ # Arguments:
618+ # - x (:obj:`torch.Tensor`): Input token sequence of shape (B, seq_len).
619+ # - no_grad (:obj:`bool`): If True, run the transformer model in `torch.no_grad()` context.
620+ # Returns:
621+ # - (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size).
622+ # """
623+
624+ # # Construct the attention mask to exclude padding tokens.
625+ # attention_mask = x != self.tokenizer.pad_token_id
626+
627+ # # [FIX] Clear buffered token_type_ids to prevent shape mismatch errors
628+ # # BERT models cache token_type_ids for efficiency, but this causes issues
629+ # # when batch sizes or sequence lengths vary across different forward passes.
630+ # # We delete the buffer entirely and let BERT recreate it with the correct shape.
631+ # if hasattr(self.pretrained_model, 'embeddings') and hasattr(self.pretrained_model.embeddings, 'token_type_ids'):
632+ # # Check if token_type_ids exists and has wrong shape
633+ # if self.pretrained_model.embeddings.token_type_ids is not None:
634+ # expected_seq_len = x.shape[1]
635+ # current_seq_len = self.pretrained_model.embeddings.token_type_ids.shape[1]
636+ # # Only delete if the cached buffer has wrong shape
637+ # if current_seq_len != expected_seq_len:
638+ # # Delete the registered buffer and let BERT recreate it
639+ # delattr(self.pretrained_model.embeddings, 'token_type_ids')
640+ # # Re-register with correct shape
641+ # self.pretrained_model.embeddings.register_buffer(
642+ # "token_type_ids",
643+ # torch.zeros((1, expected_seq_len), dtype=torch.long, device=x.device),
644+ # persistent=False
645+ # )
646+
647+ # if no_grad:
648+ # with torch.no_grad():
649+ # x = x.long() # Ensure the input tensor is of type long.
650+ # outputs = self.pretrained_model(x, attention_mask=attention_mask)
651+ # # Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
652+ # cls_embedding = outputs.last_hidden_state[:, 0, :]
653+ # else:
654+ # x = x.long()
655+ # outputs = self.pretrained_model(x, attention_mask=attention_mask)
656+ # cls_embedding = outputs.last_hidden_state[:, 0, :]
657+
658+ # cls_embedding = self.embed_proj_head(cls_embedding)
659+ # cls_embedding = self.norm(cls_embedding)
660+
661+ # return cls_embedding
662+
556663
557664class RepresentationNetworkUniZero (nn .Module ):
558665
0 commit comments