diff --git a/bids/analysis/tests/test_transformations.py b/bids/analysis/tests/test_transformations.py index 892721a9e..3179cb3a5 100644 --- a/bids/analysis/tests/test_transformations.py +++ b/bids/analysis/tests/test_transformations.py @@ -4,6 +4,7 @@ from bids.variables.entities import RunInfo from bids.variables.kollekshuns import BIDSRunVariableCollection from bids.layout import BIDSLayout +import math import pytest from os.path import join, sep from bids.tests import get_test_data_path @@ -16,14 +17,30 @@ import mock +# Sub-select collection for faster testing, without sacrificing anything +# in the tests +SUBJECTS = ['01', '02'] +NRUNS = 3 +SCAN_LENGTH = 480 + +cached_collections = {} + + @pytest.fixture def collection(): - layout_path = join(get_test_data_path(), 'ds005') - layout = BIDSLayout(layout_path) - collection = layout.get_collections('run', types=['events'], - scan_length=480, merge=True, - sampling_rate=10) - return collection + if 'ds005' not in cached_collections: + layout_path = join(get_test_data_path(), 'ds005') + layout = BIDSLayout(layout_path) + cached_collections['ds005'] = layout.get_collections( + 'run', + types=['events'], + scan_length=SCAN_LENGTH, + merge=True, + sampling_rate=10, + subject=SUBJECTS + ) + # Always return a clone! + return cached_collections['ds005'].clone() @pytest.fixture @@ -77,8 +94,8 @@ def test_convolve(collection): def test_rename(collection): - dense_rt = collection.variables['RT'].to_dense(10) - assert len(dense_rt.values) == 230400 + dense_rt = collection.variables['RT'].to_dense(collection.sampling_rate) + assert len(dense_rt.values) == math.ceil(len(SUBJECTS) * NRUNS * SCAN_LENGTH * collection.sampling_rate) transform.Rename(collection, 'RT', output='reaction_time') assert 'reaction_time' in collection.variables assert 'RT' not in collection.variables @@ -132,9 +149,10 @@ def test_demean(collection): def test_orthogonalize_dense(collection): transform.Factor(collection, 'trial_type', sep=sep) + sampling_rate = collection.sampling_rate # Store pre-orth variables needed for tests - pg_pre = collection['trial_type/parametric gain'].to_dense(10) - rt = collection['RT'].to_dense(10) + pg_pre = collection['trial_type/parametric gain'].to_dense(sampling_rate) + rt = collection['RT'].to_dense(sampling_rate) # Orthogonalize and store result transform.Orthogonalize(collection, variables='trial_type/parametric gain', @@ -174,7 +192,8 @@ def test_split(collection): orig = collection['RT'].clone(name='RT_2') collection['RT_2'] = orig.clone() - collection['RT_3'] = collection['RT'].clone(name='RT_3').to_dense(10) + collection['RT_3'] = collection['RT']\ + .clone(name='RT_3').to_dense(collection.sampling_rate) rt_pre_onsets = collection['RT'].onset @@ -208,14 +227,22 @@ def test_split(collection): def test_resample_dense(collection): - collection['RT'] = collection['RT'].to_dense(10) + new_sampling_rate = 50 + old_sampling_rate = collection.sampling_rate + upsampling = float(new_sampling_rate) / old_sampling_rate + + collection['RT'] = collection['RT'].to_dense(old_sampling_rate) old_rt = collection['RT'].clone() - collection.resample(50, in_place=True) - assert len(old_rt.values) * 5 == len(collection['RT'].values) + collection.resample(new_sampling_rate, in_place=True) + assert math.floor(len(old_rt.values) * upsampling) == len(collection['RT'].values) # Should work after explicitly converting categoricals transform.Factor(collection, 'trial_type') - collection.resample(5, force_dense=True, in_place=True) - assert len(old_rt.values) == len(collection['parametric gain'].values) * 2 + + new_sampling_rate2 = 5 + upsampling2 = float(new_sampling_rate2) / old_sampling_rate + + collection.resample(new_sampling_rate2, force_dense=True, in_place=True) + assert len(old_rt.values) == math.ceil(float(len(collection['parametric gain'].values) / upsampling2)) def test_threshold(collection): @@ -346,12 +373,13 @@ def test_filter(collection): transform.Filter(collection, 'RT', query=q, by='parametric gain') assert len(orig.values) != len(collection['RT'].values) # There is some bizarro thing going on where, on travis, the result is - # randomly either 1536 or 3909 when running on Python 3 (on linux or mac). + # randomly either 1536 (when all subjects are used) or 3909 when running + # on Python 3 (on linux or mac). # Never happens locally, and I've had no luck tracking down the problem. # Best guess is it reflects either some non-deterministic ordering of # variables somewhere, or some weird precision issues when resampling to # dense. Needs to be tracked down and fixed. - assert len(collection['RT'].values) in [1536, 3909] + assert len(collection['RT'].values) in [96 * len(SUBJECTS), 3909] def test_replace(collection):