Skip to content

Commit

Permalink
Added unit tests for dataset_utils.py and gratings_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ArashAkbarinia committed Dec 15, 2023
1 parent 4b67dec commit ffea063
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 1 deletion.
2 changes: 2 additions & 0 deletions osculari/datasets/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def background_img(bg_type: Any, bg_size: Union[int, Tuple], im2double=True) ->
num_colours = np.random.randint(3, 25)
num_patches = np.random.randint(2, bg_size[0] // 20)
bg_img = _patch_img(bg_size, num_colours, num_patches, channels)
if 'achromatic' in bg_type:
bg_img = np.repeat(bg_img, 3, axis=2)
else:
raise RuntimeError('Unsupported background type %s.' % bg_type)
# Handle user-specified background values
Expand Down
2 changes: 1 addition & 1 deletion osculari/datasets/gratings.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class GratingsDataset(TorchDataset):
"""

def __init__(self, img_size: int, spatial_frequencies: Optional[Sequence[int]] = None,
thetas: Optional[Sequence[int]] = None, gaussian_sigma: Optional[float] = None,
thetas: Optional[Sequence[float]] = None, gaussian_sigma: Optional[float] = None,
transform: Optional[Callable] = None) -> None:
super(GratingsDataset, self).__init__()
self.img_size = img_size
Expand Down
84 changes: 84 additions & 0 deletions tests/datasets/dataset_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Unit tests for dataset_utils.py
"""

import pytest
import numpy as np

from osculari.datasets import dataset_utils


def test_background_uniform_achromatic():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('uniform_achromatic', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.allclose(bg_img, bg_img[0, 0, :])


def test_background_uniform_colour():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('uniform_colour', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.allclose(bg_img, bg_img[0, 0, :])


def test_background_random_achromatic():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('random_achromatic', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.unique(bg_img).shape[0] > 1


def test_background_random_achromatic_pixelwise():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('random_achromatic', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.unique(bg_img).shape[0] > 1
assert np.all(np.equal(bg_img[..., 0], bg_img[..., 1]))


def test_background_random_colour():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('random_colour', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.unique(bg_img).shape[0] > 1


def test_background_patch_achromatic():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('patch_achromatic', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.unique(bg_img).shape[0] > 1


def test_background_patch_achromatic_pixelwise():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('patch_achromatic', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.unique(bg_img).shape[0] > 1
assert np.all(np.equal(bg_img[..., 0], bg_img[..., 1]))


def test_background_patch_colour():
bg_size = (256, 256)
bg_img = dataset_utils.background_img('patch_colour', bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.unique(bg_img).shape[0] > 1


def test_background_uniform_value():
bg_size = (256, 256)
bg_value = 0.5
bg_img = dataset_utils.background_img(bg_value, bg_size)
assert bg_img.shape == (*bg_size, 3)
assert np.allclose(bg_img, bg_value)


def test_background_invalid_type():
with pytest.raises(RuntimeError, match='Unsupported background type'):
_ = dataset_utils.background_img('invalid_type', (256, 256))


def test_background_invalid_value_type():
with pytest.raises(RuntimeError, match='Unsupported background type'):
_ = dataset_utils.background_img(None, (256, 256))
61 changes: 61 additions & 0 deletions tests/datasets/gratings_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Unit tests for gratings.py
"""

import pytest
import numpy as np
import torch
import torchvision.transforms as torch_transforms

from osculari.datasets import GratingsDataset


def test_gratings_dataset_len():
# Test the __len__ method of GratingsDataset
img_size = 64
dataset = GratingsDataset(img_size=img_size)
expected_length = len(dataset.thetas) * len(dataset.sfs)
assert len(dataset) == expected_length


def test_gratings_dataset_make_grating():
# Test the make_grating method of GratingsDataset
img_size = 64
dataset = GratingsDataset(img_size=img_size)
idx = 0
amplitude = 1.0
channels = 3
grating = dataset.make_grating(idx, amplitude, channels)
assert isinstance(grating, np.ndarray)
assert grating.shape == (img_size, img_size, channels)


def test_gratings_dataset_getitem():
# Test the __getitem__ method of GratingsDataset
img_size = 64
dataset = GratingsDataset(img_size=img_size)

# Test without transformation
idx = 0
grating = dataset[idx]
assert isinstance(grating, np.ndarray)
assert grating.shape == (img_size, img_size, 3)

# Test with transformation
transform = torch_transforms.Compose([torch_transforms.ToTensor()])
dataset.transform = transform
grating = dataset[idx]
assert isinstance(grating, torch.Tensor)
assert grating.shape == (3, img_size, img_size)


def test_gratings_dataset_with_gaussian():
# Test the make_grating method of GratingsDataset
img_size = 64
dataset = GratingsDataset(img_size=img_size, gaussian_sigma=0.5)
idx = 0
amplitude = 1.0
channels = 3
grating = dataset.make_grating(idx, amplitude, channels)
assert isinstance(grating, np.ndarray)
assert grating.shape == (img_size, img_size, channels)

0 comments on commit ffea063

Please sign in to comment.