Skip to content

Commit

Permalink
Fix LabelSampler and add test (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
GFabien authored Sep 21, 2020
1 parent 040be1c commit 23b5e56
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
26 changes: 25 additions & 1 deletion tests/data/sampler/test_label_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_label_probabilities(self):
fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
assert torch.all(probabilities.squeeze().eq(fixture))

def test_incosistent_shape(self):
def test_inconsistent_shape(self):
# https://github.com/fepegar/torchio/issues/234#issuecomment-675029767
sample = torchio.Subject(
im1=torchio.ScalarImage(tensor=torch.rand(2, 4, 5, 6)),
Expand All @@ -34,3 +34,27 @@ def test_incosistent_shape(self):
patch_size = 2
sampler = LabelSampler(patch_size, 'im2')
next(sampler(sample))

def test_multichannel_label_sampler(self):
sample = torchio.Subject(
label=torchio.LabelMap(
tensor=torch.tensor(
[
[[[1, 1]]],
[[[0, 1]]]
]
)
)
)
patch_size = 1
sampler = LabelSampler(
patch_size,
'label',
label_probabilities={0: 1, 1: 1}
)
# There are 2 voxels in the image, channels have same probabilities,
# 1st voxel has probability 0.5 * 0.5 + 0 * 0.5 of being chosen while
# 2nd voxel has probability 0.5 * 0.5 + 1 * 0.5 of being chosen.
probabilities = sampler.get_probability_map(sample)
fixture = torch.Tensor((1 / 4, 3 / 4))
assert torch.all(probabilities.squeeze().eq(fixture))
1 change: 0 additions & 1 deletion torchio/data/sampler/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def get_probabilities_from_label_map(
if not label_size:
continue
prob_voxels = label_probability / label_size
probability_map[mask] = prob_voxels
if multichannel:
probability_map[label] = prob_voxels * mask
else:
Expand Down

0 comments on commit 23b5e56

Please sign in to comment.