-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtorch_rbf.py
114 lines (91 loc) · 3.42 KB
/
torch_rbf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
# RBF Layer
class RBF(nn.Module):
"""
Transforms incoming data using a given radial basis function:
u_{i} = rbf(||x - c_{i}|| / s_{i})
Arguments:
in_features: size of each input sample
out_features: size of each output sample
Shape:
- Input: (N, in_features) where N is an arbitrary batch size
- Output: (N, out_features) where N is an arbitrary batch size
Attributes:
centres: the learnable centres of shape (out_features, in_features).
The values are initialised from a standard normal distribution.
Normalising inputs to have mean 0 and standard deviation 1 is
recommended.
log_sigmas: logarithm of the learnable scaling factors of shape (out_features).
basis_func: the radial basis function used to transform the scaled
distances.
"""
def __init__(self, in_features, out_features, basis_func):
super(RBF, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.centres = nn.Parameter(torch.Tensor(out_features, in_features))
self.log_sigmas = nn.Parameter(torch.Tensor(out_features))
self.basis_func = basis_func
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.centres, 0, 1)
nn.init.constant_(self.log_sigmas, 0)
def forward(self, input):
size = (input.size(0), self.out_features, self.in_features)
x = input.unsqueeze(1).expand(size)
c = self.centres.unsqueeze(0).expand(size)
distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp(self.log_sigmas).unsqueeze(0)
return self.basis_func(distances)
# RBFs
def gaussian(alpha):
phi = torch.exp(-1*alpha.pow(2))
return phi
def linear(alpha):
phi = alpha
return phi
def quadratic(alpha):
phi = alpha.pow(2)
return phi
def inverse_quadratic(alpha):
phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2))
return phi
def multiquadric(alpha):
phi = (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
return phi
def inverse_multiquadric(alpha):
phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
return phi
def spline(alpha):
phi = (alpha.pow(2) * torch.log(alpha + torch.ones_like(alpha)))
return phi
def poisson_one(alpha):
phi = (alpha - torch.ones_like(alpha)) * torch.exp(-alpha)
return phi
def poisson_two(alpha):
phi = ((alpha - 2*torch.ones_like(alpha)) / 2*torch.ones_like(alpha)) \
* alpha * torch.exp(-alpha)
return phi
def matern32(alpha):
phi = (torch.ones_like(alpha) + 3**0.5*alpha)*torch.exp(-3**0.5*alpha)
return phi
def matern52(alpha):
phi = (torch.ones_like(alpha) + 5**0.5*alpha + (5/3) \
* alpha.pow(2))*torch.exp(-5**0.5*alpha)
return phi
def basis_func_dict():
"""
A helper function that returns a dictionary containing each RBF
"""
bases = {'gaussian': gaussian,
'linear': linear,
'quadratic': quadratic,
'inverse quadratic': inverse_quadratic,
'multiquadric': multiquadric,
'inverse multiquadric': inverse_multiquadric,
'spline': spline,
'poisson one': poisson_one,
'poisson two': poisson_two,
'matern32': matern32,
'matern52': matern52}
return bases