From 3dea00f433d6b494df83a5ae5f86480b794c1ddb Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:18:59 -0500 Subject: [PATCH] Refactor contrastive loss (#35) --- mmlearn/cli/run.py | 2 +- mmlearn/conf/__init__.py | 2 +- mmlearn/hf_utils.py | 2 +- mmlearn/modules/encoders/__init__.py | 4 +- mmlearn/modules/encoders/clip.py | 117 ---- mmlearn/modules/encoders/vision.py | 2 +- mmlearn/modules/losses/__init__.py | 4 +- mmlearn/modules/losses/contrastive.py | 527 +++++++++++++++--- .../lr_schedulers/linear_warmup_cosine_lr.py | 2 +- mmlearn/modules/metrics/retrieval_recall.py | 10 +- mmlearn/tasks/contrastive_pretraining.py | 81 +-- mmlearn/tasks/zero_shot_retrieval.py | 20 +- 12 files changed, 508 insertions(+), 265 deletions(-) diff --git a/mmlearn/cli/run.py b/mmlearn/cli/run.py index f342458..eee7a3a 100644 --- a/mmlearn/cli/run.py +++ b/mmlearn/cli/run.py @@ -45,7 +45,7 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912 if is_torch_tf32_available(): torch.backends.cuda.matmul.allow_tf32 = True - if "16-mixed" in cfg.trainer.precision: + if "16-mixed" in str(cfg.trainer.precision): cfg.trainer.precision = "bf16-mixed" # setup trainer first so that we can get some variables for distributed training diff --git a/mmlearn/conf/__init__.py b/mmlearn/conf/__init__.py index df9b505..8a112c6 100644 --- a/mmlearn/conf/__init__.py +++ b/mmlearn/conf/__init__.py @@ -168,7 +168,7 @@ class MMLearnConf: job=JobConf( name=II("experiment_name"), env_set={ - "TORCH_NCCL_ASYNC_ERROR_HANDLING": "3", + "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1", "HYDRA_FULL_ERROR": "1", }, ), diff --git a/mmlearn/hf_utils.py b/mmlearn/hf_utils.py index 3d140d7..3cfff67 100644 --- a/mmlearn/hf_utils.py +++ b/mmlearn/hf_utils.py @@ -67,7 +67,7 @@ def load_huggingface_model( return_unused_kwargs=True, **model_config_kwargs, ) - model = model_type._from_config(config, **kwargs) + model = model_type.from_config(config, **kwargs) if get_model_attr is not None and hasattr(model, get_model_attr): model = getattr(model, get_model_attr) diff --git a/mmlearn/modules/encoders/__init__.py b/mmlearn/modules/encoders/__init__.py index a41fa6e..c29bb86 100644 --- a/mmlearn/modules/encoders/__init__.py +++ b/mmlearn/modules/encoders/__init__.py @@ -5,9 +5,9 @@ HFCLIPTextEncoderWithProjection, HFCLIPVisionEncoder, HFCLIPVisionEncoderWithProjection, - PubMedBERTForCLIPTextEncoding, ) from mmlearn.modules.encoders.text import HFTextEncoder +from mmlearn.modules.encoders.vision import TimmViT __all__ = [ @@ -16,5 +16,5 @@ "HFCLIPTextEncoderWithProjection", "HFCLIPVisionEncoder", "HFCLIPVisionEncoderWithProjection", - "PubMedBERTForCLIPTextEncoding", + "TimmViT", ] diff --git a/mmlearn/modules/encoders/clip.py b/mmlearn/modules/encoders/clip.py index 354bbfb..708dabb 100644 --- a/mmlearn/modules/encoders/clip.py +++ b/mmlearn/modules/encoders/clip.py @@ -474,123 +474,6 @@ def forward(self, inputs: Dict[str, Any]) -> Tuple[torch.Tensor]: return (self.model.visual_projection(pooled_output),) -@store(group="modules/encoders", provider="mmlearn", hydra_convert="object") -class PubMedBERTForCLIPTextEncoding(nn.Module): - """BiomedNLP's PubMedBERT model for CLIP text encoding. - - This module is wrapper around the PubMedBERT model from huggingface. - - Parameters - ---------- - pretrained : bool, default=False - Whether to load the pretrained weights or not. - pooling_layer : nn.Module, optional, default=None - Pooling layer to apply to the last hidden state of the model. - freeze_layers : int | float | List[int] | bool, default=False - Whether to freeze layers of the model and which layers to freeze. If `True`, - all model layers are frozen. If it is an integer, the first `N` layers of - the model are frozen. If it is a float, the first `N` percent of the layers - are frozen. If it is a list of integers, the layers at the indices in the - list are frozen. - freeze_layer_norm : bool, default=True - Whether to freeze the layer normalization layers of the model. - peft_config : PeftConfig, optional, default=None - The configuration from the `peft` library to use to wrap the model - for parameter-efficient finetuning. - model_config_kwargs : Dict[str, Any], optional, default=None - Additional keyword arguments to pass to the model configuration. - - Warns - ----- - UserWarning - If both `peft_config` and `freeze_layers` are set. The `peft_config` will - override the `freeze_layers` setting. - - """ - - def __init__( - self, - pretrained: bool = True, - pooling_layer: Optional[nn.Module] = None, - freeze_layers: Union[int, float, List[int], bool] = False, - freeze_layer_norm: bool = True, - peft_config: Optional["PeftConfig"] = None, - model_config_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize the model.""" - super().__init__() - _warn_freeze_with_peft(peft_config, freeze_layers) - - model = hf_utils.load_huggingface_model( - transformers.AutoModelForMaskedLM, - "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext", - load_pretrained_weights=pretrained, - get_model_attr="bert", - model_config_kwargs=model_config_kwargs, - ) - - if isinstance(freeze_layers, bool) and freeze_layers: - for name, param in model.named_parameters(): - param.requires_grad = ( - (not freeze_layer_norm) if "LayerNorm" in name else False - ) - - layers = [model.embeddings, *model.encoder.layer] - if isinstance(freeze_layers, float): - freeze_layers = int(freeze_layers * len(layers)) - if isinstance(freeze_layers, int): - freeze_layers = list(range(freeze_layers)) - - if isinstance(freeze_layers, list): - for idx, layer in enumerate(layers): - if idx in freeze_layers: - for name, param in layer.named_parameters(): - param.requires_grad = ( - (not freeze_layer_norm) if "LayerNorm" in name else False - ) - - if peft_config is not None: - model = hf_utils._wrap_peft_model(model, peft_config) - - self.model = model - self.pooling_layer = pooling_layer - - def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput: - """Run the forward pass. - - Parameters - ---------- - inputs : Dict[str, Any] - The input data. The `input_ids` will be expected under the `Modalities.TEXT` - key. - - Returns - ------- - BaseModelOutput - The output of the model, including the last hidden state, all hidden states, - and the attention weights, if `output_attentions` is set to `True`. - """ - output = self.model( - input_ids=inputs[Modalities.TEXT.name], - attention_mask=inputs.get( - "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None) - ), - inputs_embeds=inputs.get("inputs_embeds"), - output_attentions=inputs.get("output_attentions"), - output_hidden_states=True, - return_dict=True, - ) - last_hidden_state = output.last_hidden_state - if self.pooling_layer is not None: - last_hidden_state = self.pooling_layer(last_hidden_state) - - return BaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=output.hidden_states, - attentions=output.attentions, - ) - - #### Utility methods #### diff --git a/mmlearn/modules/encoders/vision.py b/mmlearn/modules/encoders/vision.py index 669d7f0..1f2ee7c 100644 --- a/mmlearn/modules/encoders/vision.py +++ b/mmlearn/modules/encoders/vision.py @@ -26,7 +26,7 @@ @store( group="modules/encoders", provider="mmlearn", - model_name_or_path="vit_base_patch16_224", + model_name="vit_base_patch16_224", hydra_convert="object", ) class TimmViT(nn.Module): diff --git a/mmlearn/modules/losses/__init__.py b/mmlearn/modules/losses/__init__.py index 164b1e9..c18868b 100644 --- a/mmlearn/modules/losses/__init__.py +++ b/mmlearn/modules/losses/__init__.py @@ -1,7 +1,7 @@ """Loss functions.""" -from mmlearn.modules.losses.contrastive import CLIPLoss +from mmlearn.modules.losses.contrastive import ContrastiveLoss from mmlearn.modules.losses.data2vec import Data2VecLoss -__all__ = ["CLIPLoss", "Data2VecLoss"] +__all__ = ["ContrastiveLoss", "Data2VecLoss"] diff --git a/mmlearn/modules/losses/contrastive.py b/mmlearn/modules/losses/contrastive.py index a2adafc..7d491ed 100644 --- a/mmlearn/modules/losses/contrastive.py +++ b/mmlearn/modules/losses/contrastive.py @@ -1,19 +1,24 @@ """Implementations of the contrastive loss and its variants.""" -from typing import Dict, Tuple +import itertools +from typing import Any, Dict, Optional import torch import torch.distributed as dist +import torch.distributed.nn as dist_nn from hydra_zen import store from torch import nn from torch.nn import functional as F # noqa: N812 from torchmetrics.utilities.compute import _safe_matmul -from torchmetrics.utilities.distributed import gather_all_tensors + +from mmlearn.datasets.core import find_matching_indices +from mmlearn.datasets.core.modalities import Modalities +from mmlearn.tasks.contrastive_pretraining import LossPairSpec @store(group="modules/losses", provider="mmlearn") -class CLIPLoss(nn.Module): - """CLIP Loss module. +class ContrastiveLoss(nn.Module): + """Contrastive Loss module. Parameters ---------- @@ -23,6 +28,10 @@ class CLIPLoss(nn.Module): Whether to calculate the loss locally i.e. `local_features@global_features`. gather_with_grad : bool, default=False Whether to gather tensors with gradients. + modality_alignment : bool, default=False + Whether to include modality alignment loss. This loss considers all features + from the same modality as positive pairs and all features from different + modalities as negative pairs. cache_labels : bool, default=False Whether to cache the labels. @@ -33,6 +42,7 @@ def __init__( l2_normalize: bool = False, local_loss: bool = False, gather_with_grad: bool = False, + modality_alignment: bool = False, cache_labels: bool = False, ): """Initialize the loss.""" @@ -41,37 +51,158 @@ def __init__( self.gather_with_grad = gather_with_grad self.cache_labels = cache_labels self.l2_normalize = l2_normalize + self.modality_alignment = modality_alignment # cache state self._prev_num_logits = 0 self._labels: Dict[torch.device, torch.Tensor] = {} + def forward( + self, + embeddings: dict[str, torch.Tensor], + example_ids: dict[str, torch.Tensor], + logit_scale: torch.Tensor, + modality_loss_pairs: list[LossPairSpec], + ) -> torch.Tensor: + """Calculate the contrastive loss. + + Parameters + ---------- + embeddings : dict[str, torch.Tensor] + Dictionary of embeddings, where the key is the modality name and the value + is the corresponding embedding tensor. + example_ids : dict[str, torch.Tensor] + Dictionary of example IDs, where the key is the modality name and the value + is a tensor tuple of the dataset index and the example index. + logit_scale : torch.Tensor + Scale factor for the logits. + modality_loss_pairs : List[LossPairSpec] + Specification of the modality pairs for which the loss should be calculated. + + Returns + ------- + torch.Tensor + Contrastive loss. + """ + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank = dist.get_rank() if world_size > 1 else 0 + + if self.l2_normalize: + embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()} + + if world_size > 1: # gather embeddings and example_ids across all processes + # NOTE: gathering dictionaries of tensors across all processes + # (keys + values, as opposed to just values) is especially important + # for the modality_alignment loss, which requires all embeddings + all_embeddings = _gather_dicts( + embeddings, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=rank, + ) + all_example_ids = _gather_dicts( + example_ids, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=rank, + ) + else: + all_embeddings = embeddings + all_example_ids = example_ids + + losses = [] + for loss_pairs in modality_loss_pairs: + logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits( + loss_pairs.modalities, + per_device_embeddings=embeddings, + all_embeddings=all_embeddings, + per_device_example_ids=example_ids, + all_example_ids=all_example_ids, + logit_scale=logit_scale, + world_size=world_size, + ) + if logits_per_feature_a is None or logits_per_feature_b is None: + continue + + labels = self._get_ground_truth( + logits_per_feature_a.shape, + device=logits_per_feature_a.device, + rank=rank, + world_size=world_size, + skipped_process=skip_flag, + ) + + if labels.numel() != 0: + losses.append( + ( + ( + F.cross_entropy(logits_per_feature_a, labels) + + F.cross_entropy(logits_per_feature_b, labels) + ) + / 2 + ) + * loss_pairs.weight + ) + + if self.modality_alignment: + losses.append( + self._compute_modality_alignment_loss(all_embeddings, logit_scale) + ) + + return torch.stack(losses).sum() + def _get_ground_truth( - self, device: torch.device, num_logits: int, rank: int, world_size: int + self, + logits_shape: tuple[int, int], + device: torch.device, + rank: int, + world_size: int, + skipped_process: bool, ) -> torch.Tensor: """Return the ground-truth labels. Parameters ---------- + logits_shape : tuple[int, int] + Shape of the logits tensor. device : torch.device - Device to store the labels. - num_logits : int - Number of logits. + Device on which the labels should be created. rank : int Rank of the current process. world_size : int Number of processes. + skipped_process : bool + Whether the current process skipped the computation due to lack of data. Returns ------- torch.Tensor Ground-truth labels. """ + num_logits = logits_shape[-1] + # calculate ground-truth and cache if enabled if self._prev_num_logits != num_logits or device not in self._labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) + if world_size > 1 and self.local_loss: - labels = labels + num_logits * rank + local_size = torch.tensor( + 0 if skipped_process else logits_shape[0], device=device + ) + # NOTE: all processes must participate in the all_gather operation + # even if they have no data to contribute. + sizes = torch.stack( + _simple_gather_all_tensors( + local_size, group=dist.group.WORLD, world_size=world_size + ) + ) + sizes = torch.cat( + [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)] + ) + labels = labels[ + sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None + ] + if self.cache_labels: self._labels[device] = labels self._prev_num_logits = num_logits @@ -79,118 +210,228 @@ def _get_ground_truth( labels = self._labels[device] return labels - def _get_logits( + def _get_logits( # noqa: PLR0912 self, - features_1: torch.Tensor, - features_2: torch.Tensor, + modalities: tuple[str, str], + per_device_embeddings: dict[str, torch.Tensor], + all_embeddings: dict[str, torch.Tensor], + per_device_example_ids: dict[str, torch.Tensor], + all_example_ids: dict[str, torch.Tensor], logit_scale: torch.Tensor, - rank: int, world_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Return the logits. + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: + """Calculate the logits for the given modalities. Parameters ---------- - features_1 : torch.Tensor - First feature tensor. - features_2 : torch.Tensor - Second feature tensor. + modalities : tuple[str, str] + Tuple of modality names. + per_device_embeddings : dict[str, torch.Tensor] + Dictionary of embeddings, where the key is the modality name and the value + is the corresponding embedding tensor. + all_embeddings : dict[str, torch.Tensor] + Dictionary of embeddings, where the key is the modality name and the value + is the corresponding embedding tensor. In distributed mode, this contains + embeddings from all processes. + per_device_example_ids : dict[str, torch.Tensor] + Dictionary of example IDs, where the key is the modality name and the value + is a tensor tuple of the dataset index and the example index. + all_example_ids : dict[str, torch.Tensor] + Dictionary of example IDs, where the key is the modality name and the value + is a tensor tuple of the dataset index and the example index. In distributed + mode, this contains example IDs from all processes. logit_scale : torch.Tensor - Logit scale. - rank : int - Rank of the current process. + Scale factor for the logits. world_size : int Number of processes. Returns ------- - Tuple[torch.Tensor, torch.Tensor] - Logits per feature_1 and feature_2, respectively. - + tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool] + Tuple of logits for the given modalities. If embeddings for the given + modalities are not available, returns `None` for the logits. The last + element is a flag indicating whether the process skipped the computation + due to lack of data. """ + modality_a = Modalities.get_modality(modalities[0]) + modality_b = Modalities.get_modality(modalities[1]) + skip_flag = False + + if self.local_loss or world_size == 1: + if not ( + modality_a.embedding in per_device_embeddings + and modality_b.embedding in per_device_embeddings + ): + if world_size > 1: # NOTE: not all processes exit here, hence skip_flag + skip_flag = True + else: + return None, None, skip_flag + + if not skip_flag: + indices_a, indices_b = find_matching_indices( + per_device_example_ids[modality_a.name], + per_device_example_ids[modality_b.name], + ) + if indices_a.numel() == 0 or indices_b.numel() == 0: + if world_size > 1: # not all processes exit here + skip_flag = True + else: + return None, None, skip_flag + + if not skip_flag: + features_a = per_device_embeddings[modality_a.embedding][indices_a] + features_b = per_device_embeddings[modality_b.embedding][indices_b] + else: + # all processes must participate in the all_gather operation + # that follows, even if they have no data to contribute. So, + # we create empty tensors here. + features_a = torch.empty( + 0, device=list(per_device_embeddings.values())[0].device + ) + features_b = torch.empty( + 0, device=list(per_device_embeddings.values())[0].device + ) + if world_size > 1: - all_features_1 = gather_features( - features_1, self.local_loss, self.gather_with_grad, rank - ) - all_features_2 = gather_features( - features_2, self.local_loss, self.gather_with_grad, rank + if not ( + modality_a.embedding in all_embeddings + and modality_b.embedding in all_embeddings + ): # all processes exit here + return None, None, skip_flag + + indices_a, indices_b = find_matching_indices( + all_example_ids[modality_a.name], + all_example_ids[modality_b.name], ) + if indices_a.numel() == 0 or indices_b.numel() == 0: + # all processes exit here + return None, None, skip_flag + + all_features_a = all_embeddings[modality_a.embedding][indices_a] + all_features_b = all_embeddings[modality_b.embedding][indices_b] if self.local_loss: - logits_per_feature_1 = logit_scale * _safe_matmul( - features_1, all_features_2 + if features_a.numel() == 0: + features_a = all_features_a + if features_b.numel() == 0: + features_b = all_features_b + + logits_per_feature_a = logit_scale * _safe_matmul( + features_a, all_features_b ) - logits_per_feature_2 = logit_scale * _safe_matmul( - features_2, all_features_1 + logits_per_feature_b = logit_scale * _safe_matmul( + features_b, all_features_a ) else: - logits_per_feature_1 = logit_scale * _safe_matmul( - all_features_1, all_features_2 + logits_per_feature_a = logit_scale * _safe_matmul( + all_features_a, all_features_b ) - logits_per_feature_2 = logits_per_feature_1.T + logits_per_feature_b = logits_per_feature_a.T else: - logits_per_feature_1 = logit_scale * _safe_matmul(features_1, features_2) - logits_per_feature_2 = logit_scale * _safe_matmul(features_2, features_1) + logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b) + logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a) - return logits_per_feature_1, logits_per_feature_2 + return logits_per_feature_a, logits_per_feature_b, skip_flag - def forward( - self, - features_1: torch.Tensor, - features_2: torch.Tensor, - logit_scale: torch.Tensor, + def _compute_modality_alignment_loss( + self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor ) -> torch.Tensor: - """Calculate the CLIP-style loss between two sets of features. + """Compute the modality alignment loss. + + This loss considers all features from the same modality as positive pairs + and all features from different modalities as negative pairs. Parameters ---------- - features_1 : torch.Tensor - First set of features. - features_2 : torch.Tensor - Second set of features. + all_embeddings : dict[str, torch.Tensor] + Dictionary of embeddings, where the key is the modality name and the value + is the corresponding embedding tensor. logit_scale : torch.Tensor - Logit scale. + Scale factor for the logits. Returns ------- torch.Tensor - Loss value. - """ - world_size = dist.get_world_size() if dist.is_initialized() else 1 - rank = dist.get_rank() if world_size > 1 else 0 + Modality alignment loss. - if self.l2_normalize: - features_1 = F.normalize(features_1, p=2, dim=-1) - features_2 = F.normalize(features_2, p=2, dim=-1) + Notes + ----- + This loss does not support `local_loss=True`. + """ + available_modalities = list(all_embeddings.keys()) + # TODO: support local_loss for modality_alignment? + # if world_size == 1, all_embeddings == embeddings + all_features = torch.cat(list(all_embeddings.values()), dim=0) - logits_per_feat1, logits_per_feat2 = self._get_logits( - features_1, features_2, logit_scale, rank=rank, world_size=world_size + positive_indices = torch.tensor( + [ + (i, j) + if idx == 0 + else ( + i + all_embeddings[available_modalities[idx - 1]].size(0), + j + all_embeddings[available_modalities[idx - 1]].size(0), + ) + for idx, k in enumerate(all_embeddings) + for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2) + ], + device=all_features.device, ) - labels = self._get_ground_truth( - features_1.device, - logits_per_feat1.shape[0], - rank=rank, - world_size=world_size, + logits = logit_scale * _safe_matmul(all_features, all_features) + + target = torch.eye(all_features.size(0), device=all_features.device) + target[positive_indices[:, 0], positive_indices[:, 1]] = 1 + + modality_loss = torch.nn.functional.binary_cross_entropy_with_logits( + logits, target, reduction="none" ) - return ( - F.cross_entropy(logits_per_feat1, labels) - + F.cross_entropy(logits_per_feat2, labels) - ) / 2 + target_pos = target.bool() + target_neg = ~target_pos + # loss_pos and loss_neg below contain non-zero values only for those + # elements that are positive pairs and negative pairs respectively. + loss_pos = torch.zeros( + logits.size(0), logits.size(0), device=target.device + ).masked_scatter(target_pos, modality_loss[target_pos]) + loss_neg = torch.zeros( + logits.size(0), logits.size(0), device=target.device + ).masked_scatter(target_neg, modality_loss[target_neg]) -def gather_features( - features: torch.Tensor, - local_loss: bool = False, + loss_pos = loss_pos.sum(dim=1) + loss_neg = loss_neg.sum(dim=1) + num_pos = target.sum(dim=1) + num_neg = logits.size(0) - num_pos + + return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean() + + +def _get_dtype_max(tensor: torch.Tensor) -> torch.Tensor: + if tensor.is_floating_point(): + return torch.finfo(tensor.dtype).max + if not tensor.is_complex(): + return torch.iinfo(tensor.dtype).max + raise ValueError( + f"Unsupported dtype {tensor.dtype}. Only floating point and integer types are supported." + ) + + +def _is_all_dtype_max(tensor: torch.Tensor) -> torch.BoolTensor: + dtype_max = _get_dtype_max(tensor) + return torch.all(tensor == dtype_max) + + +def _gather_dicts( + dicts: dict[str, torch.Tensor], + local_loss: bool, + rank: int, gather_with_grad: bool = False, - rank: int = 0, -) -> torch.Tensor: - """Gather features across all processes. +) -> dict[str, torch.Tensor]: + """Gather dictionaries of tensors across all processes. Parameters ---------- - features : torch.Tensor - First feature tensor to gather. + dicts : dict[str, torch.Tensor] + Dictionary of tensors to gather. local_loss : bool, default=False Whether to calculate the loss locally i.e. `matmul(local_features, global_features)`. If False, this method ensures @@ -202,16 +443,128 @@ def gather_features( Returns ------- - torch.Tensor - Gathered features. + dict[str, torch.Tensor] + Gathered dictionary of tensors. """ + group = dist.group.WORLD + world_size = dist.get_world_size(group) + current_device = next(iter(dicts.values())).device + dist.barrier(group=group) + + # gather keys + local_keys = list(dicts.keys()) + all_keys: list[str] = [None] * world_size # type: ignore[list-item] + dist.all_gather_object(all_keys, local_keys, group=group) + all_keys = sorted(set(itertools.chain.from_iterable(all_keys))) + + # gather tensors + gathered_dict: dict[str, torch.Tensor] = {} + for key in all_keys: + if key not in dicts: # use dummy tensor for missing key in current process + placeholder_tensor = dicts[local_keys[0]] + tensor = torch.full_like( + placeholder_tensor, + fill_value=_get_dtype_max(placeholder_tensor), + device=current_device, + memory_format=torch.contiguous_format, + requires_grad=gather_with_grad + and placeholder_tensor.is_floating_point(), # only floating point tensors can have gradients + ) + else: + tensor = dicts[key].contiguous() + + gathered_tensors: list[torch.Tensor] = _gather_all_tensors( + tensor, + world_size=world_size, + group=group, + gather_with_grad=gather_with_grad, + ) + + if not gather_with_grad and not local_loss: + gathered_tensors[rank] = tensor + + # filter out placeholder tensors + gathered_tensors = [t for t in gathered_tensors if not _is_all_dtype_max(t)] + + gathered_dict[key] = torch.cat(gathered_tensors, dim=0) + + return gathered_dict + + +def _simple_gather_all_tensors( + result: torch.Tensor, group: Any, world_size: int, gather_with_grad: bool = False +) -> list[torch.Tensor]: if gather_with_grad: - all_features = torch.cat(torch.distributed.nn.all_gather(features), dim=0) - else: - gathered_features = gather_all_tensors(features) - if not local_loss: - # ensure grads for local rank when all_* features don't have a gradient - gathered_features[rank] = features - all_features = torch.cat(gathered_features, dim=0) + return list(dist_nn.all_gather(result, group)) - return all_features + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + dist.all_gather(gathered_result, result, group) + return gathered_result + + +def _gather_all_tensors( + a_tensor: torch.Tensor, + world_size: Optional[int] = None, + group: Optional[Any] = None, + gather_with_grad: bool = False, +) -> list[torch.Tensor]: + """Gather tensor(s) from all devices onto a list and broadcast to all devices. + + Parameters + ---------- + a_tensor : torch.Tensor + The tensor to gather. + world_size : int, default=None + Number of processes in the group. + group : Any, default=None + The process group to work on. + gather_with_grad : bool, default=False + Whether to gather tensors with gradients. + + Returns + ------- + list[torch.Tensor] + List of gathered tensors. + """ + if group is None: + group = torch.distributed.group.WORLD + + # convert tensors to contiguous format + a_tensor = a_tensor.contiguous() + + if world_size is None: + world_size = dist.get_world_size(group) + dist.barrier(group=group) + + # if the tensor is scalar, things are easy + if a_tensor.ndim == 0: + return _simple_gather_all_tensors(a_tensor, group, world_size, gather_with_grad) + + # 1. Gather sizes of all tensors + local_size = torch.tensor(a_tensor.shape, device=a_tensor.device) + local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(local_sizes, local_size, group=group) + max_size = torch.stack(local_sizes).max(dim=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # 2. If shapes are all the same, then do a simple gather: + if all_sizes_equal: + return _simple_gather_all_tensors(a_tensor, group, world_size, gather_with_grad) + + # 3. If not, we need to pad each local tensor to maximum size, gather and + # then truncate + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = F.pad(a_tensor, pad_dims) + if gather_with_grad: + gathered_result = list(dist_nn.all_gather(result_padded, group)) + else: + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + dist.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + return gathered_result diff --git a/mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py b/mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py index 8fe4217..46e2fb1 100644 --- a/mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py +++ b/mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py @@ -73,7 +73,7 @@ def linear_warmup_cosine_annealing_lr( ) cosine_lr = CosineAnnealingLR( optimizer, - T_max=max_steps - warmup_steps - 1, + T_max=max_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch, ) diff --git a/mmlearn/modules/metrics/retrieval_recall.py b/mmlearn/modules/metrics/retrieval_recall.py index 35c0064..4ff9751 100644 --- a/mmlearn/modules/metrics/retrieval_recall.py +++ b/mmlearn/modules/metrics/retrieval_recall.py @@ -6,6 +6,7 @@ import torch import torch.distributed from hydra_zen import store +from torch.nn import functional as F # noqa: N812 from torchmetrics import Metric from torchmetrics.retrieval.base import _retrieval_aggregate from torchmetrics.utilities.checks import _check_same_shape @@ -52,7 +53,7 @@ class RetrievalRecallAtK(Metric): def __init__( self, top_k: int, - reduction: Literal["mean", "sum", "none", None] = "sum", + reduction: Literal["mean", "sum", "none", None] = None, aggregation: Union[ Literal["mean", "median", "min", "max"], Callable[[torch.Tensor, int], torch.Tensor], @@ -166,12 +167,9 @@ def compute(self) -> torch.Tensor: torch.Tensor The computed metric. """ - x = dim_zero_cat(self.x) - y = dim_zero_cat(self.y) - # compute the cosine similarity - x_norm = x / x.norm(dim=-1, p=2, keepdim=True) - y_norm = y / y.norm(dim=-1, p=2, keepdim=True) + x_norm = F.normalize(dim_zero_cat(self.x), p=2, dim=-1) + y_norm = F.normalize(dim_zero_cat(self.y), p=2, dim=-1) similarity = _safe_matmul(x_norm, y_norm) reduction_mapping: Dict[ Optional[str], Callable[[torch.Tensor], torch.Tensor] diff --git a/mmlearn/tasks/contrastive_pretraining.py b/mmlearn/tasks/contrastive_pretraining.py index 04146c7..c48a724 100644 --- a/mmlearn/tasks/contrastive_pretraining.py +++ b/mmlearn/tasks/contrastive_pretraining.py @@ -17,9 +17,8 @@ from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn -from mmlearn.datasets.core import Modalities, find_matching_indices +from mmlearn.datasets.core import Modalities from mmlearn.datasets.core.modalities import Modality -from mmlearn.modules.losses import CLIPLoss from mmlearn.tasks.hooks import EvaluationHooks @@ -165,7 +164,7 @@ def __init__( # noqa: PLR0912, PLR0915 init_logit_scale: float = 1 / 0.07, max_logit_scale: float = 100, learnable_logit_scale: bool = True, - loss: Optional[CLIPLoss] = None, + loss: Optional[nn.Module] = None, modality_loss_pairs: Optional[List[LossPairSpec]] = None, auxiliary_tasks: Optional[Dict[str, AuxiliaryTaskSpec]] = None, log_auxiliary_tasks_loss: bool = False, @@ -184,6 +183,7 @@ def __init__( # noqa: PLR0912, PLR0915 "loss", "auxiliary_tasks", "evaluation_tasks", + "modality_loss_pairs", ] ) @@ -255,7 +255,7 @@ def __init__( # noqa: PLR0912, PLR0915 if isinstance(heads[head_key], nn.Module) else nn.Sequential(*heads[head_key].values()) for modality_key, head_key in modality_head_mapping.items() - if head_key is not None + if head_key is not None and head_key in heads } ) @@ -270,6 +270,7 @@ def __init__( # noqa: PLR0912, PLR0915 else nn.Sequential(*postprocessors[postprocessor_key].values()) for modality_key, postprocessor_key in modality_postprocessor_mapping.items() if postprocessor_key is not None + and postprocessor_key in postprocessors } ) @@ -368,12 +369,12 @@ def encode( if self.heads and modality.name in self.heads: output = self.heads[modality.name](output) - if self.postprocessors and modality.name in self.postprocessors: - output = self.postprocessors[modality.name](output) - if normalize: output = torch.nn.functional.normalize(output, p=2, dim=-1) + if self.postprocessors and modality.name in self.postprocessors: + output = self.postprocessors[modality.name](output) + return output def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: @@ -392,6 +393,7 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: outputs = { modality.embedding: self.encode(inputs, modality, normalize=True) for modality in self._available_modalities + if modality.name in inputs } if not all( @@ -408,37 +410,13 @@ def _compute_loss( if self.loss_fn is None: return None - with torch.no_grad(): - self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale)) - self.log( - "train/logit_scale", + contrastive_loss = self.loss_fn( + outputs, + batch["example_ids"], self.log_logit_scale.exp(), - prog_bar=True, - on_step=True, - on_epoch=False, + self.modality_loss_pairs, ) - contrastive_losses: list[torch.Tensor] = [] - for loss_pair in self.modality_loss_pairs: - modality_a = Modalities.get_modality(loss_pair.modalities[0]) - modality_b = Modalities.get_modality(loss_pair.modalities[1]) - - indices_a, indices_b = find_matching_indices( - batch["example_ids"][modality_a.name], - batch["example_ids"][modality_b.name], - ) - if indices_a.numel() == 0 or indices_b.numel() == 0: - continue - - contrastive_losses.append( - self.loss_fn( - outputs[modality_a.embedding][indices_a], - outputs[modality_b.embedding][indices_b], - self.log_logit_scale.exp(), - ) - * loss_pair.weight - ) - auxiliary_losses: list[torch.Tensor] = [] if self.auxiliary_tasks: for task_name, task_spec in self.aux_task_specs.items(): @@ -457,9 +435,22 @@ def _compute_loss( auxiliary_losses.append(task_spec.loss_weight * auxiliary_task_loss) if self.log_auxiliary_tasks_loss: - self.log(f"train/{task_name}_loss", auxiliary_task_loss) + self.log( + f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True + ) + + if not auxiliary_losses: + return contrastive_loss - return torch.stack(contrastive_losses + auxiliary_losses).sum() + return torch.stack(auxiliary_losses).sum() + contrastive_loss + + def on_train_epoch_start(self) -> None: + """Prepare for the training epoch.""" + self.encoders.train() + if self.heads: + self.heads.train() + if self.postprocessors: + self.postprocessors.train() def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Compute the loss for the batch. @@ -477,12 +468,24 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: The loss for the batch. """ outputs = self(batch) + + with torch.no_grad(): + self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale)) + loss = self._compute_loss(batch, batch_idx, outputs) + print("loss: ", loss) if loss is None: raise ValueError("The loss function must be provided for training.") - self.log("train/loss", loss, prog_bar=True) + self.log("train/loss", loss, prog_bar=True, sync_dist=True) + self.log( + "train/logit_scale", + self.log_logit_scale.exp(), + prog_bar=True, + on_step=True, + on_epoch=False, + ) return loss @@ -661,7 +664,7 @@ def _shared_eval_step( outputs = self(batch) loss = self._compute_loss(batch, batch_idx, outputs) if loss is not None and not self.trainer.sanity_checking: - self.log(f"{eval_type}/loss", loss, prog_bar=True) + self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True) if self.evaluation_tasks: for task_spec in self.evaluation_tasks.values(): diff --git a/mmlearn/tasks/zero_shot_retrieval.py b/mmlearn/tasks/zero_shot_retrieval.py index 36f6deb..f2a8e37 100644 --- a/mmlearn/tasks/zero_shot_retrieval.py +++ b/mmlearn/tasks/zero_shot_retrieval.py @@ -48,6 +48,7 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]): self.task_specs = task_specs self.metrics: Dict[Tuple[str, str], MetricCollection] = {} + self._available_modalities = set() for spec in self.task_specs: query_modality = spec.query_modality @@ -63,6 +64,8 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]): for k in spec.top_k } ) + self._available_modalities.add(query_modality) + self._available_modalities.add(target_modality) def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None: """Move the metrics to the device of the Lightning module.""" @@ -90,14 +93,17 @@ def evaluation_step( if pl_module.trainer.sanity_checking: return - outputs: Dict[str, Any] = pl_module(batch) + outputs: Dict[str, Any] = {} + for modality_name in self._available_modalities: + if modality_name in batch: + outputs[modality_name] = pl_module.encode( + batch, Modalities.get_modality(modality_name), normalize=False + ) for (query_modality, target_modality), metric in self.metrics.items(): - query_embeddings: torch.Tensor = outputs[ - Modalities.get_modality(query_modality).embedding - ] - target_embeddings: torch.Tensor = outputs[ - Modalities.get_modality(target_modality).embedding - ] + if query_modality not in outputs or target_modality not in outputs: + continue + query_embeddings: torch.Tensor = outputs[query_modality] + target_embeddings: torch.Tensor = outputs[target_modality] indexes = torch.arange(query_embeddings.size(0), device=pl_module.device) metric.update(query_embeddings, target_embeddings, indexes)