diff --git a/osculari/datasets/dataset_utils.py b/osculari/datasets/dataset_utils.py index e4c1064..84ac82e 100644 --- a/osculari/datasets/dataset_utils.py +++ b/osculari/datasets/dataset_utils.py @@ -186,6 +186,8 @@ def background_img(bg_type: Any, bg_size: Union[int, Tuple], im2double=True) -> num_colours = np.random.randint(3, 25) num_patches = np.random.randint(2, bg_size[0] // 20) bg_img = _patch_img(bg_size, num_colours, num_patches, channels) + if 'achromatic' in bg_type: + bg_img = np.repeat(bg_img, 3, axis=2) else: raise RuntimeError('Unsupported background type %s.' % bg_type) # Handle user-specified background values diff --git a/osculari/datasets/gratings.py b/osculari/datasets/gratings.py index 8a103f9..1b4fa7d 100644 --- a/osculari/datasets/gratings.py +++ b/osculari/datasets/gratings.py @@ -91,7 +91,7 @@ class GratingsDataset(TorchDataset): """ def __init__(self, img_size: int, spatial_frequencies: Optional[Sequence[int]] = None, - thetas: Optional[Sequence[int]] = None, gaussian_sigma: Optional[float] = None, + thetas: Optional[Sequence[float]] = None, gaussian_sigma: Optional[float] = None, transform: Optional[Callable] = None) -> None: super(GratingsDataset, self).__init__() self.img_size = img_size diff --git a/tests/datasets/dataset_utils_test.py b/tests/datasets/dataset_utils_test.py new file mode 100644 index 0000000..4be016c --- /dev/null +++ b/tests/datasets/dataset_utils_test.py @@ -0,0 +1,84 @@ +""" +Unit tests for dataset_utils.py +""" + +import pytest +import numpy as np + +from osculari.datasets import dataset_utils + + +def test_background_uniform_achromatic(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('uniform_achromatic', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.allclose(bg_img, bg_img[0, 0, :]) + + +def test_background_uniform_colour(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('uniform_colour', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.allclose(bg_img, bg_img[0, 0, :]) + + +def test_background_random_achromatic(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('random_achromatic', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.unique(bg_img).shape[0] > 1 + + +def test_background_random_achromatic_pixelwise(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('random_achromatic', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.unique(bg_img).shape[0] > 1 + assert np.all(np.equal(bg_img[..., 0], bg_img[..., 1])) + + +def test_background_random_colour(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('random_colour', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.unique(bg_img).shape[0] > 1 + + +def test_background_patch_achromatic(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('patch_achromatic', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.unique(bg_img).shape[0] > 1 + + +def test_background_patch_achromatic_pixelwise(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('patch_achromatic', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.unique(bg_img).shape[0] > 1 + assert np.all(np.equal(bg_img[..., 0], bg_img[..., 1])) + + +def test_background_patch_colour(): + bg_size = (256, 256) + bg_img = dataset_utils.background_img('patch_colour', bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.unique(bg_img).shape[0] > 1 + + +def test_background_uniform_value(): + bg_size = (256, 256) + bg_value = 0.5 + bg_img = dataset_utils.background_img(bg_value, bg_size) + assert bg_img.shape == (*bg_size, 3) + assert np.allclose(bg_img, bg_value) + + +def test_background_invalid_type(): + with pytest.raises(RuntimeError, match='Unsupported background type'): + _ = dataset_utils.background_img('invalid_type', (256, 256)) + + +def test_background_invalid_value_type(): + with pytest.raises(RuntimeError, match='Unsupported background type'): + _ = dataset_utils.background_img(None, (256, 256)) diff --git a/tests/datasets/gratings_test.py b/tests/datasets/gratings_test.py new file mode 100644 index 0000000..bbd45dd --- /dev/null +++ b/tests/datasets/gratings_test.py @@ -0,0 +1,61 @@ +""" +Unit tests for gratings.py +""" + +import pytest +import numpy as np +import torch +import torchvision.transforms as torch_transforms + +from osculari.datasets import GratingsDataset + + +def test_gratings_dataset_len(): + # Test the __len__ method of GratingsDataset + img_size = 64 + dataset = GratingsDataset(img_size=img_size) + expected_length = len(dataset.thetas) * len(dataset.sfs) + assert len(dataset) == expected_length + + +def test_gratings_dataset_make_grating(): + # Test the make_grating method of GratingsDataset + img_size = 64 + dataset = GratingsDataset(img_size=img_size) + idx = 0 + amplitude = 1.0 + channels = 3 + grating = dataset.make_grating(idx, amplitude, channels) + assert isinstance(grating, np.ndarray) + assert grating.shape == (img_size, img_size, channels) + + +def test_gratings_dataset_getitem(): + # Test the __getitem__ method of GratingsDataset + img_size = 64 + dataset = GratingsDataset(img_size=img_size) + + # Test without transformation + idx = 0 + grating = dataset[idx] + assert isinstance(grating, np.ndarray) + assert grating.shape == (img_size, img_size, 3) + + # Test with transformation + transform = torch_transforms.Compose([torch_transforms.ToTensor()]) + dataset.transform = transform + grating = dataset[idx] + assert isinstance(grating, torch.Tensor) + assert grating.shape == (3, img_size, img_size) + + +def test_gratings_dataset_with_gaussian(): + # Test the make_grating method of GratingsDataset + img_size = 64 + dataset = GratingsDataset(img_size=img_size, gaussian_sigma=0.5) + idx = 0 + amplitude = 1.0 + channels = 3 + grating = dataset.make_grating(idx, amplitude, channels) + assert isinstance(grating, np.ndarray) + assert grating.shape == (img_size, img_size, channels)