Skip to content

Commit c04c9fc

Browse files
authored
Add BIOSCAN-CLIP project (#8)
1 parent 7228880 commit c04c9fc

File tree

13 files changed

+1554
-130
lines changed

13 files changed

+1554
-130
lines changed

mmlearn/datasets/core/modalities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Modality(str):
3838

3939
_default_properties = {
4040
"target": "{}_target",
41+
"attention_mask": "{}_attention_mask",
4142
"mask": "{}_mask",
4243
"embedding": "{}_embedding",
4344
"masked_embedding": "{}_masked_embedding",

mmlearn/modules/encoders/clip_encoders.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
group="modules/encoders",
2525
provider="mmlearn",
2626
model_name_or_path="openai/clip-vit-base-patch16",
27+
hydra_convert="object", # required for `peft_config` to be converted to a `PeftConfig` object
2728
)
2829
class HFCLIPTextEncoder(nn.Module):
2930
"""Wrapper around the `CLIPTextModel` from HuggingFace.
@@ -103,7 +104,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
103104
"""
104105
outputs = self.model(
105106
input_ids=inputs[Modalities.TEXT],
106-
attention_mask=inputs.get("attention_mask"),
107+
attention_mask=inputs.get("attention_mask")
108+
or inputs.get(Modalities.TEXT.attention_mask),
107109
position_ids=inputs.get("position_ids"),
108110
output_attentions=inputs.get("output_attentions"),
109111
return_dict=True,
@@ -123,6 +125,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
123125
group="modules/encoders",
124126
provider="mmlearn",
125127
model_name_or_path="openai/clip-vit-base-patch16",
128+
hydra_convert="object",
126129
)
127130
class HFCLIPVisionEncoder(nn.Module):
128131
"""Wrapper around the `CLIPVisionModel` from HuggingFace.
@@ -247,6 +250,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
247250
group="modules/encoders",
248251
provider="mmlearn",
249252
model_name_or_path="openai/clip-vit-base-patch16",
253+
hydra_convert="object",
250254
)
251255
class HFCLIPTextEncoderWithProjection(nn.Module):
252256
"""Wrapper around the `CLIPTextModelWithProjection` from HuggingFace.
@@ -323,7 +327,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
323327
The text embeddings. Will be a tuple with a single element.
324328
"""
325329
input_ids = inputs[Modalities.TEXT]
326-
attention_mask = inputs.get("attention_mask")
330+
attention_mask = inputs.get("attention_mask") or inputs.get(
331+
Modalities.TEXT.attention_mask
332+
)
327333
position_ids = inputs.get("position_ids")
328334

329335
if self.use_all_token_embeddings:
@@ -350,6 +356,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
350356
group="modules/encoders",
351357
provider="mmlearn",
352358
model_name_or_path="openai/clip-vit-base-patch16",
359+
hydra_convert="object",
353360
)
354361
class HFCLIPVisionEncoderWithProjection(nn.Module):
355362
"""Wrapper around the `CLIPVisionModelWithProjection` class from HuggingFace.
@@ -463,7 +470,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
463470
return (self.model.visual_projection(pooled_output),)
464471

465472

466-
@store(group="modules/encoders", provider="mmlearn")
473+
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
467474
class PubMedBERTForCLIPTextEncoding(nn.Module):
468475
"""BiomedNLP's PubMedBERT model for CLIP text encoding.
469476
@@ -561,7 +568,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
561568
"""
562569
output = self.model(
563570
input_ids=inputs[Modalities.TEXT],
564-
attention_mask=inputs.get("attention_mask"),
571+
attention_mask=inputs.get("attention_mask")
572+
or inputs.get(Modalities.TEXT.attention_mask),
565573
inputs_embeds=inputs.get("inputs_embeds"),
566574
output_attentions=inputs.get("output_attentions"),
567575
output_hidden_states=True,

mmlearn/modules/encoders/hf_text_encoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from peft import PeftConfig
1818

1919

20-
@store(group="modules/encoders", provider="mmlearn")
20+
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
2121
class HFTextEncoder(nn.Module):
2222
"""Wrapper around huggingface models in the `AutoModelForTextEncoding` class.
2323
@@ -66,7 +66,6 @@ def __init__( # noqa: PLR0912
6666
super().__init__()
6767
if model_config_kwargs is None:
6868
model_config_kwargs = {}
69-
model_config_kwargs["use_return_dict"] = True
7069
model_config_kwargs["output_hidden_states"] = True
7170
model_config_kwargs["add_pooling_layer"] = False
7271
model = hf_utils.load_huggingface_model(
@@ -157,7 +156,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
157156
"""
158157
outputs = self.model(
159158
input_ids=inputs[Modalities.TEXT],
160-
attention_mask=inputs.get("attention_mask"),
159+
attention_mask=inputs.get("attention_mask")
160+
or inputs.get(Modalities.TEXT.attention_mask),
161161
position_ids=inputs.get("position_ids"),
162162
output_attentions=inputs.get("output_attentions"),
163163
return_dict=True,

0 commit comments

Comments
 (0)