diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..ca47eb6 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,20 @@ +This is a template for making a pull-request. You can remove the text and sections and write your own thing if you wish, just make sure you give enough information about how and why. If you have any issues or difficulties, don't hesitate to open an issue. + + +# Description + +The aim is to add this feature ... + +# Proposed Changes + +I changed the `foo()` function so that ... + + +# Checklist + +Here are some things to check before creating the pull request. If you encounter any issues, don't hesitate to ask for help :) + +- [ ] I have read the [contributor's guide](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md). +- [ ] The base branch of my pull request is the `dev` branch, not the `main` branch. +- [ ] I ran the [code checks](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md#implement-your-changes) on the files I added or modified and fixed the errors. +- [ ] I updated the [changelog](https://github.com/arnab39/equiadapt/blob/main/CHANGELOG.md). diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b95f5ab..2b77530 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,14 +58,16 @@ jobs: test: needs: prepare strategy: + fail-fast: false matrix: - python: - - "3.7" # oldest Python supported by PSF - - "3.10" # newest Python that is stable - platform: - - ubuntu-latest - # - macos-latest - # - windows-latest + python: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + platform: [ubuntu-latest, macos-latest, windows-latest] + exclude: # Python < v3.8 does not support Apple Silicon ARM64. + - python: "3.7" + platform: macos-latest + include: # So run those legacy versions on Intel CPUs. + - python: "3.7" + platform: macos-13 runs-on: ${{ matrix.platform }} steps: - uses: actions/checkout@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f783d0..9a01735 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: '^docs/conf.py' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: check-added-large-files @@ -40,7 +40,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 24.4.2 hooks: - id: black language_version: python3 @@ -66,7 +66,7 @@ repos: # Check for type errors with mypy: - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.9.0' + rev: 'v1.10.0' hooks: - id: mypy args: [--disallow-untyped-defs, --ignore-missing-imports] diff --git a/CHANGELOG.md b/CHANGELOG.md index a443b94..ba5a48f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,15 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [0.1.2] - 2024-05-29 ### Added +- Added canonicalization with optimization approach. +- Added evaluating transfer learning capabilities of canonicalizer. +- Added pull request template. +- Added test for discrete invert canonicalization. ### Fixed +- Fixed segmentation evaluation for non-identity canonicalizers. +- Fixed minor bugs in inverse canonicalization for discrete groups. ### Changed - -### Removed +- Updated `README.md` with [Improved Canonicalization for Model Agnostic Equivariance](https://arxiv.org/abs/2405.14089) ([EquiVision](https://equivision.github.io/), CVPR 2024 workshop) paper details. +- Updated `CONTRIBUTING.md` with more information on how to run the code checks. +- Changed the OS used to test Python 3.7 on GitHub actions (macos-latest -> macos-13). ## [0.1.1] - 2024-03-15 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9afc304..245d2c9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -155,17 +155,29 @@ This can easily be done via [Anaconda] or [Miniconda] and detailed [here](https: `git log --graph --decorate --pretty=oneline --abbrev-commit --all` to look for recurring communication patterns. +#### Run code checks -5. Please check that your changes don't break any unit tests with: +Please make sure to see the validation messages from pre-commit and fix any +eventual issues. This should automatically use [flake8]/[black] to check/fix +the code style in a way that is compatible with the project. - ``` - tox - ``` +To run pre-commit manually, you can use: + +``` +pre-commit run --all-files +``` + +Please also check that your changes don't break any unit tests with: + +``` +tox +``` + +(after having installed [tox] with `pip install tox` or `pipx`). - (after having installed [tox] with `pip install tox` or `pipx`). +You can also use [tox] to run several other pre-configured tasks in the +repository. Try `tox -av` to see a list of the available checks. - You can also use [tox] to run several other pre-configured tasks in the - repository. Try `tox -av` to see a list of the available checks. ### Submit your contribution diff --git a/README.md b/README.md index 324cd7a..5fdf567 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,8 @@ You can clone this repository and manually install it with: ## Setup Conda environment for examples +The recommended way is to manually create an environment and install the dependencies from the `min_conda_env.yaml` file. + To create a conda environment with the necessary packages: ``` @@ -168,7 +170,7 @@ You can also find [tutorials](https://github.com/arnab39/equiadapt/blob/main/tut # Related papers and Citations -For more insights on this library refer to our original paper on the idea: [Equivariance with Learned Canonicalization Function (ICML 2023)](https://proceedings.mlr.press/v202/kaba23a.html) and how to extend it to make any existing large pre-trained model equivariant: [Equivariant Adaptation of Large Pretrained Models (NeurIPS 2023)](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9d5856318032ef3630cb580f4e24f823-Abstract-Conference.html). +For more insights on this library refer to our original paper on the idea: [Equivariance with Learned Canonicalization Function (ICML 2023)](https://proceedings.mlr.press/v202/kaba23a.html) and how to extend it to make any existing large pre-trained model equivariant: [Equivariant Adaptation of Large Pretrained Models (NeurIPS 2023)](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9d5856318032ef3630cb580f4e24f823-Abstract-Conference.html). An improved approach for designing canonicalization network, which allows non-equivariant and expressive models as equivariant networks is presented in [Improved Canonicalization for Model Agnostic Equivariance (CVPR 2024: EquiVision Workshop)](https://arxiv.org/abs/2405.14089). If you find this library or the associated papers useful, please cite the following papers: @@ -197,6 +199,17 @@ If you find this library or the associated papers useful, please cite the follow } ``` +``` +@inproceedings{ + panigrahi2024improved, + title={Improved Canonicalization for Model Agnostic Equivariance}, + author={Siba Smarak Panigrahi and Arnab Kumar Mondal}, + booktitle={CVPR 2024 Workshop on Equivariant Vision: From Theory to Practice}, + year={2024}, + url={https://arxiv.org/abs/2405.14089} +} +``` + # Contact For questions related to this code, please raise an issue and you can mail us at: @@ -206,7 +219,7 @@ For questions related to this code, please raise an issue and you can mail us at # Contributing -You can check out the [contributor's guide](https://github.com/arnab39/equiadapt/blob/main/CHANGELOG.md). +You can check out the [contributor's guide](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md). This project uses `pre-commit`, you can install it before making any changes:: diff --git a/equiadapt/images/__init__.py b/equiadapt/images/__init__.py index b670541..bdc5971 100644 --- a/equiadapt/images/__init__.py +++ b/equiadapt/images/__init__.py @@ -22,6 +22,8 @@ RotationEquivariantConvLift, RotoReflectionEquivariantConv, RotoReflectionEquivariantConvLift, + WideResNet50Network, + WideResNet101Network, custom_equivariant_networks, custom_group_equivariant_layers, custom_nonequivariant_networks, @@ -51,6 +53,8 @@ "OptimizedGroupEquivariantImageCanonicalization", "OptimizedSteerableImageCanonicalization", "ResNet18Network", + "WideResNet50Network", + "WideResNet101Network", "RotationEquivariantConv", "RotationEquivariantConvLift", "RotoReflectionEquivariantConv", diff --git a/equiadapt/images/canonicalization_networks/__init__.py b/equiadapt/images/canonicalization_networks/__init__.py index 13b33d0..f48b390 100644 --- a/equiadapt/images/canonicalization_networks/__init__.py +++ b/equiadapt/images/canonicalization_networks/__init__.py @@ -16,6 +16,8 @@ from equiadapt.images.canonicalization_networks.custom_nonequivariant_networks import ( ConvNetwork, ResNet18Network, + WideResNet50Network, + WideResNet101Network, ) from equiadapt.images.canonicalization_networks.escnn_networks import ( ESCNNEquivariantNetwork, @@ -34,6 +36,8 @@ "ESCNNWideBasic", "ESCNNWideBottleneck", "ResNet18Network", + "WideResNet101Network", + "WideResNet50Network", "RotationEquivariantConv", "RotationEquivariantConvLift", "RotoReflectionEquivariantConv", diff --git a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py index 73b5508..6c1d25d 100644 --- a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py @@ -110,7 +110,7 @@ def __init__( out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. """ super().__init__() - self.resnet18 = torchvision.models.resnet18(weights=None) + self.resnet18 = torchvision.models.resnet18(weights="DEFAULT") self.resnet18.fc = nn.Sequential( nn.Linear(512, out_vector_size), ) @@ -128,3 +128,103 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: The output of the network. It has the shape (batch_size, 1). """ return self.resnet18(x) + + +class WideResNet101Network(nn.Module): + """ + This class represents a neural network based on the WideResNetNetwork architecture. + + The network uses a pre-trained WideResNet model. The final fully connected layer of the WideResNet101 model is replaced with a new fully connected layer. + + Attributes: + resnet18 (torchvision.models.ResNet): The ResNet-18 model. + out_vector_size (int): The size of the output vector of the network. + """ + + def __init__( + self, + in_shape: tuple, + out_channels: int, + kernel_size: int, + num_layers: int = 2, + out_vector_size: int = 128, + ): + """ + Initializes the ResNet18Network instance. + + Args: + in_shape (tuple): The shape of the input data. It should be a tuple of the form (in_channels, height, width). + out_channels (int): The number of output channels of the first convolutional layer. + kernel_size (int): The size of the kernel of the convolutional layers. + num_layers (int, optional): The number of convolutional layers. Defaults to 2. + out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. + """ + super().__init__() + self.wideresnet = torchvision.models.wide_resnet101_2(weights="DEFAULT") + self.wideresnet.fc = nn.Sequential( + nn.Linear(2048, out_vector_size), + ) + + self.out_vector_size = out_vector_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs a forward pass through the network. + + Args: + x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). + + Returns: + torch.Tensor: The output of the network. It has the shape (batch_size, 1). + """ + return self.wideresnet(x) + + +class WideResNet50Network(nn.Module): + """ + This class represents a neural network based on the WideResNetNetwork architecture. + + The network uses a pre-trained WideResNet model. The final fully connected layer of the WideResNet50 model is replaced with a new fully connected layer. + + Attributes: + resnet18 (torchvision.models.ResNet): The ResNet-18 model. + out_vector_size (int): The size of the output vector of the network. + """ + + def __init__( + self, + in_shape: tuple, + out_channels: int, + kernel_size: int, + num_layers: int = 2, + out_vector_size: int = 128, + ): + """ + Initializes the ResNet18Network instance. + + Args: + in_shape (tuple): The shape of the input data. It should be a tuple of the form (in_channels, height, width). + out_channels (int): The number of output channels of the first convolutional layer. + kernel_size (int): The size of the kernel of the convolutional layers. + num_layers (int, optional): The number of convolutional layers. Defaults to 2. + out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. + """ + super().__init__() + self.wideresnet = torchvision.models.wide_resnet50_2(weights="DEFAULT") + self.wideresnet.fc = nn.Sequential( + nn.Linear(2048, out_vector_size), + ) + + self.out_vector_size = out_vector_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs a forward pass through the network. + + Args: + x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). + + Returns: + torch.Tensor: The output of the network. It has the shape (batch_size, 1). + """ + return self.wideresnet(x) diff --git a/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml b/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml index 986eae3..8d5d1e1 100644 --- a/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml +++ b/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml @@ -1,5 +1,5 @@ canonicalization_type: opt_group_equivariant -network_type: cnn # Options for canonization method 1) cnn 2) wideresnet +network_type: cnn # Options for canonization method 1) cnn 2) non_equivariant_wrn_50 3) non_equivariant_wrn_101 4) non_equivariant_resnet18 network_hyperparams: kernel_size: 7 # Kernel size for the canonization network out_channels: 16 # Number of output channels for the canonization network diff --git a/examples/images/classification/configs/checkpoint/default.yaml b/examples/images/classification/configs/checkpoint/default.yaml index b479fe7..57075ff 100644 --- a/examples/images/classification/configs/checkpoint/default.yaml +++ b/examples/images/classification/configs/checkpoint/default.yaml @@ -2,3 +2,5 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty for training and dynamically allocated later save_canonized_images: 0 # Whether to save canonized images (1) or not (0) strict_loading: 1 # Whether to strictly load the model (1) or not (0) +prediction_network_checkpoint_path: null # Path to load prediction network checkpoints +prediction_network_checkpoint_name: null # Path to load prediction network checkpoint name diff --git a/examples/images/classification/inference_utils.py b/examples/images/classification/inference_utils.py index 9095df2..617399e 100644 --- a/examples/images/classification/inference_utils.py +++ b/examples/images/classification/inference_utils.py @@ -61,7 +61,9 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): ] # check if the accuracy per class is nan - acc_per_class = [0.0 if math.isnan(acc) else acc for acc in acc_per_class] + acc_per_class = [ + torch.tensor(0.0) if math.isnan(acc) else acc for acc in acc_per_class + ] # Update metrics with accuracy per class metrics.update( @@ -151,7 +153,9 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): ] # check if the accuracy per class is nan - acc_per_class = [0.0 if math.isnan(acc) else acc for acc in acc_per_class] + acc_per_class = [ + torch.tensor(0.0) if math.isnan(acc) else acc for acc in acc_per_class + ] # Update metrics with accuracy per class metrics.update( diff --git a/examples/images/classification/model.py b/examples/images/classification/model.py index 1586713..6fd9a49 100644 --- a/examples/images/classification/model.py +++ b/examples/images/classification/model.py @@ -64,7 +64,7 @@ def training_step(self, batch: torch.Tensor): assert (num_channels, height, width) == self.image_shape training_metrics = {} - loss = 0.0 + loss, acc = 0.0, 0.0 # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation @@ -101,7 +101,6 @@ def training_step(self, batch: torch.Tensor): acc = (preds == y).float().mean() training_metrics.update({"train/task_loss": task_loss, "train/acc": acc}) - training_metrics.update({"train/task_loss": task_loss, "train/acc": acc}) # Add prior regularization loss if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: diff --git a/examples/images/classification/train.py b/examples/images/classification/train.py index 657a212..9d7a731 100644 --- a/examples/images/classification/train.py +++ b/examples/images/classification/train.py @@ -74,13 +74,13 @@ def train_images(hyperparams: DictConfig) -> None: if not hyperparams["experiment"]["run_mode"] == "test": hyperparams["checkpoint"]["checkpoint_name"] = ( - wandb_run.id + str(wandb_run.id) + "_" - + wandb_run.name + + str(wandb_run.name) + "_" - + wandb_run.sweep_id + + str(wandb_run.sweep_id) + "_" - + wandb_run.group + + str(wandb_run.group) ) # set seed diff --git a/examples/images/classification/train_utils.py b/examples/images/classification/train_utils.py index 6e9fc5b..f4f7996 100644 --- a/examples/images/classification/train_utils.py +++ b/examples/images/classification/train_utils.py @@ -2,6 +2,7 @@ import dotenv import pytorch_lightning as pl +import torch from model import ImageClassifierPipeline from omegaconf import DictConfig from prepare import ( @@ -39,6 +40,23 @@ def get_model_pipeline(hyperparams: DictConfig) -> pl.LightningModule: hyperparams=hyperparams, strict=hyperparams.checkpoint.strict_loading, ) + + # load a different (finetuned) prediction network + # to test the transfer learning capabilities of the canonicalizer + if hyperparams.checkpoint.prediction_network_checkpoint_path: + model.load_state_dict( + torch.load( + open( + hyperparams.checkpoint.prediction_network_checkpoint_path + + "/" + + hyperparams.checkpoint.prediction_network_checkpoint_name + + ".ckpt", + mode="rb", + ) + )["state_dict"], + strict=False, + ) + model.freeze() model.eval() else: diff --git a/examples/images/common/utils.py b/examples/images/common/utils.py index a1ce933..59dde53 100644 --- a/examples/images/common/utils.py +++ b/examples/images/common/utils.py @@ -17,6 +17,8 @@ ESCNNSteerableNetwork, ESCNNWRNEquivariantNetwork, ResNet18Network, + WideResNet50Network, + WideResNet101Network, ) @@ -46,7 +48,9 @@ def get_canonicalization_network( }, "opt_group_equivariant": { "cnn": ConvNetwork, - "resnet18": ResNet18Network, + "non_equivariant_resnet18": ResNet18Network, + "non_equivariant_wrn_101": WideResNet101Network, + "non_equivariant_wrn_50": WideResNet50Network, }, "opt_steerable": { "cnn": ConvNetwork, diff --git a/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml b/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml index 986eae3..b0b4f5d 100644 --- a/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml +++ b/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml @@ -1,5 +1,5 @@ canonicalization_type: opt_group_equivariant -network_type: cnn # Options for canonization method 1) cnn 2) wideresnet +network_type: cnn # Options for canonization method 1) cnn 2) non_equivariant_wrn network_hyperparams: kernel_size: 7 # Kernel size for the canonization network out_channels: 16 # Number of output channels for the canonization network @@ -9,6 +9,6 @@ group_type: rotation # Type of group for the canonization network num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 96 # Resize shape for the input +resize_shape: 128 # Resize shape for the input learn_ref_vec: False # Whether to learn the reference vector artifact_err_wt: 0 # Weight for rotation artifact error (specific to image data, for non C4 rotation, for non-equivariant canonicalization networks) diff --git a/examples/images/segmentation/configs/checkpoint/default.yaml b/examples/images/segmentation/configs/checkpoint/default.yaml index a3e9fbd..8470d12 100644 --- a/examples/images/segmentation/configs/checkpoint/default.yaml +++ b/examples/images/segmentation/configs/checkpoint/default.yaml @@ -2,3 +2,5 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later save_canonized_images: 0 # Whether to save canonized images (1) or not (0) strict_loading: 1 # Whether to strictly load the model (1) or not (0) +prediction_network_checkpoint_path: null # Path to load prediction network checkpoints +prediction_network_checkpoint_name: null # Path to load prediction network checkpoint name diff --git a/examples/images/segmentation/inference_utils.py b/examples/images/segmentation/inference_utils.py index a60770e..97ce758 100644 --- a/examples/images/segmentation/inference_utils.py +++ b/examples/images/segmentation/inference_utils.py @@ -76,6 +76,46 @@ def __init__( self.pad = transforms.Pad(math.ceil(in_shape[-2] * 0.4), padding_mode="edge") self.crop = transforms.CenterCrop((in_shape[-2], in_shape[-1])) + def forward( + self, x: torch.Tensor, targets: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # canonicalize the input data + # For the vanilla model, the canonicalization is the identity transformation + _, _, _, outputs = super().forward(x, targets) + + if self.canonicalizer.canonicalization_info_dict: + # if non-identity canonicalization is used, then the outputs will be transformed + rotation_angles = self.canonicalizer.canonicalization_info_dict[ + "group_element" + ]["rotation"] + reflections = ( + self.canonicalizer.canonicalization_info_dict["group_element"][ + "reflection" + ] + if "reflection" + in self.canonicalizer.canonicalization_info_dict["group_element"] + else [None] * len(rotation_angles) + ) + + image_width = x[0].shape[1] + outputs = [ + dict( + boxes=( + flip_boxes(rotate_boxes(output["boxes"], degree, image_width)) + if reflection + else rotate_boxes(output["boxes"], degree, image_width) + ), + labels=output["labels"], + scores=output["scores"], + masks=output["masks"], + ) + for degree, reflection, output in zip( + rotation_angles, reflections, outputs + ) + ] + + return None, None, None, outputs + def get_group_element_wise_maps( self, images: torch.Tensor, targets: torch.Tensor ) -> dict: @@ -105,13 +145,13 @@ def get_group_element_wise_maps( _, _, _, outputs = self.forward(images_rot, targets_transformed) Map = MeanAveragePrecision(iou_type="segm") - targets = [ + _targets = [ dict( boxes=target["boxes"], labels=target["labels"], masks=target["masks"], ) - for target in targets + for target in targets_transformed ] outputs = [ dict( @@ -122,7 +162,7 @@ def get_group_element_wise_maps( ) for output in outputs ] - Map.update(outputs, targets) + Map.update(outputs, _targets) map_dict[rot] = Map.compute() @@ -191,30 +231,18 @@ def get_inference_metrics( for i in range(self.num_group_elements): metrics.update( { - f"test/map_group_element_{i}": max(map_dict[i]["map"], 0.0), - f"test/map_small_group_element_{i}": max( - map_dict[i]["map_small"], 0.0 - ), - f"test/map_medium_group_element_{i}": max( - map_dict[i]["map_medium"], 0.0 - ), - f"test/map_large_group_element_{i}": max( - map_dict[i]["map_large"], 0.0 - ), - f"test/map_50_group_element_{i}": max(map_dict[i]["map_50"], 0.0), - f"test/map_75_group_element_{i}": max(map_dict[i]["map_75"], 0.0), - f"test/mar_1_group_element_{i}": max(map_dict[i]["mar_1"], 0.0), - f"test/mar_10_group_element_{i}": max(map_dict[i]["mar_10"], 0.0), - f"test/mar_100_group_element_{i}": max(map_dict[i]["mar_100"], 0.0), - f"test/mar_small_group_element_{i}": max( - map_dict[i]["mar_small"], 0.0 - ), - f"test/mar_medium_group_element_{i}": max( - map_dict[i]["mar_medium"], 0.0 - ), - f"test/mar_large_group_element_{i}": max( - map_dict[i]["mar_large"], 0.0 - ), + f"test/map_group_element_{i}": map_dict[i]["map"], + f"test/map_small_group_element_{i}": map_dict[i]["map_small"], + f"test/map_medium_group_element_{i}": map_dict[i]["map_medium"], + f"test/map_large_group_element_{i}": map_dict[i]["map_large"], + f"test/map_50_group_element_{i}": map_dict[i]["map_50"], + f"test/map_75_group_element_{i}": map_dict[i]["map_75"], + f"test/mar_1_group_element_{i}": map_dict[i]["mar_1"], + f"test/mar_10_group_element_{i}": map_dict[i]["mar_10"], + f"test/mar_100_group_element_{i}": map_dict[i]["mar_100"], + f"test/mar_small_group_element_{i}": map_dict[i]["mar_small"], + f"test/mar_medium_group_element_{i}": map_dict[i]["mar_medium"], + f"test/mar_large_group_element_{i}": map_dict[i]["mar_large"], } ) diff --git a/examples/images/segmentation/train.py b/examples/images/segmentation/train.py index b87dd31..b70cf52 100644 --- a/examples/images/segmentation/train.py +++ b/examples/images/segmentation/train.py @@ -86,13 +86,13 @@ def train_images(hyperparams: DictConfig) -> None: if not hyperparams["experiment"]["run_mode"] == "test": hyperparams["checkpoint"]["checkpoint_name"] = ( - wandb_run.id + str(wandb_run.id) + "_" - + wandb_run.name + + str(wandb_run.name) + "_" - + wandb_run.sweep_id + + str(wandb_run.sweep_id) + "_" - + wandb_run.group + + str(wandb_run.group) ) # set seed diff --git a/examples/images/segmentation/train_utils.py b/examples/images/segmentation/train_utils.py index dc282e9..48e6d10 100644 --- a/examples/images/segmentation/train_utils.py +++ b/examples/images/segmentation/train_utils.py @@ -2,6 +2,7 @@ import dotenv import pytorch_lightning as pl +import torch from model import ImageSegmentationPipeline from omegaconf import DictConfig from prepare import COCODataModule @@ -33,6 +34,22 @@ def get_model_pipeline(hyperparams: DictConfig) -> pl.LightningModule: hyperparams=hyperparams, strict=hyperparams.checkpoint.strict_loading, ) + + # load a different (finetuned) prediction network + # to test the transfer learning capabilities of the canonicalizer + if hyperparams.checkpoint.prediction_network_checkpoint_path: + model.load_state_dict( + torch.load( + open( + hyperparams.checkpoint.prediction_network_checkpoint_path + + "/" + + hyperparams.checkpoint.prediction_network_checkpoint_name + + ".ckpt", + mode="rb", + ) + )["state_dict"], + strict=False, + ) model.freeze() model.eval() else: diff --git a/setup.cfg b/setup.cfg index 88303f6..8b1ddc0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ packages = find_namespace: include_package_data = True # Require a min/specific Python version (comma-separated conditions) -python_requires = >=3.7, <3.11 +python_requires = >=3.7 # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in