Skip to content

Commit

Permalink
Add BIOSCAN-CLIP project (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Sep 4, 2024
1 parent 7228880 commit c04c9fc
Show file tree
Hide file tree
Showing 13 changed files with 1,554 additions and 130 deletions.
1 change: 1 addition & 0 deletions mmlearn/datasets/core/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Modality(str):

_default_properties = {
"target": "{}_target",
"attention_mask": "{}_attention_mask",
"mask": "{}_mask",
"embedding": "{}_embedding",
"masked_embedding": "{}_masked_embedding",
Expand Down
16 changes: 12 additions & 4 deletions mmlearn/modules/encoders/clip_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
group="modules/encoders",
provider="mmlearn",
model_name_or_path="openai/clip-vit-base-patch16",
hydra_convert="object", # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
"""Wrapper around the `CLIPTextModel` from HuggingFace.
Expand Down Expand Up @@ -103,7 +104,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""
outputs = self.model(
input_ids=inputs[Modalities.TEXT],
attention_mask=inputs.get("attention_mask"),
attention_mask=inputs.get("attention_mask")
or inputs.get(Modalities.TEXT.attention_mask),
position_ids=inputs.get("position_ids"),
output_attentions=inputs.get("output_attentions"),
return_dict=True,
Expand All @@ -123,6 +125,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
group="modules/encoders",
provider="mmlearn",
model_name_or_path="openai/clip-vit-base-patch16",
hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
"""Wrapper around the `CLIPVisionModel` from HuggingFace.
Expand Down Expand Up @@ -247,6 +250,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
group="modules/encoders",
provider="mmlearn",
model_name_or_path="openai/clip-vit-base-patch16",
hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
"""Wrapper around the `CLIPTextModelWithProjection` from HuggingFace.
Expand Down Expand Up @@ -323,7 +327,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
The text embeddings. Will be a tuple with a single element.
"""
input_ids = inputs[Modalities.TEXT]
attention_mask = inputs.get("attention_mask")
attention_mask = inputs.get("attention_mask") or inputs.get(
Modalities.TEXT.attention_mask
)
position_ids = inputs.get("position_ids")

if self.use_all_token_embeddings:
Expand All @@ -350,6 +356,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
group="modules/encoders",
provider="mmlearn",
model_name_or_path="openai/clip-vit-base-patch16",
hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
"""Wrapper around the `CLIPVisionModelWithProjection` class from HuggingFace.
Expand Down Expand Up @@ -463,7 +470,7 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
return (self.model.visual_projection(pooled_output),)


@store(group="modules/encoders", provider="mmlearn")
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class PubMedBERTForCLIPTextEncoding(nn.Module):
"""BiomedNLP's PubMedBERT model for CLIP text encoding.
Expand Down Expand Up @@ -561,7 +568,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""
output = self.model(
input_ids=inputs[Modalities.TEXT],
attention_mask=inputs.get("attention_mask"),
attention_mask=inputs.get("attention_mask")
or inputs.get(Modalities.TEXT.attention_mask),
inputs_embeds=inputs.get("inputs_embeds"),
output_attentions=inputs.get("output_attentions"),
output_hidden_states=True,
Expand Down
6 changes: 3 additions & 3 deletions mmlearn/modules/encoders/hf_text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from peft import PeftConfig


@store(group="modules/encoders", provider="mmlearn")
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
"""Wrapper around huggingface models in the `AutoModelForTextEncoding` class.
Expand Down Expand Up @@ -66,7 +66,6 @@ def __init__( # noqa: PLR0912
super().__init__()
if model_config_kwargs is None:
model_config_kwargs = {}
model_config_kwargs["use_return_dict"] = True
model_config_kwargs["output_hidden_states"] = True
model_config_kwargs["add_pooling_layer"] = False
model = hf_utils.load_huggingface_model(
Expand Down Expand Up @@ -157,7 +156,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""
outputs = self.model(
input_ids=inputs[Modalities.TEXT],
attention_mask=inputs.get("attention_mask"),
attention_mask=inputs.get("attention_mask")
or inputs.get(Modalities.TEXT.attention_mask),
position_ids=inputs.get("position_ids"),
output_attentions=inputs.get("output_attentions"),
return_dict=True,
Expand Down
Loading

0 comments on commit c04c9fc

Please sign in to comment.