-
Notifications
You must be signed in to change notification settings - Fork 240
/
Copy pathrandom_noise.py
32 lines (25 loc) · 948 Bytes
/
random_noise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from ..torchio import INTENSITY
from ..utils import is_image_dict
from .random_transform import RandomTransform
class RandomNoise(RandomTransform):
def __init__(self, std_range=(0, 0.25), seed=None, verbose=False):
super().__init__(seed=seed, verbose=verbose)
self.std_range = std_range
def apply_transform(self, sample):
std = self.get_params(self.std_range)
sample['random_noise'] = std
for image_dict in sample.values():
if not is_image_dict(image_dict):
continue
if image_dict['type'] != INTENSITY:
continue
add_noise(image_dict['data'], std)
return sample
@staticmethod
def get_params(std_range):
std = torch.FloatTensor(1).uniform_(*std_range).item()
return std
def add_noise(data, std):
noise = torch.FloatTensor(*data.shape).normal_(mean=0, std=std)
data += noise