-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added unit tests for dataset_utils.py and gratings_test.py
- Loading branch information
1 parent
4b67dec
commit ffea063
Showing
4 changed files
with
148 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
Unit tests for dataset_utils.py | ||
""" | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
from osculari.datasets import dataset_utils | ||
|
||
|
||
def test_background_uniform_achromatic(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('uniform_achromatic', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.allclose(bg_img, bg_img[0, 0, :]) | ||
|
||
|
||
def test_background_uniform_colour(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('uniform_colour', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.allclose(bg_img, bg_img[0, 0, :]) | ||
|
||
|
||
def test_background_random_achromatic(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('random_achromatic', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.unique(bg_img).shape[0] > 1 | ||
|
||
|
||
def test_background_random_achromatic_pixelwise(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('random_achromatic', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.unique(bg_img).shape[0] > 1 | ||
assert np.all(np.equal(bg_img[..., 0], bg_img[..., 1])) | ||
|
||
|
||
def test_background_random_colour(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('random_colour', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.unique(bg_img).shape[0] > 1 | ||
|
||
|
||
def test_background_patch_achromatic(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('patch_achromatic', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.unique(bg_img).shape[0] > 1 | ||
|
||
|
||
def test_background_patch_achromatic_pixelwise(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('patch_achromatic', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.unique(bg_img).shape[0] > 1 | ||
assert np.all(np.equal(bg_img[..., 0], bg_img[..., 1])) | ||
|
||
|
||
def test_background_patch_colour(): | ||
bg_size = (256, 256) | ||
bg_img = dataset_utils.background_img('patch_colour', bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.unique(bg_img).shape[0] > 1 | ||
|
||
|
||
def test_background_uniform_value(): | ||
bg_size = (256, 256) | ||
bg_value = 0.5 | ||
bg_img = dataset_utils.background_img(bg_value, bg_size) | ||
assert bg_img.shape == (*bg_size, 3) | ||
assert np.allclose(bg_img, bg_value) | ||
|
||
|
||
def test_background_invalid_type(): | ||
with pytest.raises(RuntimeError, match='Unsupported background type'): | ||
_ = dataset_utils.background_img('invalid_type', (256, 256)) | ||
|
||
|
||
def test_background_invalid_value_type(): | ||
with pytest.raises(RuntimeError, match='Unsupported background type'): | ||
_ = dataset_utils.background_img(None, (256, 256)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Unit tests for gratings.py | ||
""" | ||
|
||
import pytest | ||
import numpy as np | ||
import torch | ||
import torchvision.transforms as torch_transforms | ||
|
||
from osculari.datasets import GratingsDataset | ||
|
||
|
||
def test_gratings_dataset_len(): | ||
# Test the __len__ method of GratingsDataset | ||
img_size = 64 | ||
dataset = GratingsDataset(img_size=img_size) | ||
expected_length = len(dataset.thetas) * len(dataset.sfs) | ||
assert len(dataset) == expected_length | ||
|
||
|
||
def test_gratings_dataset_make_grating(): | ||
# Test the make_grating method of GratingsDataset | ||
img_size = 64 | ||
dataset = GratingsDataset(img_size=img_size) | ||
idx = 0 | ||
amplitude = 1.0 | ||
channels = 3 | ||
grating = dataset.make_grating(idx, amplitude, channels) | ||
assert isinstance(grating, np.ndarray) | ||
assert grating.shape == (img_size, img_size, channels) | ||
|
||
|
||
def test_gratings_dataset_getitem(): | ||
# Test the __getitem__ method of GratingsDataset | ||
img_size = 64 | ||
dataset = GratingsDataset(img_size=img_size) | ||
|
||
# Test without transformation | ||
idx = 0 | ||
grating = dataset[idx] | ||
assert isinstance(grating, np.ndarray) | ||
assert grating.shape == (img_size, img_size, 3) | ||
|
||
# Test with transformation | ||
transform = torch_transforms.Compose([torch_transforms.ToTensor()]) | ||
dataset.transform = transform | ||
grating = dataset[idx] | ||
assert isinstance(grating, torch.Tensor) | ||
assert grating.shape == (3, img_size, img_size) | ||
|
||
|
||
def test_gratings_dataset_with_gaussian(): | ||
# Test the make_grating method of GratingsDataset | ||
img_size = 64 | ||
dataset = GratingsDataset(img_size=img_size, gaussian_sigma=0.5) | ||
idx = 0 | ||
amplitude = 1.0 | ||
channels = 3 | ||
grating = dataset.make_grating(idx, amplitude, channels) | ||
assert isinstance(grating, np.ndarray) | ||
assert grating.shape == (img_size, img_size, channels) |