From f342ca0a9ac49a28603942dddd7368ee3e6c2274 Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Wed, 5 Jun 2024 02:13:11 -0700 Subject: [PATCH 1/2] Add unit tests for transforms --- piqture/data_loader/minmax_normalization.py | 14 ----- piqture/data_loader/mnist_data_loader.py | 10 +-- piqture/transforms/__init__.py | 19 ++++++ piqture/transforms/transforms.py | 47 ++++++++++++++ tests/transforms/__init__.py | 0 tests/transforms/test_transforms.py | 70 +++++++++++++++++++++ 6 files changed, 137 insertions(+), 23 deletions(-) delete mode 100644 piqture/data_loader/minmax_normalization.py create mode 100644 piqture/transforms/__init__.py create mode 100644 piqture/transforms/transforms.py create mode 100644 tests/transforms/__init__.py create mode 100644 tests/transforms/test_transforms.py diff --git a/piqture/data_loader/minmax_normalization.py b/piqture/data_loader/minmax_normalization.py deleted file mode 100644 index a0037ca..0000000 --- a/piqture/data_loader/minmax_normalization.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Class for Min-Max Normalization of datasets.""" - - -# pylint: disable=too-few-public-methods -class MinMaxNormalization: - """Normalizes input values in range [min, max].""" - - def __init__(self, normalize_min, normalize_max): - self.min = normalize_min - self.max = normalize_max - - def __call__(self, x): - """Normalizes data to a range [min, max].""" - return (x - self.min) / (self.max - self.min) diff --git a/piqture/data_loader/mnist_data_loader.py b/piqture/data_loader/mnist_data_loader.py index 1e61398..44d4673 100644 --- a/piqture/data_loader/mnist_data_loader.py +++ b/piqture/data_loader/mnist_data_loader.py @@ -16,7 +16,7 @@ import torch.utils.data import torchvision from torchvision import datasets -from piqture.data_loader.minmax_normalization import MinMaxNormalization +from piqture.transforms import MinMaxNormalization def load_mnist_dataset( @@ -65,14 +65,6 @@ def load_mnist_dataset( raise TypeError("The input labels must be of the type list.") if normalize_max and normalize_min: - # Check if normalize_min and max are int or float. - if not isinstance(normalize_max, (int, float)) and not isinstance( - normalize_min, (int, float) - ): - raise TypeError( - "The inputs normalize_min and normlaize_max must be of the type int or float." - ) - # Define a custom mnist transforms. mnist_transform = torchvision.transforms.Compose( [ diff --git a/piqture/transforms/__init__.py b/piqture/transforms/__init__.py new file mode 100644 index 0000000..59d9ea7 --- /dev/null +++ b/piqture/transforms/__init__.py @@ -0,0 +1,19 @@ +# (C) Copyright SaashaJoshi 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Transforms (module: piqture.data_loader) +""" + +from .transforms import MinMaxNormalization + +__all__ = [ + "MinMaxNormalization", +] diff --git a/piqture/transforms/transforms.py b/piqture/transforms/transforms.py new file mode 100644 index 0000000..cf4fd52 --- /dev/null +++ b/piqture/transforms/transforms.py @@ -0,0 +1,47 @@ +# (C) Copyright SaashaJoshi 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Data transforms for Pytorch datasets.""" + +from typing import Union +import torch +from torch import Tensor + + +# pylint: disable=too-few-public-methods +class MinMaxNormalization: + """Normalizes input values in range [min, max].""" + + def __init__( + self, normalize_min: Union[int, float], normalize_max: Union[int, float] + ): + # Check if normalize_min and max are int or float. + if not isinstance(normalize_max, (int, float)) or isinstance( + normalize_max, bool + ): + raise TypeError("The input normalize_max must be of the type int or float.") + if not isinstance(normalize_min, (int, float)) or isinstance( + normalize_min, bool + ): + raise TypeError("The input normalize_min must be of the type int or float.") + self.min = normalize_min + self.max = normalize_max + + def __repr__(self): + """MinMaxNormalization transform representation.""" + return ( + f"{__class__.__name__}(normalize_min={self.min}, normalize_max={self.max})" + ) + + def __call__(self, x: Tensor) -> Tensor: + """Normalizes data to a range [min, max].""" + return self.min + ( + (x - torch.min(x)) * (self.max - self.min) / (torch.max(x) - torch.min(x)) + ) diff --git a/tests/transforms/__init__.py b/tests/transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py new file mode 100644 index 0000000..8bd4750 --- /dev/null +++ b/tests/transforms/test_transforms.py @@ -0,0 +1,70 @@ +# (C) Copyright SaashaJoshi 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Unit test for transforms""" + +from __future__ import annotations +import numpy as np +import pytest +import torch +from pytest import raises +from piqture.transforms.transforms import MinMaxNormalization + + +class TestMinMaxNormalization: + """Test class for MinMaxNormalization transform.""" + + @pytest.mark.parametrize( + "normalize_min, normalize_max", [(0, 1), (-np.pi, np.pi), (0, np.pi / 2)] + ) + def test_repr(self, normalize_min, normalize_max): + """Tests MinMaxNormalization class representation.""" + result = f"MinMaxNormalization(normalize_min={normalize_min}, normalize_max={normalize_max})" + assert result == repr(MinMaxNormalization(normalize_min, normalize_max)) + + @pytest.mark.parametrize( + "normalize_min, normalize_max", + [(None, None), ({}, []), ("12", "abc"), (True, False)], + ) + def test_min(self, normalize_min, normalize_max): + """Tests the normalize_min inputs""" + with raises( + TypeError, match="The input normalize_max must be of the type int or float." + ): + _ = MinMaxNormalization(1, normalize_max) + + with raises( + TypeError, match="The input normalize_min must be of the type int or float." + ): + _ = MinMaxNormalization(normalize_min, 2.3) + + @pytest.mark.parametrize( + "normalize_min, normalize_max, x, output", + [ + (0, 1, torch.Tensor([1, 2, 3, 4]), torch.Tensor([0, 0.3333, 0.6667, 1])), + ( + -np.pi, + np.pi, + torch.Tensor([251, 252, 253, 254]), + torch.Tensor([-np.pi, -np.pi / 3, np.pi / 3, np.pi]), + ), + ( + 0, + np.pi / 2, + torch.Tensor([1.8, 2.1, 3.2, 4.5]), + torch.Tensor([0, np.pi / 18, (7 * np.pi) / 27, np.pi / 2]), + ), + ], + ) + def test_minmax_transform(self, normalize_min, normalize_max, x, output): + """Tests the transform output.""" + transform = MinMaxNormalization(normalize_min, normalize_max) + result = transform(x) + assert torch.allclose(result, output, atol=1e-5, rtol=1e-4) From 24dcf3972f369aac86ecefce21e31493008272ca Mon Sep 17 00:00:00 2001 From: SaashaJoshi Date: Wed, 5 Jun 2024 02:25:19 -0700 Subject: [PATCH 2/2] linting --- tests/transforms/test_transforms.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 8bd4750..e330bb7 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -26,7 +26,10 @@ class TestMinMaxNormalization: ) def test_repr(self, normalize_min, normalize_max): """Tests MinMaxNormalization class representation.""" - result = f"MinMaxNormalization(normalize_min={normalize_min}, normalize_max={normalize_max})" + result = ( + f"MinMaxNormalization(normalize_min={normalize_min}, " + f"normalize_max={normalize_max})" + ) assert result == repr(MinMaxNormalization(normalize_min, normalize_max)) @pytest.mark.parametrize(