2424 group = "modules/encoders" ,
2525 provider = "mmlearn" ,
2626 model_name_or_path = "openai/clip-vit-base-patch16" ,
27+ hydra_convert = "object" , # required for `peft_config` to be converted to a `PeftConfig` object
2728)
2829class HFCLIPTextEncoder (nn .Module ):
2930 """Wrapper around the `CLIPTextModel` from HuggingFace.
@@ -103,7 +104,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
103104 """
104105 outputs = self .model (
105106 input_ids = inputs [Modalities .TEXT ],
106- attention_mask = inputs .get ("attention_mask" ),
107+ attention_mask = inputs .get ("attention_mask" )
108+ or inputs .get (Modalities .TEXT .attention_mask ),
107109 position_ids = inputs .get ("position_ids" ),
108110 output_attentions = inputs .get ("output_attentions" ),
109111 return_dict = True ,
@@ -123,6 +125,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
123125 group = "modules/encoders" ,
124126 provider = "mmlearn" ,
125127 model_name_or_path = "openai/clip-vit-base-patch16" ,
128+ hydra_convert = "object" ,
126129)
127130class HFCLIPVisionEncoder (nn .Module ):
128131 """Wrapper around the `CLIPVisionModel` from HuggingFace.
@@ -247,6 +250,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
247250 group = "modules/encoders" ,
248251 provider = "mmlearn" ,
249252 model_name_or_path = "openai/clip-vit-base-patch16" ,
253+ hydra_convert = "object" ,
250254)
251255class HFCLIPTextEncoderWithProjection (nn .Module ):
252256 """Wrapper around the `CLIPTextModelWithProjection` from HuggingFace.
@@ -323,7 +327,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
323327 The text embeddings. Will be a tuple with a single element.
324328 """
325329 input_ids = inputs [Modalities .TEXT ]
326- attention_mask = inputs .get ("attention_mask" )
330+ attention_mask = inputs .get ("attention_mask" ) or inputs .get (
331+ Modalities .TEXT .attention_mask
332+ )
327333 position_ids = inputs .get ("position_ids" )
328334
329335 if self .use_all_token_embeddings :
@@ -350,6 +356,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
350356 group = "modules/encoders" ,
351357 provider = "mmlearn" ,
352358 model_name_or_path = "openai/clip-vit-base-patch16" ,
359+ hydra_convert = "object" ,
353360)
354361class HFCLIPVisionEncoderWithProjection (nn .Module ):
355362 """Wrapper around the `CLIPVisionModelWithProjection` class from HuggingFace.
@@ -463,7 +470,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
463470 return (self .model .visual_projection (pooled_output ),)
464471
465472
466- @store (group = "modules/encoders" , provider = "mmlearn" )
473+ @store (group = "modules/encoders" , provider = "mmlearn" , hydra_convert = "object" )
467474class PubMedBERTForCLIPTextEncoding (nn .Module ):
468475 """BiomedNLP's PubMedBERT model for CLIP text encoding.
469476
@@ -561,7 +568,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
561568 """
562569 output = self .model (
563570 input_ids = inputs [Modalities .TEXT ],
564- attention_mask = inputs .get ("attention_mask" ),
571+ attention_mask = inputs .get ("attention_mask" )
572+ or inputs .get (Modalities .TEXT .attention_mask ),
565573 inputs_embeds = inputs .get ("inputs_embeds" ),
566574 output_attentions = inputs .get ("output_attentions" ),
567575 output_hidden_states = True ,
0 commit comments